From 103f938d693696c03b16327a2d749db083b72136 Mon Sep 17 00:00:00 2001 From: Paul Hollinsky Date: Thu, 14 Apr 2022 18:22:13 -0400 Subject: [PATCH] Disk: ReadDriver: Add unified cache Previously, we had to copy an entire block out of the old cache every time we wanted to read even a single byte from it. This ended up being a fairly significant performance issue, in addition to the fact that the caching code was duplicated. --- disk/diskreaddriver.cpp | 51 +++- disk/diskwritedriver.cpp | 5 + disk/extextractordiskreaddriver.cpp | 289 +++++++++--------- disk/neomemorydiskdriver.cpp | 53 ++-- disk/plasiondiskreaddriver.cpp | 84 +++-- include/icsneo/disk/diskreaddriver.h | 12 + .../icsneo/disk/extextractordiskreaddriver.h | 5 - include/icsneo/disk/neomemorydiskdriver.h | 5 - include/icsneo/disk/plasiondiskreaddriver.h | 5 - test/diskdriverreadtest.cpp | 52 ++++ 10 files changed, 316 insertions(+), 245 deletions(-) diff --git a/disk/diskreaddriver.cpp b/disk/diskreaddriver.cpp index 29da933..52fbf6d 100644 --- a/disk/diskreaddriver.cpp +++ b/disk/diskreaddriver.cpp @@ -9,14 +9,25 @@ optional ReadDriver::readLogicalDisk(Communication& com, device_eventh if(amount == 0) return 0; - optional ret; + pos += vsaOffset; + + // First read from the cache + optional ret = readFromCache(pos, into, amount); + if(ret == amount) // Full cache hit, we're done + return ret; + + const uint64_t totalAmount = amount; + if(ret.has_value()) { // Partial cache hit + pos += *ret; + into += *ret; + amount -= *ret; + } // Read into here if we can't read directly into the user buffer // That would be the case either if we don't want some at the // beginning or end of the block. std::vector alignedReadBuffer; - pos += vsaOffset; const uint32_t idealBlockSize = getBlockSizeBounds().second; const uint64_t startBlock = pos / idealBlockSize; const uint32_t posWithinFirstBlock = static_cast(pos % idealBlockSize); @@ -36,7 +47,7 @@ optional ReadDriver::readLogicalDisk(Communication& com, device_eventh const uint32_t posWithinCurrentBlock = (blocksProcessed ? 0 : posWithinFirstBlock); uint32_t curAmt = idealBlockSize - posWithinCurrentBlock; - const auto amountLeft = amount - ret.value_or(0); + const auto amountLeft = totalAmount - ret.value_or(0); if(curAmt > amountLeft) curAmt = static_cast(amountLeft); @@ -65,7 +76,41 @@ optional ReadDriver::readLogicalDisk(Communication& com, device_eventh ret.emplace(); *ret += std::min(*readAmount, curAmt); blocksProcessed++; + + if(blocksProcessed == blocks) { + // Last block, add to the cache + if(useAlignedReadBuffer) { + cache = std::move(alignedReadBuffer); + } else { + if(cache.size() != idealBlockSize) + cache.resize(idealBlockSize); + memcpy(cache.data(), into + intoOffset, idealBlockSize); + } + cachePos = currentBlock * idealBlockSize; + cachedAt = std::chrono::steady_clock::now(); + } } return ret; +} + +void ReadDriver::invalidateCache(uint64_t pos, uint64_t amount) { + if(pos <= cachePos + cache.size() && pos + amount >= cachePos) + cache.clear(); +} + +optional ReadDriver::readFromCache(uint64_t pos, uint8_t* into, uint64_t amount, std::chrono::milliseconds staleAfter) { + if(cache.empty()) + return nullopt; // Nothing in the cache + + if(cachedAt + staleAfter < std::chrono::steady_clock::now()) + return nullopt; // Cache is stale + + if(pos > cachePos + cache.size() || pos < cachePos) + return nullopt; // Cache miss + + const auto cacheOffset = pos - cachePos; + const auto copyAmount = std::min(cache.size() - cacheOffset, amount); + memcpy(into, cache.data() + cacheOffset, static_cast(copyAmount)); + return copyAmount; } \ No newline at end of file diff --git a/disk/diskwritedriver.cpp b/disk/diskwritedriver.cpp index 3a0980b..3c89918 100644 --- a/disk/diskwritedriver.cpp +++ b/disk/diskwritedriver.cpp @@ -75,6 +75,7 @@ optional WriteDriver::writeLogicalDisk(Communication& com, device_even if(bytesTransferred == RetryAtomic) { // The user may want to log these events in order to see how many atomic misses they are getting report(APIEvent::Type::AtomicOperationRetried, APIEvent::Severity::EventInfo); + readDriver.invalidateCache(currentBlock * idealBlockSize, idealBlockSize); continue; } @@ -93,5 +94,9 @@ optional WriteDriver::writeLogicalDisk(Communication& com, device_even blocksProcessed++; } + // No matter how much succeeded, to be safe, we'll invalidate anything + // we may have even tried to write, since it may have succeeded without + // notifying, etc. + readDriver.invalidateCache(pos, amount); return ret; } \ No newline at end of file diff --git a/disk/extextractordiskreaddriver.cpp b/disk/extextractordiskreaddriver.cpp index 876629c..b29d115 100644 --- a/disk/extextractordiskreaddriver.cpp +++ b/disk/extextractordiskreaddriver.cpp @@ -37,177 +37,170 @@ optional ExtExtractorDiskReadDriver::attemptReadLogicalDiskAligned(Com uint64_t pos, uint8_t* into, uint64_t amount, std::chrono::milliseconds timeout) { static std::shared_ptr NeoMemorySDRead = std::make_shared(Network::NetID::NeoMemorySDRead); - if(cachePos != pos || std::chrono::steady_clock::now() > cachedAt + CacheTime) { - uint64_t sector = pos / SectorSize; + uint64_t sector = pos / SectorSize; - uint64_t largeSectorCount = amount / SectorSize; - uint32_t sectorCount = uint32_t(largeSectorCount); - if (largeSectorCount != uint64_t(sectorCount)) - return nullopt; + uint64_t largeSectorCount = amount / SectorSize; + uint32_t sectorCount = uint32_t(largeSectorCount); + if (largeSectorCount != uint64_t(sectorCount)) + return nullopt; - // The cache does not have this data, go get it - std::mutex m; - std::condition_variable cv; - uint16_t receiving = 0; // How much are we about to get before another header or completion - uint64_t received = 0; - uint16_t receivedCurrent = 0; - size_t skipping = 0; - std::vector header; - std::unique_lock lk(m); - bool error = !com.redirectRead([&](std::vector&& data) { - std::unique_lock lk2(m); - if(error) { - lk2.unlock(); - cv.notify_all(); - return; - } + std::mutex m; + std::condition_variable cv; + uint16_t receiving = 0; // How much are we about to get before another header or completion + uint64_t received = 0; + uint16_t receivedCurrent = 0; + size_t skipping = 0; + std::vector header; + std::unique_lock lk(m); + bool error = !com.redirectRead([&](std::vector&& data) { + std::unique_lock lk2(m); + if(error) { + lk2.unlock(); + cv.notify_all(); + return; + } - if(skipping > data.size()) { - skipping -= data.size(); - return; - } - size_t offset = skipping; - skipping = 0; - while(offset < data.size()) { - size_t left = data.size() - offset; + if(skipping > data.size()) { + skipping -= data.size(); + return; + } + size_t offset = skipping; + skipping = 0; + while(offset < data.size()) { + size_t left = data.size() - offset; + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Going to process " << left << " bytes" << std::endl; + #endif + if(header.size() != HeaderLength) { + if(header.empty() && left && data[offset] != 0xaa) { + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Incorrect header " << int(data[offset]) << ' ' << int(offset) << std::endl; + #endif + error = true; + lk2.unlock(); + cv.notify_all(); + return; + } + + // Did we get a correct header and at least one byte of data? + const auto begin = data.begin() + offset; + int32_t headerLeft = int32_t(HeaderLength - header.size()); + if(int32_t(left) < headerLeft) { + // Not enough data here, grab what header we can and continue + header.insert(header.end(), begin, data.end()); + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Got " << int(left) << " bytes of header at " << offset << " (incomplete " << + header.size() << ')' << std::endl; + #endif + return; + } + header.insert(header.end(), begin, begin + headerLeft); #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "Going to process " << left << " bytes" << std::endl; + std::cout << "Got " << int(headerLeft) << " bytes of header at " << offset << " (complete " << + header.size() << ')' << std::endl; #endif - if(header.size() != HeaderLength) { - if(header.empty() && left && data[offset] != 0xaa) { + offset += headerLeft; + + if(header[1] == uint8_t(Network::NetID::RED)) { + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Got extended response " << int(offset) << std::endl; + #endif + // This is the extended command response, not all devices send this + // If we got it, we need to figure out how much more data to ignore + uint16_t length = (header[2] + (header[3] << 8)); + // Try for another header after this, regardless how much we choose + // to skip and how we skip it + header.clear(); + if(length <= 6) { #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "Incorrect header " << int(data[offset]) << ' ' << int(offset) << std::endl; + std::cout << "Incorrect extended response length " << int(length) << ' ' << int(offset) << std::endl; #endif error = true; lk2.unlock(); cv.notify_all(); return; } - - // Did we get a correct header and at least one byte of data? - const auto begin = data.begin() + offset; - int32_t headerLeft = int32_t(HeaderLength - header.size()); - if(int32_t(left) < headerLeft) { - // Not enough data here, grab what header we can and continue - header.insert(header.end(), begin, data.end()); - #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "Got " << int(left) << " bytes of header at " << offset << " (incomplete " << - header.size() << ')' << std::endl; - #endif + length -= 7; + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Skipping " << int(length) << ' ' << int(left) << std::endl; + #endif + if(left < length) { + skipping = length - left; return; } - header.insert(header.end(), begin, begin + headerLeft); - #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "Got " << int(headerLeft) << " bytes of header at " << offset << " (complete " << - header.size() << ')' << std::endl; - #endif - offset += headerLeft; - - if(header[1] == uint8_t(Network::NetID::RED)) { - #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "Got extended response " << int(offset) << std::endl; - #endif - // This is the extended command response, not all devices send this - // If we got it, we need to figure out how much more data to ignore - uint16_t length = (header[2] + (header[3] << 8)); - // Try for another header after this, regardless how much we choose - // to skip and how we skip it - header.clear(); - if(length <= 6) { - #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "Incorrect extended response length " << int(length) << ' ' << int(offset) << std::endl; - #endif - error = true; - lk2.unlock(); - cv.notify_all(); - return; - } - length -= 7; - #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "Skipping " << int(length) << ' ' << int(left) << std::endl; - #endif - if(left < length) { - skipping = length - left; - return; - } - offset += length; - continue; - } - - // The device tells us how much it's sending us before the next header - receiving = (header[5] | (header[6] << 8)); - #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "Started packet of size " << receiving << " bytes" << std::endl; - #endif + offset += length; + continue; } - left = data.size() - offset; - auto count = uint16_t(std::min(std::min(receiving - receivedCurrent, left), amount - received)); + // The device tells us how much it's sending us before the next header + receiving = (header[5] | (header[6] << 8)); #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "With " << int(left) << " bytes " << int(offset) << std::endl; + std::cout << "Started packet of size " << receiving << " bytes" << std::endl; #endif - memcpy(cache.data() + received, data.data() + offset, count); - received += count; - receivedCurrent += count; - offset += count; - - if(amount == received) { - if(receivedCurrent % 2 == 0) - offset++; - header.clear(); // Now we will need another header - lk2.unlock(); - cv.notify_all(); - lk2.lock(); - #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "Finished!" << std::endl; - #endif - } - else if(receivedCurrent == receiving) { - #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "Got " << count << " bytes, " << receivedCurrent << " byte packet " << received << - " complete of " << amount << std::endl; - #endif - if(receivedCurrent % 2 == 0) - offset++; - header.clear(); // Now we will need another header - receivedCurrent = 0; - } else { - #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS - std::cout << "Got " << count << " bytes, incomplete (of " << receiving << " bytes)" << std::endl; - #endif - } } - }); - Lifetime clearRedirect([&com, &lk] { lk.unlock(); com.clearRedirectRead(); }); - if(error) - return nullopt; + left = data.size() - offset; + auto count = uint16_t(std::min(std::min(receiving - receivedCurrent, left), amount - received)); + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "With " << int(left) << " bytes " << int(offset) << std::endl; + #endif + memcpy(into + received, data.data() + offset, count); + received += count; + receivedCurrent += count; + offset += count; - error = !com.sendCommand(ExtendedCommand::Extract, { - uint8_t(sector & 0xff), - uint8_t((sector >> 8) & 0xff), - uint8_t((sector >> 16) & 0xff), - uint8_t((sector >> 24) & 0xff), - uint8_t((sector >> 32) & 0xff), - uint8_t((sector >> 40) & 0xff), - uint8_t((sector >> 48) & 0xff), - uint8_t((sector >> 56) & 0xff), - uint8_t(sectorCount & 0xff), - uint8_t((sectorCount >> 8) & 0xff), - uint8_t((sectorCount >> 16) & 0xff), - uint8_t((sectorCount >> 24) & 0xff), - }); - if(error) - return nullopt; + if(amount == received) { + if(receivedCurrent % 2 == 0) + offset++; + header.clear(); // Now we will need another header + lk2.unlock(); + cv.notify_all(); + lk2.lock(); + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Finished!" << std::endl; + #endif + } + else if(receivedCurrent == receiving) { + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Got " << count << " bytes, " << receivedCurrent << " byte packet " << received << + " complete of " << amount << std::endl; + #endif + if(receivedCurrent % 2 == 0) + offset++; + header.clear(); // Now we will need another header + receivedCurrent = 0; + } else { + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Got " << count << " bytes, incomplete (of " << receiving << " bytes)" << std::endl; + #endif + } + } + }); + Lifetime clearRedirect([&com, &lk] { lk.unlock(); com.clearRedirectRead(); }); - bool hitTimeout = !cv.wait_for(lk, timeout, [&]() { return error || amount == received; }); - if(hitTimeout || error) - return nullopt; + if(error) + return nullopt; - cachedAt = std::chrono::steady_clock::now(); - cachePos = pos; - } + error = !com.sendCommand(ExtendedCommand::Extract, { + uint8_t(sector & 0xff), + uint8_t((sector >> 8) & 0xff), + uint8_t((sector >> 16) & 0xff), + uint8_t((sector >> 24) & 0xff), + uint8_t((sector >> 32) & 0xff), + uint8_t((sector >> 40) & 0xff), + uint8_t((sector >> 48) & 0xff), + uint8_t((sector >> 56) & 0xff), + uint8_t(sectorCount & 0xff), + uint8_t((sectorCount >> 8) & 0xff), + uint8_t((sectorCount >> 16) & 0xff), + uint8_t((sectorCount >> 24) & 0xff), + }); + if(error) + return nullopt; + + bool hitTimeout = !cv.wait_for(lk, timeout, [&]() { return error || amount == received; }); + if(hitTimeout || error) + return nullopt; - memcpy(into, cache.data(), size_t(amount)); return amount; } diff --git a/disk/neomemorydiskdriver.cpp b/disk/neomemorydiskdriver.cpp index a376080..a43f86c 100644 --- a/disk/neomemorydiskdriver.cpp +++ b/disk/neomemorydiskdriver.cpp @@ -15,38 +15,31 @@ optional NeoMemoryDiskDriver::readLogicalDiskAligned(Communication& co if(amount != SectorSize) return nullopt; - if(cachePos != pos || std::chrono::steady_clock::now() > cachedAt + CacheTime) { - // The cache does not have this data, go get it - const uint64_t currentSector = pos / SectorSize; - auto msg = com.waitForMessageSync([¤tSector, &com] { - return com.sendCommand(Command::NeoReadMemory, { - MemoryTypeSD, - uint8_t(currentSector & 0xFF), - uint8_t((currentSector >> 8) & 0xFF), - uint8_t((currentSector >> 16) & 0xFF), - uint8_t((currentSector >> 24) & 0xFF), - uint8_t(SectorSize & 0xFF), - uint8_t((SectorSize >> 8) & 0xFF), - uint8_t((SectorSize >> 16) & 0xFF), - uint8_t((SectorSize >> 24) & 0xFF) - }); - }, NeoMemorySDRead, timeout); + const uint64_t currentSector = pos / SectorSize; + auto msg = com.waitForMessageSync([¤tSector, &com] { + return com.sendCommand(Command::NeoReadMemory, { + MemoryTypeSD, + uint8_t(currentSector & 0xFF), + uint8_t((currentSector >> 8) & 0xFF), + uint8_t((currentSector >> 16) & 0xFF), + uint8_t((currentSector >> 24) & 0xFF), + uint8_t(SectorSize & 0xFF), + uint8_t((SectorSize >> 8) & 0xFF), + uint8_t((SectorSize >> 16) & 0xFF), + uint8_t((SectorSize >> 24) & 0xFF) + }); + }, NeoMemorySDRead, timeout); - if(!msg) - return 0; + if(!msg) + return 0; - const auto sdmsg = std::dynamic_pointer_cast(msg); - if(!sdmsg || sdmsg->data.size() != SectorSize) { - report(APIEvent::Type::PacketDecodingError, APIEvent::Severity::Error); - return nullopt; - } - - memcpy(cache.data(), sdmsg->data.data(), SectorSize); - cachedAt = std::chrono::steady_clock::now(); - cachePos = pos; + const auto sdmsg = std::dynamic_pointer_cast(msg); + if(!sdmsg || sdmsg->data.size() != SectorSize) { + report(APIEvent::Type::PacketDecodingError, APIEvent::Severity::Error); + return nullopt; } - memcpy(into, cache.data(), SectorSize); + memcpy(into, sdmsg->data.data(), SectorSize); return SectorSize; } @@ -61,10 +54,6 @@ optional NeoMemoryDiskDriver::writeLogicalDiskAligned(Communication& c if(amount != SectorSize) return nullopt; - // Clear the cache if we're writing to the cached sector - if(pos == cachePos) - cachedAt = std::chrono::time_point(); - // Requesting an atomic operation, but neoMemory does not support it // Continue on anyway but warn the caller if(atomicBuf != nullptr) diff --git a/disk/plasiondiskreaddriver.cpp b/disk/plasiondiskreaddriver.cpp index 5a05e07..56748b8 100644 --- a/disk/plasiondiskreaddriver.cpp +++ b/disk/plasiondiskreaddriver.cpp @@ -19,60 +19,50 @@ optional PlasionDiskReadDriver::readLogicalDiskAligned(Communication& if(pos % getBlockSizeBounds().first != 0) return nullopt; - if(cachePos != pos || std::chrono::steady_clock::now() > cachedAt + CacheTime) { - uint64_t largeSector = pos / SectorSize; - uint32_t sector = uint32_t(largeSector); - if (largeSector != uint64_t(sector)) - return nullopt; + uint64_t largeSector = pos / SectorSize; + uint32_t sector = uint32_t(largeSector); + if (largeSector != uint64_t(sector)) + return nullopt; - // The cache does not have this data, go get it - std::mutex m; - std::condition_variable cv; - uint32_t copied = 0; - bool error = false; + std::mutex m; + std::condition_variable cv; + uint32_t copied = 0; + bool error = false; + std::unique_lock lk(m); + auto cb = com.addMessageCallback(MessageCallback([&](std::shared_ptr msg) { std::unique_lock lk(m); - auto cb = com.addMessageCallback(MessageCallback([&](std::shared_ptr msg) { - std::unique_lock lk(m); - const auto sdmsg = std::dynamic_pointer_cast(msg); - if(!sdmsg || cache.size() < copied + sdmsg->data.size()) { - error = true; - lk.unlock(); - cv.notify_all(); - return; - } + const auto sdmsg = std::dynamic_pointer_cast(msg); + if(!sdmsg || amount < copied + sdmsg->data.size()) { + error = true; + lk.unlock(); + cv.notify_all(); + return; + } - // Invalidate the cache here in case we fail half-way through - cachedAt = std::chrono::steady_clock::time_point(); + memcpy(into + copied, sdmsg->data.data(), sdmsg->data.size()); + copied += uint32_t(sdmsg->data.size()); + if(copied == amount) { + lk.unlock(); + cv.notify_all(); + } + }, NeoMemorySDRead)); - memcpy(cache.data() + copied, sdmsg->data.data(), sdmsg->data.size()); - copied += uint32_t(sdmsg->data.size()); - if(copied == amount) { - lk.unlock(); - cv.notify_all(); - } - }, NeoMemorySDRead)); + com.rawWrite({ + uint8_t(MultiChannelCommunication::CommandType::HostPC_from_SDCC1), + uint8_t(sector & 0xFF), + uint8_t((sector >> 8) & 0xFF), + uint8_t((sector >> 16) & 0xFF), + uint8_t((sector >> 24) & 0xFF), + uint8_t(amount & 0xFF), + uint8_t((amount >> 8) & 0xFF), + }); - com.rawWrite({ - uint8_t(MultiChannelCommunication::CommandType::HostPC_from_SDCC1), - uint8_t(sector & 0xFF), - uint8_t((sector >> 8) & 0xFF), - uint8_t((sector >> 16) & 0xFF), - uint8_t((sector >> 24) & 0xFF), - uint8_t(amount & 0xFF), - uint8_t((amount >> 8) & 0xFF), - }); + bool hitTimeout = !cv.wait_for(lk, timeout, [&copied, &error, &amount] { return error || copied == amount; }); + com.removeMessageCallback(cb); - bool hitTimeout = !cv.wait_for(lk, timeout, [&copied, &error, &amount] { return error || copied == amount; }); - com.removeMessageCallback(cb); + if(hitTimeout) + return nullopt; - if(hitTimeout) - return nullopt; - - cachedAt = std::chrono::steady_clock::now(); - cachePos = pos; - } - - memcpy(into, cache.data(), size_t(amount)); return amount; } \ No newline at end of file diff --git a/include/icsneo/disk/diskreaddriver.h b/include/icsneo/disk/diskreaddriver.h index 5b52a9b..d77aac6 100644 --- a/include/icsneo/disk/diskreaddriver.h +++ b/include/icsneo/disk/diskreaddriver.h @@ -22,6 +22,9 @@ public: virtual optional readLogicalDisk(Communication& com, device_eventhandler_t report, uint64_t pos, uint8_t* into, uint64_t amount, std::chrono::milliseconds timeout = DefaultTimeout); + void invalidateCache(uint64_t pos = 0, + uint64_t amount = std::numeric_limits::max() /* large value, but avoid overflow */); + protected: /** * Perform a read which the driver can do in one shot. @@ -31,6 +34,15 @@ protected: */ virtual optional readLogicalDiskAligned(Communication& com, device_eventhandler_t report, uint64_t pos, uint8_t* into, uint64_t amount, std::chrono::milliseconds timeout) = 0; + +private: + std::vector cache; + uint64_t cachePos = 0; + std::chrono::time_point cachedAt; + + static constexpr const std::chrono::milliseconds CacheTime = std::chrono::milliseconds(1000); + + optional readFromCache(uint64_t pos, uint8_t* into, uint64_t amount, std::chrono::milliseconds staleAfter = CacheTime); }; } // namespace Disk diff --git a/include/icsneo/disk/extextractordiskreaddriver.h b/include/icsneo/disk/extextractordiskreaddriver.h index 3809a21..46befb7 100644 --- a/include/icsneo/disk/extextractordiskreaddriver.h +++ b/include/icsneo/disk/extextractordiskreaddriver.h @@ -24,13 +24,8 @@ public: private: static constexpr const uint32_t MaxSize = Disk::SectorSize * 512; - static constexpr const std::chrono::seconds CacheTime = std::chrono::seconds(1); static constexpr const uint8_t HeaderLength = 7; - std::array cache; - uint64_t cachePos = 0; - std::chrono::time_point cachedAt; - Access getPossibleAccess() const override { return Access::EntireCard; } optional readLogicalDiskAligned(Communication& com, device_eventhandler_t report, diff --git a/include/icsneo/disk/neomemorydiskdriver.h b/include/icsneo/disk/neomemorydiskdriver.h index 086a17e..a6b2e66 100644 --- a/include/icsneo/disk/neomemorydiskdriver.h +++ b/include/icsneo/disk/neomemorydiskdriver.h @@ -27,11 +27,6 @@ public: private: static constexpr const uint8_t MemoryTypeSD = 0x01; // Logical Disk - static constexpr const std::chrono::seconds CacheTime = std::chrono::seconds(1); - - std::array cache; - uint64_t cachePos = 0; - std::chrono::time_point cachedAt; Access getPossibleAccess() const override { return Access::VSA; } diff --git a/include/icsneo/disk/plasiondiskreaddriver.h b/include/icsneo/disk/plasiondiskreaddriver.h index ee94265..00a16d8 100644 --- a/include/icsneo/disk/plasiondiskreaddriver.h +++ b/include/icsneo/disk/plasiondiskreaddriver.h @@ -24,11 +24,6 @@ public: private: static constexpr const uint32_t MaxSize = 65024; - static constexpr const std::chrono::seconds CacheTime = std::chrono::seconds(1); - - std::array cache; - uint64_t cachePos = 0; - std::chrono::time_point cachedAt; Access getPossibleAccess() const override { return Access::EntireCard; } diff --git a/test/diskdriverreadtest.cpp b/test/diskdriverreadtest.cpp index c6070b8..34ecfb1 100644 --- a/test/diskdriverreadtest.cpp +++ b/test/diskdriverreadtest.cpp @@ -61,4 +61,56 @@ TEST_F(DiskDriverTest, ReadBadStartingPos) { const auto amountRead = readLogicalDisk(2000, buf.data(), buf.size()); EXPECT_FALSE(amountRead.has_value()); EXPECT_EQ(driver->readCalls, 1u); // One to check EOF +} + +TEST_F(DiskDriverTest, ReadCache) { + std::array buf; + buf.fill(0u); + auto amountRead = readLogicalDisk(1, buf.data(), buf.size()); + EXPECT_TRUE(amountRead.has_value()); + EXPECT_EQ(amountRead, buf.size()); + EXPECT_EQ(buf[0], TEST_STRING[1]); + EXPECT_EQ(buf[110], 111u); + EXPECT_EQ(driver->readCalls, 1u); + + // Subsequent reads (within the same second) should hit the cache + amountRead = readLogicalDisk(1, buf.data(), buf.size()); + EXPECT_EQ(driver->readCalls, 1u); + + // The underlying data can be changed + driver->mockDisk[1] = 'J'; + + // But the same data should be returned from the cache + amountRead = readLogicalDisk(1, buf.data(), buf.size()); + EXPECT_TRUE(amountRead.has_value()); + EXPECT_EQ(amountRead, buf.size()); + EXPECT_EQ(buf[0], TEST_STRING[1]); + EXPECT_EQ(buf[110], 111u); + EXPECT_EQ(driver->readCalls, 1u); + + driver->invalidateCache(0, 0xfffff); + + // After invalidating the cache (or waiting for it to expire), the underlying data will be read + amountRead = readLogicalDisk(1, buf.data(), buf.size()); + EXPECT_TRUE(amountRead.has_value()); + EXPECT_EQ(amountRead, buf.size()); + EXPECT_EQ(buf[0], 'J'); + EXPECT_EQ(buf[110], 111u); + EXPECT_EQ(driver->readCalls, 2u); +} + +TEST_F(DiskDriverTest, ReadCacheLong) { + std::array buf; + buf.fill(0u); + auto amountRead = readLogicalDisk(300, buf.data(), buf.size()); + EXPECT_TRUE(amountRead.has_value()); + EXPECT_EQ(amountRead, buf.size()); + EXPECT_EQ(buf[0], 300 & 0xFF); + EXPECT_EQ(buf[110], 410 & 0xFF); + EXPECT_EQ(driver->readCalls, 3u); + + // Re-read the end, it will be in the cache + amountRead = readLogicalDisk(780, buf.data() + 480, buf.size() - 480); + EXPECT_EQ(buf[490], 790 & 0xFF); + EXPECT_EQ(driver->readCalls, 3u); } \ No newline at end of file