diff --git a/CMakeLists.txt b/CMakeLists.txt index 32165d4..a62235b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -154,6 +154,7 @@ set(SRC_FILES disk/nulldiskdriver.cpp disk/neomemorydiskreaddriver.cpp disk/plasiondiskreaddriver.cpp + disk/extextractordiskreaddriver.cpp disk/fat.cpp ${PLATFORM_SRC} ) diff --git a/device/device.cpp b/device/device.cpp index 67758a2..22b8965 100644 --- a/device/device.cpp +++ b/device/device.cpp @@ -486,6 +486,9 @@ optional Device::readLogicalDisk(uint64_t pos, uint8_t* into, uint64_t return nullopt; } + // This is needed for certain read drivers which take over the communication stream + const auto lifetime = suppressDisconnects(); + return diskReadDriver->readLogicalDisk(*com, report, pos, into, amount, timeout); } diff --git a/disk/extextractordiskreaddriver.cpp b/disk/extextractordiskreaddriver.cpp new file mode 100644 index 0000000..9ccf86b --- /dev/null +++ b/disk/extextractordiskreaddriver.cpp @@ -0,0 +1,142 @@ +#include "icsneo/disk/extextractordiskreaddriver.h" +#include "icsneo/communication/message/neoreadmemorysdmessage.h" +#include "icsneo/communication/multichannelcommunication.h" +#include "icsneo/api/lifetime.h" +#include + +#ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS +#include +#endif + +using namespace icsneo; +using namespace icsneo::Disk; + +optional ExtExtractorDiskReadDriver::readLogicalDiskAligned(Communication& com, device_eventhandler_t report, + uint64_t pos, uint8_t* into, uint64_t amount, std::chrono::milliseconds timeout) { + static std::shared_ptr NeoMemorySDRead = std::make_shared(Network::NetID::NeoMemorySDRead); + + if(amount > getBlockSizeBounds().second) + return nullopt; + + if(amount % getBlockSizeBounds().first != 0) + return nullopt; + + if(pos % getBlockSizeBounds().first != 0) + return nullopt; + + if(cachePos != pos || std::chrono::steady_clock::now() > cachedAt + CacheTime) { + uint64_t sector = pos / SectorSize; + + uint64_t largeSectorCount = amount / SectorSize; + uint32_t sectorCount = uint32_t(largeSectorCount); + if (largeSectorCount != uint64_t(sectorCount)) + return nullopt; + + // The cache does not have this data, go get it + std::mutex m; + std::condition_variable cv; + uint16_t receiving = 0; // How much are we about to get before another header or completion + uint64_t received = 0; + uint16_t receivedCurrent = 0; + uint8_t gotHeaderBytes = 0; + std::unique_lock lk(m); + bool error = !com.redirectRead([&](std::vector&& data) { + std::unique_lock lk2(m); + size_t offset = 0; + while(offset < data.size()) { + size_t left = data.size() - offset; + if(gotHeaderBytes != headerLength) { + if(gotHeaderBytes == 0 && left && data[offset] != 0xaa) { + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Incorrect header " << int(data[offset]) << ' ' << int(offset) << std::endl; + #endif + offset++; + continue; + } + + // Did we get a correct header and at least one byte of data? + if(int32_t(left) < (headerLength - gotHeaderBytes)) { + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Got " << int(left) << " bytes of header at " << offset << std::endl; + #endif + gotHeaderBytes += uint8_t(left); + return; + } + + // The device tells us how much it's sending us before the next header + receiving = (data[offset + headerLength-2-gotHeaderBytes] | (data[offset + headerLength-1-gotHeaderBytes] << 8)); + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Started packet of size " << receiving << " bytes" << std::endl; + #endif + + // Skip the header and any bytes necessary for unaligned read + offset += headerLength - gotHeaderBytes; + gotHeaderBytes = headerLength; + } + + const auto available = left - offset; + auto count = uint16_t(std::min(std::min(receiving - receivedCurrent, available), amount - received)); + memcpy(cache.data() + received, data.data() + offset, count); + received += count; + receivedCurrent += count; + offset += count; + + if(amount == received) { + if(receivedCurrent % 2 == 0) + offset++; + lk2.unlock(); + cv.notify_all(); + lk2.lock(); + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Finished!" << std::endl; + #endif + } + else if(receivedCurrent == receiving) { + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Got " << count << " bytes, " << receivedCurrent << " byte packet " << received << + " complete of " << amount << std::endl; + #endif + if(receivedCurrent % 2 == 0) + offset++; + gotHeaderBytes = 0; // Now we will need another header + receivedCurrent = 0; + } else { + #ifdef ICSNEO_EXTENDED_EXTRACTOR_DEBUG_PRINTS + std::cout << "Got " << count << " bytes, incomplete (of " << receiving << " bytes)" << std::endl; + #endif + } + } + }); + Lifetime clearRedirect([&com, &lk] { lk.unlock(); com.clearRedirectRead(); }); + + if(error) + return nullopt; + + error = !com.sendCommand(ExtendedCommand::Extract, { + uint8_t(sector & 0xff), + uint8_t((sector >> 8) & 0xff), + uint8_t((sector >> 16) & 0xff), + uint8_t((sector >> 24) & 0xff), + uint8_t((sector >> 32) & 0xff), + uint8_t((sector >> 40) & 0xff), + uint8_t((sector >> 48) & 0xff), + uint8_t((sector >> 56) & 0xff), + uint8_t(sectorCount & 0xff), + uint8_t((sectorCount >> 8) & 0xff), + uint8_t((sectorCount >> 16) & 0xff), + uint8_t((sectorCount >> 24) & 0xff), + }); + if(error) + return nullopt; + + bool hitTimeout = !cv.wait_for(lk, timeout, [&]() { return error || amount == received; }); + if(hitTimeout || error) + return nullopt; + + cachedAt = std::chrono::steady_clock::now(); + cachePos = pos; + } + + memcpy(into, cache.data(), size_t(amount)); + return amount; +} \ No newline at end of file diff --git a/include/icsneo/device/tree/neovired2/neovired2.h b/include/icsneo/device/tree/neovired2/neovired2.h index 6093e2c..22ab7ef 100644 --- a/include/icsneo/device/tree/neovired2/neovired2.h +++ b/include/icsneo/device/tree/neovired2/neovired2.h @@ -4,6 +4,7 @@ #include "icsneo/device/device.h" #include "icsneo/device/devicetype.h" #include "icsneo/platform/pcap.h" +#include "icsneo/disk/extextractordiskreaddriver.h" #include "icsneo/device/tree/neovired2/neovired2settings.h" namespace icsneo { @@ -67,7 +68,7 @@ public: protected: NeoVIRED2(neodevice_t neodevice) : Device(neodevice) { - initialize(); + initialize(); getWritableNeoDevice().type = DEVICE_TYPE; productId = PRODUCT_ID; } diff --git a/include/icsneo/disk/extextractordiskreaddriver.h b/include/icsneo/disk/extextractordiskreaddriver.h new file mode 100644 index 0000000..539afb2 --- /dev/null +++ b/include/icsneo/disk/extextractordiskreaddriver.h @@ -0,0 +1,48 @@ +#ifndef __EXTEXTRACTORDISKREADDRIVER_H__ +#define __EXTEXTRACTORDISKREADDRIVER_H__ + +#ifdef __cplusplus + +#include "icsneo/disk/diskreaddriver.h" +#include +#include + +namespace icsneo { + +namespace Disk { + +/** + * A disk read driver which uses the extended extractor command set to read from the disk + */ +class ExtExtractorDiskReadDriver : public ReadDriver { +public: + Access getAccess() const override { return Access::EntireCard; } + std::pair getBlockSizeBounds() const override { + static_assert(SectorSize <= std::numeric_limits::max(), "Incorrect sector size"); + static_assert(SectorSize >= std::numeric_limits::min(), "Incorrect sector size"); + return { static_cast(SectorSize), static_cast(MaxSize) }; + } + + uint8_t getHeaderLength() const { return headerLength; } + void setHeaderLength(uint8_t length) { headerLength = length; } + +private: + static constexpr const uint32_t MaxSize = Disk::SectorSize * 512; + static constexpr const std::chrono::seconds CacheTime = std::chrono::seconds(1); + + std::array cache; + uint64_t cachePos = 0; + std::chrono::time_point cachedAt; + + uint8_t headerLength = 7; // Correct for Ethernet + + optional readLogicalDiskAligned(Communication& com, device_eventhandler_t report, + uint64_t pos, uint8_t* into, uint64_t amount, std::chrono::milliseconds timeout) override; +}; + +} // namespace Disk + +} // namespace icsneo + +#endif // __cplusplus +#endif // __EXTEXTRACTORDISKREADDRIVER_H__ \ No newline at end of file