diff --git a/communication/communication.cpp b/communication/communication.cpp index 678a682..54a0ee6 100644 --- a/communication/communication.cpp +++ b/communication/communication.cpp @@ -20,8 +20,6 @@ using namespace icsneo; int Communication::messageCallbackIDCounter = 1; Communication::~Communication() { - if(redirectingRead) - clearRedirectRead(); if(isOpen()) close(); } @@ -44,6 +42,11 @@ void Communication::spawnThreads() { void Communication::joinThreads() { closing = true; + + if(pauseReadTask) { + resumeReads(); + } + if(readTaskThread.joinable()) readTaskThread.join(); closing = false; @@ -96,23 +99,6 @@ bool Communication::sendCommand(ExtendedCommand cmd, std::vector argume return sendCommand(Command::Extended, arguments); } -bool Communication::redirectRead(std::function&&)> redirectTo) { - if(redirectingRead) - return false; - redirectionFn = redirectTo; - redirectingRead = true; - return true; -} - -void Communication::clearRedirectRead() { - if(!redirectingRead) - return; - // The mutex is required to clear the redirection, but not to set it - std::lock_guard lk(redirectingReadMutex); - redirectingRead = false; - redirectionFn = std::function&&)>(); -} - bool Communication::getSettingsSync(std::vector& data, std::chrono::milliseconds timeout) { static const std::shared_ptr filter = std::make_shared(Network::NetID::ReadSettings); std::shared_ptr msg = waitForMessageSync([this]() { @@ -261,44 +247,55 @@ void Communication::dispatchMessage(const std::shared_ptr& msg) { EventManager::GetInstance().downgradeErrorsOnCurrentThread(); } -void Communication::readTask() { - std::vector readBytes; +void Communication::pauseReads() { + std::unique_lock lk(pauseReadTaskMutex); + pauseReadTask = true; +} +void Communication::resumeReads() { + std::unique_lock lk(pauseReadTaskMutex); + if(!pauseReadTask) { + return; + } + pauseReadTask = false; + lk.unlock(); + + pauseReadTaskCv.notify_one(); +} + +bool Communication::readsArePaused() { + std::unique_lock lk(pauseReadTaskMutex); + return pauseReadTask; +} + +void Communication::readTask() { EventManager::GetInstance().downgradeErrorsOnCurrentThread(); while(!closing) { - readBytes.clear(); + if(pauseReadTask) { + std::unique_lock lk(pauseReadTaskMutex); + pauseReadTaskCv.wait(lk, [this]() { return !pauseReadTask; }); + } if(driver->readAvailable()) { - handleInput(*packetizer, readBytes); + if(pauseReadTask) { + /** + * Reads could have paused while the driver was not available + */ + continue; + } + handleInput(*packetizer); } } } -void Communication::handleInput(Packetizer& p, std::vector& readBytes) { - if(redirectingRead) { - // redirectingRead is an atomic so it can be set without acquiring a mutex - // However, we do not clear it without the mutex. The idea is that if another - // thread calls clearRedirectRead(), it will block until the redirectionFn - // finishes, and after that the redirectionFn will not be called again. - std::unique_lock lk(redirectingReadMutex); - // So after we acquire the mutex, we need to check the atomic again, and - // if it has become cleared, we *can not* run the redirectionFn. - if(redirectingRead) { - redirectionFn(std::move(readBytes)); - } else { - // The redirectionFn got cleared while we were acquiring the lock - lk.unlock(); // We don't need the lock anymore - handleInput(p, readBytes); // and we might as well process this input ourselves - } - } else { - if(p.input(driver->getReadBuffer())) { - for(const auto& packet : p.output()) { - std::shared_ptr msg; - if(!decoder->decode(msg, packet)) - continue; +void Communication::handleInput(Packetizer& p) { + if(p.input(driver->getReadBuffer())) { + for(const auto& packet : p.output()) { + std::shared_ptr msg; + if(!decoder->decode(msg, packet)) + continue; - dispatchMessage(msg); - } + dispatchMessage(msg); } } } diff --git a/communication/driver.cpp b/communication/driver.cpp index c2fb0f1..26673ac 100644 --- a/communication/driver.cpp +++ b/communication/driver.cpp @@ -8,6 +8,26 @@ using namespace icsneo; +bool Driver::writeToReadBuffer(const uint8_t* buf, size_t numReceived) { + bool ret = readBuffer.write(buf, numReceived); + + if(hasRxWaitRequest) { + rxWaitRequestCv.notify_one(); + } + + return ret; +} + +bool Driver::waitForRx(size_t minBytes, std::chrono::milliseconds timeout) { + std::unique_lock lk(rxWaitMutex); + hasRxWaitRequest = true; + + auto ret = rxWaitRequestCv.wait_for(lk, timeout, [this, minBytes]{ return readBuffer.size() >= minBytes; }); + hasRxWaitRequest = false; + + return ret; +} + bool Driver::readWait(std::vector& bytes, std::chrono::milliseconds timeout, size_t limit) { // A limit of zero indicates no limit if(limit == 0) @@ -16,14 +36,13 @@ bool Driver::readWait(std::vector& bytes, std::chrono::milliseconds tim if(limit > (readBuffer.size() + 4)) limit = (readBuffer.size() + 4); - bytes.resize(limit); // wait until we have enough data, or the timout occurs - const auto timeoutTime = std::chrono::steady_clock::now() + timeout; - while (readBuffer.size() < limit && std::chrono::steady_clock::now() < timeoutTime) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } + waitForRx(limit, timeout); + size_t actuallyRead = std::min(readBuffer.size(), limit); + bytes.resize(actuallyRead); + readBuffer.read(bytes.data(), 0, actuallyRead); readBuffer.pop(actuallyRead); bytes.resize(actuallyRead); diff --git a/communication/multichannelcommunication.cpp b/communication/multichannelcommunication.cpp index 70a6f91..b1050fa 100644 --- a/communication/multichannelcommunication.cpp +++ b/communication/multichannelcommunication.cpp @@ -177,8 +177,11 @@ void MultiChannelCommunication::vnetReadTask(size_t vnetIndex) { if(queue.wait_dequeue_timed(payloadBytes, std::chrono::milliseconds(250))) { if(closing) break; - - handleInput(*vnetPacketizer, payloadBytes); + + auto& ringBuffer = driver->getReadBuffer(); + ringBuffer.write(payloadBytes); + + handleInput(*vnetPacketizer); } } } \ No newline at end of file diff --git a/include/icsneo/communication/communication.h b/include/icsneo/communication/communication.h index 064ecd9..9f4b9f9 100644 --- a/include/icsneo/communication/communication.h +++ b/include/icsneo/communication/communication.h @@ -48,8 +48,10 @@ public: void awaitModeChangeComplete() { driver->awaitModeChangeComplete(); } bool rawWrite(const std::vector& bytes) { return driver->write(bytes); } virtual bool sendPacket(std::vector& bytes); - bool redirectRead(std::function&&)> redirectTo); - void clearRedirectRead(); + + void pauseReads(); + void resumeReads(); + bool readsArePaused(); void setWriteBlocks(bool blocks) { driver->writeBlocks = blocks; } @@ -90,12 +92,14 @@ protected: std::mutex messageCallbacksLock; std::map> messageCallbacks; std::atomic closing{false}; - std::atomic redirectingRead{false}; - std::function&&)> redirectionFn; - std::mutex redirectingReadMutex; // Don't allow read to be disabled while in the redirectionFn + + std::condition_variable pauseReadTaskCv; + std::mutex pauseReadTaskMutex; + std::atomic pauseReadTask = false; + std::mutex syncMessageMutex; - void handleInput(Packetizer& p, std::vector& readBytes); + void handleInput(Packetizer& p); private: std::thread readTaskThread; diff --git a/include/icsneo/communication/driver.h b/include/icsneo/communication/driver.h index 0de0a56..45e6230 100644 --- a/include/icsneo/communication/driver.h +++ b/include/icsneo/communication/driver.h @@ -26,6 +26,8 @@ public: virtual void awaitModeChangeComplete() {} virtual bool isDisconnected() { return disconnected; }; virtual bool close() = 0; + + bool waitForRx(size_t minBytes, std::chrono::milliseconds timeout); bool readWait(std::vector& bytes, std::chrono::milliseconds timeout = std::chrono::milliseconds(100), size_t limit = 0); bool write(const std::vector& bytes); virtual bool isEthernet() const { return false; } @@ -57,8 +59,12 @@ protected: virtual bool writeQueueAlmostFull() { return writeQueue.size_approx() > (writeQueueSize * 3 / 4); } virtual bool writeInternal(const std::vector& b) { return writeQueue.enqueue(WriteOperation(b)); } - + bool writeToReadBuffer(const uint8_t* buf, size_t numReceived); RingBuffer readBuffer = RingBuffer(ICSNEO_DRIVER_RINGBUFFER_SIZE); + std::atomic hasRxWaitRequest = false; + std::condition_variable rxWaitRequestCv; + std::mutex rxWaitMutex; + moodycamel::BlockingConcurrentQueue writeQueue; std::thread readThread, writeThread; std::atomic closing{false}; diff --git a/platform/ftd3xx.cpp b/platform/ftd3xx.cpp index a52cea8..fd0ce7a 100644 --- a/platform/ftd3xx.cpp +++ b/platform/ftd3xx.cpp @@ -136,7 +136,7 @@ void FTD3XX::readTask() { } FT_ReleaseOverlapped(*handle, &overlap); if(received > 0) { - readBuffer.write(buffer, received); + writeToReadBuffer(buffer, received); } } } diff --git a/platform/posix/cdcacm.cpp b/platform/posix/cdcacm.cpp index d517b2d..df496e5 100644 --- a/platform/posix/cdcacm.cpp +++ b/platform/posix/cdcacm.cpp @@ -190,7 +190,7 @@ void CDCACM::readTask() { } std::cout << std::dec << std::endl; #endif - readBuffer.write(readbuf, bytesRead); + writeToReadBuffer(readbuf, bytesRead); } else { if(modeChanging) { // We were expecting a disconnect for reenumeration diff --git a/platform/posix/firmio.cpp b/platform/posix/firmio.cpp index 9a1a4a9..d6215f4 100644 --- a/platform/posix/firmio.cpp +++ b/platform/posix/firmio.cpp @@ -242,7 +242,7 @@ void FirmIO::readTask() { // Translate the physical address back to our virtual address space uint8_t* addr = reinterpret_cast(msg.payload.data.addr - PHY_ADDR_BASE + vbase); - while (!readBuffer.write(addr, msg.payload.data.len)) { + while (!writeToReadBuffer(addr, msg.payload.data.len)) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); // back-off so reading thread can empty the buffer if (closing || isDisconnected()) { break; diff --git a/platform/posix/ftdi.cpp b/platform/posix/ftdi.cpp index 4a00147..24a4041 100644 --- a/platform/posix/ftdi.cpp +++ b/platform/posix/ftdi.cpp @@ -213,7 +213,7 @@ void FTDI::readTask() { } else report(APIEvent::Type::FailedToRead, APIEvent::Severity::EventWarning); } else - readBuffer.write(readbuf, readBytes); + writeToReadBuffer(readbuf, readBytes); } } diff --git a/platform/posix/pcap.cpp b/platform/posix/pcap.cpp index e7fc27e..4135bcf 100644 --- a/platform/posix/pcap.cpp +++ b/platform/posix/pcap.cpp @@ -284,7 +284,7 @@ void PCAP::readTask() { PCAP* driver = reinterpret_cast(obj); if(driver->ethPacketizer.inputUp({data, data + header->caplen})) { const auto bytes = driver->ethPacketizer.outputUp(); - driver->readBuffer.write(bytes.data(), bytes.size()); + driver->writeToReadBuffer(bytes.data(), bytes.size()); } }, (uint8_t*)this); } diff --git a/platform/tcp.cpp b/platform/tcp.cpp index 97b443d..966e7f1 100644 --- a/platform/tcp.cpp +++ b/platform/tcp.cpp @@ -533,7 +533,7 @@ void TCP::readTask() { uint8_t readbuf[READ_BUFFER_SIZE]; while(!closing) { if(const auto received = ::recv(*socket, (char*)readbuf, READ_BUFFER_SIZE, 0); received > 0) { - readBuffer.write(readbuf, received); + writeToReadBuffer(readbuf, received); } else { timeout.tv_sec = 0; timeout.tv_usec = 50'000; diff --git a/platform/windows/pcap.cpp b/platform/windows/pcap.cpp index f14f5f9..e76ff20 100644 --- a/platform/windows/pcap.cpp +++ b/platform/windows/pcap.cpp @@ -278,7 +278,7 @@ void PCAP::readTask() { if(ethPacketizer.inputUp({data, data + header->caplen})) { const auto bytes = ethPacketizer.outputUp(); - readBuffer.write(bytes.data(), bytes.size()); + writeToReadBuffer(bytes.data(), bytes.size()); } } } diff --git a/platform/windows/vcp.cpp b/platform/windows/vcp.cpp index d7d81a8..8d185dd 100644 --- a/platform/windows/vcp.cpp +++ b/platform/windows/vcp.cpp @@ -390,7 +390,7 @@ void VCP::readTask() { if(ReadFile(detail->handle, readbuf, READ_BUFFER_SIZE, nullptr, &detail->overlappedRead)) { if(GetOverlappedResult(detail->handle, &detail->overlappedRead, &bytesRead, FALSE)) { if(bytesRead) - readBuffer.write(readbuf, bytesRead); + writeToReadBuffer(readbuf, bytesRead); } continue; } @@ -413,7 +413,7 @@ void VCP::readTask() { auto ret = WaitForSingleObject(detail->overlappedRead.hEvent, 100); if(ret == WAIT_OBJECT_0) { if(GetOverlappedResult(detail->handle, &detail->overlappedRead, &bytesRead, FALSE)) { - readBuffer.write(readbuf, bytesRead); + writeToReadBuffer(readbuf, bytesRead); state = LAUNCH; } else report(APIEvent::Type::FailedToRead, APIEvent::Severity::Error);