From f2f188991880581df68829c53c6c923715e4760c Mon Sep 17 00:00:00 2001 From: Martin Bauer Date: Thu, 25 Jun 2020 22:01:53 +0200 Subject: [PATCH] Websocket connection --- firmware/lib/session/SessionManager.h | 2 + .../lib/session/SimpleMeasurementSession.h | 7 +- firmware/platformio.ini | 5 +- firmware/src/SwimTrackerConfig.h | 8 +- firmware/src/WebsocketServer.h | 237 ++++++++++++++++++ firmware/src/firmware_main.cpp | 38 +++ firmware/websocket_test.py | 36 +++ 7 files changed, 325 insertions(+), 8 deletions(-) create mode 100644 firmware/src/WebsocketServer.h create mode 100644 firmware/websocket_test.py diff --git a/firmware/lib/session/SessionManager.h b/firmware/lib/session/SessionManager.h index ea974be..e9f812c 100644 --- a/firmware/lib/session/SessionManager.h +++ b/firmware/lib/session/SessionManager.h @@ -8,6 +8,8 @@ template class SessionManager { public: + using MeasurementType = typename SessionT::MeasurementType; + SessionManager(int scaleDoutPin, int scaleSckPin, uint8_t tareAvgCount); void begin(); diff --git a/firmware/lib/session/SimpleMeasurementSession.h b/firmware/lib/session/SimpleMeasurementSession.h index 87081e7..8d9fbd8 100644 --- a/firmware/lib/session/SimpleMeasurementSession.h +++ b/firmware/lib/session/SimpleMeasurementSession.h @@ -7,7 +7,7 @@ class SimpleMeasurementSession { public: using ChunkT = SessionChunk; - + using MeasurementType = Measurement_T; // save interval in number of measurements (by default every minute) SimpleMeasurementSession(uint32_t saveInterval = 10 * 60) : chunk(nullptr), saveInterval_(saveInterval) @@ -66,6 +66,11 @@ public: encoder.sendArray(chunk->getDataPointer() + startIdx, numElementsToSend); } + Measurement_T *getDataPointer() + { + return chunk->getDataPointer(); + } + private: void saveToFileSystem() { diff --git a/firmware/platformio.ini b/firmware/platformio.ini index 4542eaa..3eb5573 100644 --- a/firmware/platformio.ini +++ b/firmware/platformio.ini @@ -14,7 +14,7 @@ data_dir = data [env:esp32] platform = espressif32 platform_packages = - framework-arduinoespressif32 @ https://github.com/espressif/arduino-esp32.git + framework-arduinoespressif32 @ https://github.com/espressif/arduino-esp32 board = esp-wrover-kit #platform = espressif8266 #board = esp_wroom_02 @@ -25,10 +25,9 @@ monitor_port = /dev/ttyUSB0 upload_port = /dev/ttyUSB0 monitor_speed = 115200 lib_deps = - https://github.com/mabau/ESPAsyncWebServer.git - AsyncTCP NTPClient HX711@0.7.4 + https://github.com/gilmaimon/ArduinoWebsockets.git src_filter = +<*> - [env:native] diff --git a/firmware/src/SwimTrackerConfig.h b/firmware/src/SwimTrackerConfig.h index 859d337..3799757 100644 --- a/firmware/src/SwimTrackerConfig.h +++ b/firmware/src/SwimTrackerConfig.h @@ -4,13 +4,13 @@ // Uncomment for Version 2.0 where load cell is connected differently -#define _HW_V_20 -//#define NEW_HEAVY_LOAD_CELL +//#define _HW_V_20 +#define NEW_HEAVY_LOAD_CELL // ------------------------------------------ WiFi --------------------------------------------------------------------------------- -//const char *CONFIG_WIFI_SSID = "WLAN"; -const char *CONFIG_WIFI_SSID = "RepeaterWZ"; +const char *CONFIG_WIFI_SSID = "WLAN"; +//const char *CONFIG_WIFI_SSID = "RepeaterWZ"; const char *CONFIG_WIFI_PASSWORD = "Bau3rWLAN"; const char *CONFIG_HOSTNAME = "smartswim"; diff --git a/firmware/src/WebsocketServer.h b/firmware/src/WebsocketServer.h new file mode 100644 index 0000000..f7ab0da --- /dev/null +++ b/firmware/src/WebsocketServer.h @@ -0,0 +1,237 @@ + +#pragma once +#include "Dtypes.h" +#include + +template +class SessionManager; + +static constexpr int MAX_CONNECTIONS = 3; +static constexpr int NUM_DATA_CHUNK_SIZE = 1; + +template +class WebsocketServer +{ +public: + WebsocketServer(SessionManager &sessionManager, int port) + : sessionManager_(sessionManager), nextFreeClient_(0), port_(port), + numSentMeasurements_(0), running_(false) + { + } + + void begin() + { + server_.listen(port_); + } + + void iteration(); + +private: + void reportSessionUpdate(); + + void sendMessageOnConnection(websockets::WebsocketsClient &client); + void sendSessionStartMessages(); + void sendSessionStopMessages(); + void sendNewDataMessages(); + + SessionManager &sessionManager_; + int nextFreeClient_; + int port_; + + size_t sentMessageCount_; + websockets::WebsocketsServer server_; + websockets::WebsocketsClient clients_[MAX_CONNECTIONS]; + + // previous session state + size_t numSentMeasurements_; + bool running_; +}; + +using websockets::WebsocketsClient; + +// ------------------------------------- Message types & classes --------------------------- + +enum MessageType +{ + INITIAL_INFO = 1, + SESSION_STARTED = 2, + SESSION_STOPPED = 3, + SESSION_NEW_DATA = 4 +}; + +#pragma pack(push, 1) +class SessionStartedMessage +{ +public: + SessionStartedMessage(uint32_t id) : messageType_(SESSION_STARTED), sessionId_(id) {} + void send(WebsocketsClient &c) const + { + c.sendBinary((const char *)(this), sizeof(*this)); + } + +private: + uint8_t messageType_; + uint32_t sessionId_; +}; + +class SessionStoppedMessage +{ +public: + SessionStoppedMessage() : messageType_(SESSION_STOPPED) {} + void send(WebsocketsClient &c) const + { + c.sendBinary((const char *)(this), sizeof(*this)); + } + +private: + uint8_t messageType_; +}; + +template +class SessionNewDataMessage +{ +public: + // typically a message contains NUM_DATA_CHUNK_SIZE measurements + // if some measurements are skipped, because loop() takes too long + // there might actually be more measurements, to be safe there is an + // additional factor here + static constexpr size_t MAX_MEASUREMENTS = 4 * NUM_DATA_CHUNK_SIZE; + + SessionNewDataMessage(MeasurementT *ptr, size_t numMeasurements) + : messageType_(SESSION_NEW_DATA), numMeasurements_(min(numMeasurements, MAX_MEASUREMENTS)) + { + memcpy(measurements_, ptr, sizeof(MeasurementT) * numMeasurements_); + } + + void send(WebsocketsClient &c) const + { + c.sendBinary((const char *)(this), numBytes()); + } + + size_t numMeasurements() const + { + return numMeasurements_; + } + +private: + size_t numBytes() const { return sizeof(uint8_t) + numMeasurements() * sizeof(MeasurementT); } + + // data to be sent + uint8_t messageType_; + MeasurementT measurements_[MAX_MEASUREMENTS]; + + // book-keeping + size_t numMeasurements_; +}; +#pragma pack(pop) + +// ------------------------------------- WebsocketServer members --------------------------- + +template +void WebsocketServer::iteration() +{ + if (server_.poll()) + { + clients_[nextFreeClient_] = server_.accept(); + //clients_[nextFreeClient_].onMessage(onMessage); // TODO + Serial.println("new websocket connection"); + sendMessageOnConnection(clients_[nextFreeClient_]); + nextFreeClient_ = (nextFreeClient_ + 1) % MAX_CONNECTIONS; + } + + for (int i = 0; i < MAX_CONNECTIONS; ++i) + clients_[i].poll(); + + reportSessionUpdate(); +} + +template +void WebsocketServer::reportSessionUpdate() +{ + auto &session = sessionManager_.session(); + + // start/stop messages + if (!running_ && sessionManager_.isMeasuring()) + sendSessionStartMessages(); + else if (running_ && !sessionManager_.isMeasuring()) + sendSessionStopMessages(); + + // new data + if (session.numMeasurements() - (NUM_DATA_CHUNK_SIZE - 1) > numSentMeasurements_) + { + sendNewDataMessages(); + } +} + +template +void WebsocketServer::sendSessionStartMessages() +{ + SessionStartedMessage msg(sessionManager_.session().getStartTime()); + for (auto &c : clients_) + if (c.available()) + msg.send(c); + running_ = sessionManager_.isMeasuring(); +} + +template +void WebsocketServer::sendSessionStopMessages() +{ + SessionStoppedMessage msg; + for (auto &c : clients_) + if (c.available()) + msg.send(c); + running_ = sessionManager_.isMeasuring(); +} + +template +void WebsocketServer::sendNewDataMessages() +{ + using MeasurementT = typename SessionT::MeasurementType; + auto &session = sessionManager_.session(); + MeasurementT *dataToSend = session.getDataPointer() + numSentMeasurements_; + size_t numMeasurementsToSend = session.numMeasurements() - numSentMeasurements_; + SessionNewDataMessage msg(dataToSend, numMeasurementsToSend); + + for (auto &c : clients_) + if (c.available()) + msg.send(c); + + numSentMeasurements_ += msg.numMeasurements(); +} + +template +void WebsocketServer::sendMessageOnConnection(WebsocketsClient &client) +{ + using MeasurementT = typename SessionT::MeasurementType; + + // Message format: + // - uint8_t messageType + // - uint8_t running + // - uint32_t sessionId + // - MeasurementT [] measurements (if running) + + auto &session = sessionManager_.session(); + const auto numMeasurements = session.numMeasurements(); + const auto sessionId = session.getStartTime(); + + const size_t msgSize = sizeof(uint8_t) + sizeof(uint8_t) + sizeof(sessionId) + sizeof(MeasurementT) * numMeasurements; + char *msg = (char *)heap_caps_malloc(msgSize, MALLOC_CAP_SPIRAM); + + char *writeHead = msg; + + *writeHead = INITIAL_INFO; + writeHead += sizeof(uint8_t); + + *writeHead = sessionManager_.isMeasuring(); + writeHead += sizeof(uint8_t); + + *((uint32_t *)writeHead) = sessionManager_.isMeasuring() ? sessionId : 0; + writeHead += sizeof(uint32_t); + + assert(writeHead - msg == msgSize - sizeof(MeasurementT) * numMeasurements); + + memcpy(writeHead, session.getDataPointer(), sizeof(MeasurementT) * numMeasurements); + client.sendBinary(msg, msgSize); + + free(msg); +} diff --git a/firmware/src/firmware_main.cpp b/firmware/src/firmware_main.cpp index 17a1c75..db22030 100644 --- a/firmware/src/firmware_main.cpp +++ b/firmware/src/firmware_main.cpp @@ -4,6 +4,8 @@ #include "SwimTrackerConfig.h" #include +#include + // Own libs #include "MockScale.h" @@ -14,11 +16,13 @@ #include "SimpleMeasurementSession.h" #include "EspHttp.h" #include "WebDAV.h" +#include "WebsocketServer.h" using Session_T = SimpleMeasurementSession; SessionManager sessionManager(CONFIG_SCALE_DOUT_PIN, CONFIG_SCALE_SCK_PIN, CONFIG_TARE_AVG_COUNT); EspHttp espHttpServer; +WebsocketServer webSocketServer(sessionManager, 81); template void httpSetup(SessionManager *sessionManager) @@ -175,9 +179,43 @@ void setup() // HTTP & Websocket server httpSetup(&sessionManager); + + webSocketServer.begin(); } +int measurementsSent = 0; + void loop() { sessionManager.iteration(); + webSocketServer.iteration(); + /* + if (webSocketServer.poll()) + { + websocketClients[nextFreeWebsocketClient] = webSocketServer.accept(); + websocketClients[nextFreeWebsocketClient].onMessage(onMessage); + Serial.println("Websocket connection"); + nextFreeWebsocketClient = (nextFreeWebsocketClient + 1) % MAX_WEBSOCKET_CONNECTIONS; + } + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + //if (websocketClients[i].available()) { + //Serial.printf("Polling client %d\n", i); + websocketClients[i].poll(); + //} + + auto &session = sessionManager.session(); + if (session.numMeasurements() < measurementsSent) + measurementsSent = 0; + else if (session.numMeasurements() > measurementsSent) + { + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + if (websocketClients[i].available()) + { + auto dataToSend = (const char*)(session.getDataPointer() + measurementsSent); + auto numBytes = (session.numMeasurements() - measurementsSent) * sizeof(MeasurementT); + Serial.printf("Sent %d bytes via websocket\n", numBytes); + websocketClients[i].sendBinary(dataToSend, numBytes); + measurementsSent = session.numMeasurements(); + } + }*/ } diff --git a/firmware/websocket_test.py b/firmware/websocket_test.py new file mode 100644 index 0000000..d37fd32 --- /dev/null +++ b/firmware/websocket_test.py @@ -0,0 +1,36 @@ +import asyncio +import websockets +import struct +import numpy as np + +INITIAL_INFO = 1 +SESSION_STARTED = 2 +SESSION_STOPPED = 3 +SESSION_NEW_DATA = 4 + + +async def hello(): + uri = "ws://192.168.178.110:81" + async with websockets.connect(uri) as websocket: + while True: + res = await websocket.recv() + msg_type = struct.unpack("