diff --git a/firmware/lib/wifimanager/WifiManager.cpp b/firmware/lib/wifimanager/WifiManager.cpp index 12214ab..7b4d7bf 100644 --- a/firmware/lib/wifimanager/WifiManager.cpp +++ b/firmware/lib/wifimanager/WifiManager.cpp @@ -78,12 +78,13 @@ void WifiManager::resetToApProvisioning() prefs_.remove("staPassword"); } -void WifiManager::wifiWatchdog() +void WifiManager::iteration() { if (state_ == STA && WiFi.status() != WL_CONNECTED) { startWifi(); Serial.println("Connection lost - Restarting WIFI"); } + } diff --git a/firmware/lib/wifimanager/WifiManager.h b/firmware/lib/wifimanager/WifiManager.h index 501e256..09f28d4 100644 --- a/firmware/lib/wifimanager/WifiManager.h +++ b/firmware/lib/wifimanager/WifiManager.h @@ -14,7 +14,7 @@ * * When operating in access point mode, the device IP is 192.168.42.1 * - * call wifiWatchdog regularly to reconnect in station mode if connection was lost + * call iteration() regularly, it has a wifiWatchdog to reconnect in station mode if connection was lost */ class WifiManager { @@ -34,7 +34,7 @@ public: void setApCredentials(const char *password); void resetToApProvisioning(); - void wifiWatchdog(); + void iteration(); bool inProvisioningMode() const { return state_ == INVALID || state_ == AP_PROVISIONING; } diff --git a/firmware/src/MessageCodes.h b/firmware/src/MessageCodes.h index 1cb4b9e..f8c4325 100644 --- a/firmware/src/MessageCodes.h +++ b/firmware/src/MessageCodes.h @@ -12,6 +12,8 @@ enum class MessageCode : uint8_t SESSION_NEW_DATA = 5, ANSWER_USER_LIST = 6, ANSWER_SESSION_LIST = 7, + WIFI_STATE_RESPONSE = 8, + WIFI_SCAN_RESPONSE = 9, // from frontend to device START_SESSION = 128, @@ -21,4 +23,5 @@ enum class MessageCode : uint8_t QUERY_SESSION_LIST = 132, WIFI_STATE_SET = 133, WIFI_STATE_GET = 134, + WIFI_TRIGGER_SCAN = 135, }; diff --git a/firmware/src/SwimTrackerConfig.h b/firmware/src/SwimTrackerConfig.h index 4265a76..f8e57b1 100644 --- a/firmware/src/SwimTrackerConfig.h +++ b/firmware/src/SwimTrackerConfig.h @@ -2,13 +2,11 @@ #include - // Uncomment for Version 2.0 where load cell is connected differently -//#define _HW_V_20 +#define _HW_V_20 #define NEW_HEAVY_LOAD_CELL - -const char *CONFIG_HOSTNAME = "swimtracker"; +constexpr const char *CONFIG_HOSTNAME = "swimtracker"; // ------------------------------------- Hardware & Measurement Settings ------------------------------------------------------------ @@ -16,32 +14,31 @@ const uint8_t CONFIG_MEASUREMENT_AVG_COUNT = 1; // number of measurements in const uint8_t CONFIG_TARE_AVG_COUNT = 20; // number of measurements in tare-phase (to find 0 ) const int CONFIG_MEASURE_DELAY = 100; // interval in ms between measurements const uint32_t CONFIG_SESSION_MAX_LENGTH_HOURS = 3; // maximum length of one session -const char *CONFIG_DATA_PATH = "/dat"; // folder in SPIFFS file system to store measurement data +constexpr const char *CONFIG_DATA_PATH = "/dat"; // folder in SPIFFS file system to store measurement data using MeasurementT = uint16_t; // data type for one measurement #ifdef NEW_HEAVY_LOAD_CELL -const int CONFIG_VALUE_RIGHT_SHIFT = 3; // uint32 measurements are divided by this power, before stored in uint16_t +const int CONFIG_VALUE_RIGHT_SHIFT = 3; // uint32 measurements are divided by this power, before stored in uint16_t #else const int CONFIG_VALUE_RIGHT_SHIFT = 7; #endif -const MeasurementT CONFIG_KG_FACTOR_INV = 701; // after shifting - how many "measurement units" are one kg +const MeasurementT CONFIG_KG_FACTOR_INV = 701; // after shifting - how many "measurement units" are one kg +static constexpr int MAX_WEBSOCKET_CONNECTIONS = 3; // maximal number of websocket connections maintained at the same time -const char * UPDATE_URL = "https://swimtracker-update.bauer.tech/firmware.bin"; - +constexpr const char *UPDATE_URL = "https://swimtracker-update.bauer.tech/firmware.bin"; // auto start/stop -MeasurementT CONFIG_AUTO_START_MIN_THRESHOLD = CONFIG_KG_FACTOR_INV * 1; -MeasurementT CONFIG_AUTO_START_MAX_THRESHOLD = CONFIG_KG_FACTOR_INV * 3; -uint32_t CONFIG_AUTO_START_MAX_MEASUREMENTS_BETWEEN_PEAKS = (1000 / CONFIG_MEASURE_DELAY) * 6; +constexpr MeasurementT CONFIG_AUTO_START_MIN_THRESHOLD = CONFIG_KG_FACTOR_INV * 1; +constexpr MeasurementT CONFIG_AUTO_START_MAX_THRESHOLD = CONFIG_KG_FACTOR_INV * 3; +constexpr uint32_t CONFIG_AUTO_START_MAX_MEASUREMENTS_BETWEEN_PEAKS = (1000 / CONFIG_MEASURE_DELAY) * 6; -MeasurementT CONFIG_AUTO_STOP_THRESHOLD = CONFIG_KG_FACTOR_INV * 1; +constexpr MeasurementT CONFIG_AUTO_STOP_THRESHOLD = CONFIG_KG_FACTOR_INV * 1; //uint32_t CONFIG_AUTO_STOP_NUM_MEASUREMENTS = (1000 / CONFIG_MEASURE_DELAY) * 60 * 15; -uint32_t CONFIG_AUTO_STOP_NUM_MEASUREMENTS = (1000 / CONFIG_MEASURE_DELAY) * 30; - +constexpr uint32_t CONFIG_AUTO_STOP_NUM_MEASUREMENTS = (1000 / CONFIG_MEASURE_DELAY) * 30; // ------------------------------------- Derived Settings ----------------------------------------------------------------------------- -const uint32_t CONFIG_SESSION_MAX_SIZE = CONFIG_SESSION_MAX_LENGTH_HOURS * 3600 * (1000 / CONFIG_MEASURE_DELAY) * sizeof(uint16_t); +constexpr uint32_t CONFIG_SESSION_MAX_SIZE = CONFIG_SESSION_MAX_LENGTH_HOURS * 3600 * (1000 / CONFIG_MEASURE_DELAY) * sizeof(uint16_t); static_assert(CONFIG_SESSION_MAX_SIZE < 1024 * 1024, "Measurement data takes more than 1MiB space"); #ifdef _HW_V_20 diff --git a/firmware/src/WebsocketServer.h b/firmware/src/WebsocketServer.h index ee49939..d1ba245 100644 --- a/firmware/src/WebsocketServer.h +++ b/firmware/src/WebsocketServer.h @@ -2,21 +2,23 @@ #pragma once #include "Dtypes.h" #include "UserDB.h" +#include "MessageCodes.h" + #include +#include + template class SessionManager; -static constexpr int MAX_CONNECTIONS = 3; static constexpr int NUM_DATA_CHUNK_SIZE = 1; -template +template class WebsocketServer { public: - WebsocketServer(SessionManager &sessionManager, UserStorage &userStorage, int port) - : sessionManager_(sessionManager), userStorage_(userStorage), nextFreeClient_(0), port_(port), - running_(false) + WebsocketServer(int port, ApiManagerTuple &tuple) + : port_(port), nextFreeClient_(0), apiManagers_(tuple) { } @@ -25,328 +27,145 @@ public: server_.listen(port_); } - void iteration(); - -private: - void reportSessionUpdate(); - - void sendMessageOnConnection(websockets::WebsocketsClient &client); - void sendSessionStartMessages(); - void sendSessionStopMessages(); - void sendNewDataMessages(); - - void sendUserList(websockets::WebsocketsClient &client); - void sendSessionList(websockets::WebsocketsClient &client, const String &userId); - - SessionManager &sessionManager_; - UserStorage &userStorage_; - - int nextFreeClient_; - int port_; - - size_t sentMessageCount_; - websockets::WebsocketsServer server_; - websockets::WebsocketsClient clients_[MAX_CONNECTIONS]; - - // previous session state - size_t numSentMeasurements_[MAX_CONNECTIONS]; - bool running_; -}; - -using websockets::WebsocketsClient; - -// ------------------------------------- Message types & classes --------------------------- - -enum MessageType -{ - // from swim tracker device to frontend - INITIAL_INFO = 1, - SESSION_STARTED = 2, - SESSION_STOPPED = 3, - SESSION_NEW_DATA = 4, - ANSWER_USER_LIST = 5, - ANSWER_SESSION_LIST = 6, - - // from frontend to device - START_SESSION = 7, - STOP_SESSION = 8, - TARE = 9, - QUERY_USER_LIST = 10, - QUERY_SESSION_LIST = 11, -}; - -#pragma pack(push, 1) -class SessionStartedMessage -{ -public: - SessionStartedMessage(uint32_t id) : messageType_(SESSION_STARTED), sessionId_(id) {} - void send(WebsocketsClient &c) const + void iteration() { - c.sendBinary((const char *)(this), sizeof(*this)); - } - -private: - uint8_t messageType_; - uint32_t sessionId_; -}; - -class SessionStoppedMessage -{ -public: - SessionStoppedMessage() : messageType_(SESSION_STOPPED) {} - void send(WebsocketsClient &c) const - { - c.sendBinary((const char *)(this), sizeof(*this)); - } - -private: - uint8_t messageType_; -}; - -template -class SessionNewDataMessage -{ -public: - // typically a message contains NUM_DATA_CHUNK_SIZE measurements - // if some measurements are skipped, because loop() takes too long - // there might actually be more measurements, to be safe there is an - // additional factor here - static constexpr size_t MAX_MEASUREMENTS = 4 * NUM_DATA_CHUNK_SIZE; - - SessionNewDataMessage(MeasurementT *ptr, size_t numMeasurements) - : messageType_(SESSION_NEW_DATA), numMeasurements_(min(numMeasurements, MAX_MEASUREMENTS)) - { - memcpy(measurements_, ptr, sizeof(MeasurementT) * numMeasurements_); - } - - void send(WebsocketsClient &c) const - { - c.sendBinary((const char *)(this), numBytes()); - } - - size_t numMeasurements() const - { - return numMeasurements_; - } - -private: - size_t numBytes() const { return sizeof(uint8_t) + numMeasurements() * sizeof(MeasurementT); } - - // data to be sent - uint8_t messageType_; - MeasurementT measurements_[MAX_MEASUREMENTS]; - - // book-keeping - size_t numMeasurements_; -}; - -#pragma pack(pop) - -// ------------------------------------- WebsocketServer members --------------------------- - -template -void WebsocketServer::iteration() -{ - using namespace websockets; - - auto onMessage = [this](WebsocketsClient &client, WebsocketsMessage message) { - if (message.isPing()) - client.pong(); - else if (message.isBinary()) + using namespace websockets; + const auto onMessage = [this](WebsocketsClient &client, WebsocketsMessage message) { - const char *data = message.c_str(); - const size_t length = message.length(); - if (length < 1) + if (message.isPing()) + client.pong(); + else if (message.isBinary()) { - client.close(CloseReason_UnsupportedData); - return; - } + const char *data = message.c_str(); + const size_t length = message.length(); - uint8_t opCode = uint8_t(data[0]); - switch (opCode) - { - case START_SESSION: - this->sessionManager_.startMeasurements(); - break; - case STOP_SESSION: - this->sessionManager_.stopMeasurements(); - break; - case TARE: - this->sessionManager_.tare(); - break; - case QUERY_USER_LIST: - this->sendUserList(client); - break; - case QUERY_SESSION_LIST: - { - StaticJsonDocument doc; - deserializeMsgPack(doc, data, length); - String userId = doc.as(); - if (userId.length() > 0) - this->sendSessionList(client, userId); + MessageCode msgCode = MessageCode((uint8_t)(data[0])); + this->handlMessageImpl(client, msgCode, data + 1, length - 1); } - break; - default: + else client.close(CloseReason_UnsupportedData); - return; - } + }; + + if (server_.poll()) + { + Serial.println("new websocket connection"); + clients_[nextFreeClient_] = server_.accept(); + clients_[nextFreeClient_].onMessage(onMessage); + this->onClientConnectImpl(clients_[nextFreeClient_]); + nextFreeClient_ = (nextFreeClient_ + 1) % MAX_WEBSOCKET_CONNECTIONS; } - }; - if (server_.poll()) - { - clients_[nextFreeClient_] = server_.accept(); - clients_[nextFreeClient_].onMessage(onMessage); - Serial.println("new websocket connection"); - sendMessageOnConnection(clients_[nextFreeClient_]); - numSentMeasurements_[nextFreeClient_] = sessionManager_.session().numMeasurements(); - nextFreeClient_ = (nextFreeClient_ + 1) % MAX_CONNECTIONS; + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + clients_[i].poll(); + + this->iterationImpl<>(); } - for (int i = 0; i < MAX_CONNECTIONS; ++i) - clients_[i].poll(); - - reportSessionUpdate(); -} - -template -void WebsocketServer::reportSessionUpdate() -{ - if (!running_ && sessionManager_.isMeasuring()) + template + void sendToAll(MessageCode msgCode, const JsonDocument &content) { - sendSessionStartMessages(); - for (int i = 0; i < MAX_CONNECTIONS; ++i) - numSentMeasurements_[i] = 0; + char buffer[bufferSize]; + buffer[0] = (char)(msgCode); + size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), bufferSize - sizeof(msgCode)); + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + if (clients_[i].available()) + clients_[i].sendBinary(buffer, bytesWritten); } - else if (running_ && !sessionManager_.isMeasuring()) + + void sendToAll(MessageCode msgCode, const JsonDocument &content) { - sendSessionStopMessages(); - for (int i = 0; i < MAX_CONNECTIONS; ++i) - numSentMeasurements_[i] = 0; - } - sendNewDataMessages(); -} + size_t expectedSize = measureMsgPack(content); + char *buffer = (char *)malloc(expectedSize + sizeof(msgCode)); + buffer[0] = (char)(msgCode); + size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), expectedSize); + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + if (clients_[i].available()) + clients_[i].sendBinary(buffer, bytesWritten + sizeof(msgCode)); -template -void WebsocketServer::sendUserList(websockets::WebsocketsClient &client) -{ - const auto numUsers = userStorage_.numUsers(); - constexpr size_t constantSlack = 64; - DynamicJsonDocument result(JSON_ARRAY_SIZE(numUsers) + numUsers * (USER_STRING_ID_MAX_LEN + 2) + constantSlack); - JsonArray arr = result.to(); - for (auto userIt = userStorage_.beginWithoutUnassigned(); userIt != userStorage_.end(); ++userIt) - arr.add(userIt->stringId()); - - char buffer[MAX_USERS * (USER_STRING_ID_MAX_LEN + 1) + constantSlack]; - size_t bytesWritten = serializeMsgPack(result, buffer, sizeof(buffer)); - client.sendBinary(buffer, bytesWritten); -} - -template -void WebsocketServer::sendSessionList(websockets::WebsocketsClient &client, const String &userId) -{ - - User *user = userStorage_.getUserInfo(userId); - if (user != nullptr) - { - DynamicJsonDocument result(JSON_ARRAY_SIZE(user->numSessions()) + user->numSessions() * (sizeof(SessionIdType) + 8)); - JsonArray arr = result.to(); - for (SessionIdType *sIt = user->sessionBegin(); sIt != user->sessionEnd(); ++sIt) - arr.add(*sIt); - - size_t bytesToWrite = measureMsgPack(result); - char *buffer = (char *)malloc(bytesToWrite); - size_t bytesWritten = serializeMsgPack(result, buffer, bytesToWrite); - assert(bytesWritten <= bytesToWrite); - client.sendBinary(buffer, bytesWritten); free(buffer); } - else + + void sendToAll(MessageCode msgCode) { - DynamicJsonDocument result(JSON_ARRAY_SIZE(1) + 8); - result.to(); - char buffer[32]; - size_t bytesWritten = serializeMsgPack(result, buffer, sizeof(buffer)); - client.sendBinary(buffer, bytesWritten); + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + if (clients_[i].available()) + clients_[i].sendBinary((const char*)&msgCode, sizeof(MessageCode)); } -} -template -void WebsocketServer::sendSessionStartMessages() -{ - SessionStartedMessage msg(sessionManager_.session().getStartTime()); - for (auto &c : clients_) - if (c.available()) - msg.send(c); - running_ = sessionManager_.isMeasuring(); -} + websockets::WebsocketsClient &client(size_t i) { return clients_[i]; } -template -void WebsocketServer::sendSessionStopMessages() -{ - SessionStoppedMessage msg; - for (auto &c : clients_) - if (c.available()) - msg.send(c); - running_ = sessionManager_.isMeasuring(); -} - -template -void WebsocketServer::sendNewDataMessages() -{ - using MeasurementT = typename SessionT::MeasurementType; - auto &session = sessionManager_.session(); - - for (int i = 0; i < MAX_CONNECTIONS; ++i) +private: + // -- Tuple calls + template ::value - 1, typename std::enable_if::type = true> + void iterationImpl() { - auto &c = clients_[i]; - if (c.available()) - { - MeasurementT *dataToSend = session.getDataPointer() + numSentMeasurements_[i]; - int32_t numMeasurementsToSend = int32_t(session.numMeasurements()) - int32_t(numSentMeasurements_[i]); - if (numMeasurementsToSend > 0) - { - SessionNewDataMessage msg(dataToSend, numMeasurementsToSend); - msg.send(c); - numSentMeasurements_[i] += msg.numMeasurements(); - } - } + std::get(apiManagers_).iteration(*this); + iterationImpl(); } -} + template ::type = true> + void iterationImpl() {} -template -void WebsocketServer::sendMessageOnConnection(WebsocketsClient &client) + template ::value - 1, typename std::enable_if::type = true> + bool handlMessageImpl(websockets::WebsocketsClient &client, MessageCode code, const char *payload, size_t size) + { + bool handled = std::get(apiManagers_).handleMessage(client, code, payload, size); + if (handled) + return true; + else + return handlMessageImpl(client, code, payload, size); + } + template ::type = true> + bool handlMessageImpl(websockets::WebsocketsClient &, MessageCode, const char *, size_t) { return false; } + + template ::value - 1, typename std::enable_if::type = true> + void onClientConnectImpl(websockets::WebsocketsClient &client) + { + std::get(apiManagers_).onClientConnect(client); + onClientConnectImpl(client); + } + template ::type = true> + void onClientConnectImpl(websockets::WebsocketsClient &client) {} + + // -- Members + + int port_; + int nextFreeClient_; + + ApiManagerTuple apiManagers_; + + websockets::WebsocketsServer server_; + websockets::WebsocketsClient clients_[MAX_WEBSOCKET_CONNECTIONS]; +}; + +template +inline WebsocketServer> makeWebsocketServer(int port, ApiManagers... managers) { - using MeasurementT = typename SessionT::MeasurementType; - - // Message format: - // - uint8_t messageType - // - uint8_t running - // - uint32_t sessionId - // - MeasurementT [] measurements (if running) - - auto &session = sessionManager_.session(); - const auto numMeasurements = session.numMeasurements(); - const auto sessionId = session.getStartTime(); - - const size_t msgSize = sizeof(uint8_t) + sizeof(uint8_t) + sizeof(sessionId) + sizeof(MeasurementT) * numMeasurements; - char *msg = (char *)heap_caps_malloc(msgSize, MALLOC_CAP_SPIRAM); - - char *writeHead = msg; - - *writeHead = INITIAL_INFO; - writeHead += sizeof(uint8_t); - - *writeHead = sessionManager_.isMeasuring(); - writeHead += sizeof(uint8_t); - - *((uint32_t *)writeHead) = sessionManager_.isMeasuring() ? sessionId : 0; - writeHead += sizeof(uint32_t); - - assert(writeHead - msg == msgSize - sizeof(MeasurementT) * numMeasurements); - - memcpy(writeHead, session.getDataPointer(), sizeof(MeasurementT) * numMeasurements); - client.sendBinary(msg, msgSize); - - free(msg); + auto tuple = std::make_tuple(managers...); + return WebsocketServer(port, tuple); +} + +template +inline void sendToClient(websockets::WebsocketsClient &client, MessageCode msgCode, const JsonDocument &content) +{ + char buffer[bufferSize]; + buffer[0] = (char)(msgCode); + size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), bufferSize - sizeof(msgCode)); + client.sendBinary(buffer, bytesWritten + 1); +} + +inline void sendToClient(websockets::WebsocketsClient &client, MessageCode msgCode, const JsonDocument &content) +{ + size_t expectedSize = measureMsgPack(content); + char *buffer = (char *)malloc(expectedSize + sizeof(msgCode)); + buffer[0] = static_cast(msgCode); + size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), expectedSize); + client.sendBinary(buffer, bytesWritten + sizeof(msgCode)); + free(buffer); +} + +inline void sendErrorToClient(websockets::WebsocketsClient &client, const char *msg) +{ + DynamicJsonDocument doc(strlen(msg) + 64); + doc["msg"] = msg; + sendToClient(client, MessageCode::ERROR, doc); } diff --git a/firmware/src/WebsocketServerOld.h b/firmware/src/WebsocketServerOld.h new file mode 100644 index 0000000..eb4611c --- /dev/null +++ b/firmware/src/WebsocketServerOld.h @@ -0,0 +1,436 @@ + +#pragma once +#include "Dtypes.h" +#include "UserDB.h" +#include "MessageCodes.h" + +#include + +template +class SessionManager; + +static constexpr int NUM_DATA_CHUNK_SIZE = 1; + +template +class WebsocketInterface +{ +public: + WebsocketInterface(int port) : port_(port), nextFreeClient_(0) + { + } + + void begin() + { + server_.listen(port_); + } + + void iteration(); + + template + void sendToAll(MessageCode msgCode, const JsonDocument &content) + { + char buffer[bufferSize]; + buffer[0] = msgCode; + size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), bufferSize - sizeof(msgCode)); + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + if (clients_[i].available()) + clients_[i].sendBinary(buffer, bytesWritten); + } + + void sendToAll(MessageCode msgCode, const JsonDocument &content) + { + size_t expectedSize = measureMsgPack(content); + char *buffer = (char *)malloc(expectedSize + sizeof(msgCode)); + buffer[0] = msgCode; + size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), expectedSize); + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + if (clients_[i].available()) + clients_[i].sendBinary(buffer, bytesWritten + sizeof(msgCode)); + + free(buffer); + } + + void sendToAll(MessageCode msgCode) + { + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + if (clients_[i].available()) + clients_[i].sendBinary(&msgCode, sizeof(MessageCode)); + } + + websockets::WebsocketsClient &client(size_t i) { return clients_[i]; } + +private: + int port_; + int nextFreeClient_; + + websockets::WebsocketsServer server_; + websockets::WebsocketsClient clients_[MAX_WEBSOCKET_CONNECTIONS]; +}; + +template +inline void sendToClient(websockets::WebsocketsClient &client, MessageCode msgCode, const JsonDocument &content) +{ + char buffer[bufferSize]; + buffer[0] = msgCode; + size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), bufferSize - sizeof(msgCode)); + client.sendBinary(buffer, bytesWritten); +} + +inline void sendToClient(websockets::WebsocketsClient &client, MessageCode msgCode, const JsonDocument &content) +{ + size_t expectedSize = measureMsgPack(content); + char *buffer = (char *)malloc(expectedSize + sizeof(msgCode)); + buffer[0] = static_cast(msgCode); + size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), expectedSize); + client.sendBinary(buffer, bytesWritten + sizeof(msgCode)); + free(buffer); +} + +inline void sendErrorToClient(websockets::WebsocketsClient &client, const char *msg) +{ + DynamicJsonDocument doc(strlen(msg) + 64); + doc["msg"] = msg; + sendToClient(client, MessageCode::ERROR, doc); +} + +template +class WebsocketServer +{ +public: + WebsocketServer(SessionManager &sessionManager, UserStorage &userStorage, int port) + : sessionManager_(sessionManager), userStorage_(userStorage), nextFreeClient_(0), port_(port), + running_(false) + { + } + + void begin() + { + server_.listen(port_); + } + + void iteration(); + +private: + void reportSessionUpdate(); + + void sendMessageOnConnection(websockets::WebsocketsClient &client); + void sendSessionStartMessages(); + void sendSessionStopMessages(); + void sendNewDataMessages(); + + void sendUserList(websockets::WebsocketsClient &client); + void sendSessionList(websockets::WebsocketsClient &client, const String &userId); + + SessionManager &sessionManager_; + UserStorage &userStorage_; + + int nextFreeClient_; + int port_; + + size_t sentMessageCount_; + websockets::WebsocketsServer server_; + websockets::WebsocketsClient clients_[MAX_WEBSOCKET_CONNECTIONS]; + + // previous session state + size_t numSentMeasurements_[MAX_WEBSOCKET_CONNECTIONS]; + bool running_; +}; + +using websockets::WebsocketsClient; + +// ------------------------------------- Message types & classes --------------------------- + +enum MessageType +{ + // from swim tracker device to frontend + INITIAL_INFO = 1, + SESSION_STARTED = 2, + SESSION_STOPPED = 3, + SESSION_NEW_DATA = 4, + ANSWER_USER_LIST = 5, + ANSWER_SESSION_LIST = 6, + + // from frontend to device + START_SESSION = 7, + STOP_SESSION = 8, + TARE = 9, + QUERY_USER_LIST = 10, + QUERY_SESSION_LIST = 11, +}; + +#pragma pack(push, 1) +class SessionStartedMessage +{ +public: + SessionStartedMessage(uint32_t id) : messageType_(SESSION_STARTED), sessionId_(id) {} + void send(WebsocketsClient &c) const + { + c.sendBinary((const char *)(this), sizeof(*this)); + } + +private: + uint8_t messageType_; + uint32_t sessionId_; +}; + +class SessionStoppedMessage +{ +public: + SessionStoppedMessage() : messageType_(SESSION_STOPPED) {} + void send(WebsocketsClient &c) const + { + c.sendBinary((const char *)(this), sizeof(*this)); + } + +private: + uint8_t messageType_; +}; + +template +class SessionNewDataMessage +{ +public: + // typically a message contains NUM_DATA_CHUNK_SIZE measurements + // if some measurements are skipped, because loop() takes too long + // there might actually be more measurements, to be safe there is an + // additional factor here + static constexpr size_t MAX_MEASUREMENTS = 4 * NUM_DATA_CHUNK_SIZE; + + SessionNewDataMessage(MeasurementT *ptr, size_t numMeasurements) + : messageType_(SESSION_NEW_DATA), numMeasurements_(min(numMeasurements, MAX_MEASUREMENTS)) + { + memcpy(measurements_, ptr, sizeof(MeasurementT) * numMeasurements_); + } + + void send(WebsocketsClient &c) const + { + c.sendBinary((const char *)(this), numBytes()); + } + + size_t numMeasurements() const + { + return numMeasurements_; + } + +private: + size_t numBytes() const { return sizeof(uint8_t) + numMeasurements() * sizeof(MeasurementT); } + + // data to be sent + uint8_t messageType_; + MeasurementT measurements_[MAX_MEASUREMENTS]; + + // book-keeping + size_t numMeasurements_; +}; + +#pragma pack(pop) + +// ------------------------------------- WebsocketServer members --------------------------- + +template +void WebsocketServer::iteration() +{ + using namespace websockets; + + auto onMessage = [this](WebsocketsClient &client, WebsocketsMessage message) + { + if (message.isPing()) + client.pong(); + else if (message.isBinary()) + { + const char *data = message.c_str(); + const size_t length = message.length(); + if (length < 1) + { + client.close(CloseReason_UnsupportedData); + return; + } + + uint8_t opCode = uint8_t(data[0]); + switch (opCode) + { + case START_SESSION: + this->sessionManager_.startMeasurements(); + break; + case STOP_SESSION: + this->sessionManager_.stopMeasurements(); + break; + case TARE: + this->sessionManager_.tare(); + break; + case QUERY_USER_LIST: + this->sendUserList(client); + break; + case QUERY_SESSION_LIST: + { + StaticJsonDocument doc; + deserializeMsgPack(doc, data, length); + String userId = doc.as(); + if (userId.length() > 0) + this->sendSessionList(client, userId); + } + break; + default: + client.close(CloseReason_UnsupportedData); + return; + } + } + }; + + if (server_.poll()) + { + clients_[nextFreeClient_] = server_.accept(); + clients_[nextFreeClient_].onMessage(onMessage); + Serial.println("new websocket connection"); + sendMessageOnConnection(clients_[nextFreeClient_]); + numSentMeasurements_[nextFreeClient_] = sessionManager_.session().numMeasurements(); + nextFreeClient_ = (nextFreeClient_ + 1) % MAX_WEBSOCKET_CONNECTIONS; + } + + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + clients_[i].poll(); + + reportSessionUpdate(); +} + +template +void WebsocketServer::reportSessionUpdate() +{ + if (!running_ && sessionManager_.isMeasuring()) + { + sendSessionStartMessages(); + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + numSentMeasurements_[i] = 0; + } + else if (running_ && !sessionManager_.isMeasuring()) + { + sendSessionStopMessages(); + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + numSentMeasurements_[i] = 0; + } + sendNewDataMessages(); +} + +template +void WebsocketServer::sendUserList(websockets::WebsocketsClient &client) +{ + const auto numUsers = userStorage_.numUsers(); + constexpr size_t constantSlack = 64; + DynamicJsonDocument result(JSON_ARRAY_SIZE(numUsers) + numUsers * (USER_STRING_ID_MAX_LEN + 2) + constantSlack); + JsonArray arr = result.to(); + for (auto userIt = userStorage_.beginWithoutUnassigned(); userIt != userStorage_.end(); ++userIt) + arr.add(userIt->stringId()); + + char buffer[MAX_USERS * (USER_STRING_ID_MAX_LEN + 1) + constantSlack]; + size_t bytesWritten = serializeMsgPack(result, buffer, sizeof(buffer)); + client.sendBinary(buffer, bytesWritten); +} + +template +void WebsocketServer::sendSessionList(websockets::WebsocketsClient &client, const String &userId) +{ + + User *user = userStorage_.getUserInfo(userId); + if (user != nullptr) + { + DynamicJsonDocument result(JSON_ARRAY_SIZE(user->numSessions()) + user->numSessions() * (sizeof(SessionIdType) + 8)); + JsonArray arr = result.to(); + for (SessionIdType *sIt = user->sessionBegin(); sIt != user->sessionEnd(); ++sIt) + arr.add(*sIt); + + size_t bytesToWrite = measureMsgPack(result); + char *buffer = (char *)malloc(bytesToWrite); + size_t bytesWritten = serializeMsgPack(result, buffer, bytesToWrite); + assert(bytesWritten <= bytesToWrite); + client.sendBinary(buffer, bytesWritten); + free(buffer); + } + else + { + DynamicJsonDocument result(JSON_ARRAY_SIZE(1) + 8); + result.to(); + char buffer[32]; + size_t bytesWritten = serializeMsgPack(result, buffer, sizeof(buffer)); + client.sendBinary(buffer, bytesWritten); + } +} + +template +void WebsocketServer::sendSessionStartMessages() +{ + SessionStartedMessage msg(sessionManager_.session().getStartTime()); + for (auto &c : clients_) + if (c.available()) + msg.send(c); + running_ = sessionManager_.isMeasuring(); +} + +template +void WebsocketServer::sendSessionStopMessages() +{ + SessionStoppedMessage msg; + for (auto &c : clients_) + if (c.available()) + msg.send(c); + running_ = sessionManager_.isMeasuring(); +} + +template +void WebsocketServer::sendNewDataMessages() +{ + using MeasurementT = typename SessionT::MeasurementType; + auto &session = sessionManager_.session(); + + for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i) + { + auto &c = clients_[i]; + if (c.available()) + { + MeasurementT *dataToSend = session.getDataPointer() + numSentMeasurements_[i]; + int32_t numMeasurementsToSend = int32_t(session.numMeasurements()) - int32_t(numSentMeasurements_[i]); + if (numMeasurementsToSend > 0) + { + SessionNewDataMessage msg(dataToSend, numMeasurementsToSend); + msg.send(c); + numSentMeasurements_[i] += msg.numMeasurements(); + } + } + } +} + +template +void WebsocketServer::sendMessageOnConnection(WebsocketsClient &client) +{ + using MeasurementT = typename SessionT::MeasurementType; + + // Message format: + // - uint8_t messageType + // - uint8_t running + // - uint32_t sessionId + // - MeasurementT [] measurements (if running) + + auto &session = sessionManager_.session(); + const auto numMeasurements = session.numMeasurements(); + const auto sessionId = session.getStartTime(); + + const size_t msgSize = sizeof(uint8_t) + sizeof(uint8_t) + sizeof(sessionId) + sizeof(MeasurementT) * numMeasurements; + char *msg = (char *)heap_caps_malloc(msgSize, MALLOC_CAP_SPIRAM); + + char *writeHead = msg; + + *writeHead = INITIAL_INFO; + writeHead += sizeof(uint8_t); + + *writeHead = sessionManager_.isMeasuring(); + writeHead += sizeof(uint8_t); + + *((uint32_t *)writeHead) = sessionManager_.isMeasuring() ? sessionId : 0; + writeHead += sizeof(uint32_t); + + assert(writeHead - msg == msgSize - sizeof(MeasurementT) * numMeasurements); + + memcpy(writeHead, session.getDataPointer(), sizeof(MeasurementT) * numMeasurements); + client.sendBinary(msg, msgSize); + + free(msg); +} diff --git a/firmware/src/WifiAPI.cpp b/firmware/src/WifiAPI.cpp index eecdad2..7a5e20e 100644 --- a/firmware/src/WifiAPI.cpp +++ b/firmware/src/WifiAPI.cpp @@ -3,13 +3,23 @@ #include "WebsocketServer.h" +void WifiAPI::sendWifiState(websockets::WebsocketsClient &client) +{ + StaticJsonDocument<128> data; + data["state"] = wifiManager_.stateStr(); + sendToClient<64>(client, MessageCode::WIFI_STATE_RESPONSE, data); +} + +void WifiAPI::onClientConnect(websockets::WebsocketsClient &client) +{ + sendWifiState(client); +} + bool WifiAPI::handleMessage(websockets::WebsocketsClient &client, MessageCode code, const char *payload, size_t size) { if (code == MessageCode::WIFI_STATE_GET) { - StaticJsonDocument<128> data; - data["state"] = wifiManager_.stateStr(); - sendToClient<64>(client, MessageCode::WIFI_STATE_GET, data); + sendWifiState(client); return true; } else if (code == MessageCode::WIFI_STATE_SET) @@ -42,5 +52,10 @@ bool WifiAPI::handleMessage(websockets::WebsocketsClient &client, MessageCode co return true; } } + else if (code == MessageCode::WIFI_TRIGGER_SCAN) + { + WiFi.scanNetworks(true); + return true; + } return false; } diff --git a/firmware/src/WifiAPI.h b/firmware/src/WifiAPI.h index a7a983a..eafd30b 100644 --- a/firmware/src/WifiAPI.h +++ b/firmware/src/WifiAPI.h @@ -15,20 +15,73 @@ public: { } - void onClientConnect(websockets::WebsocketsClient &client) {} + void onClientConnect(websockets::WebsocketsClient &client); bool handleMessage(websockets::WebsocketsClient &client, MessageCode code, const char *payload, size_t size); template - void iteration(TServer &server) - { - if (restartScheduled_) - { - Serial.print("Restart triggered by WifiAPI"); - ESP.restart(); - } - } + void iteration(TServer &server); private: + void sendWifiState(websockets::WebsocketsClient &client); + + template + void reportScanResultIfAvailable(TServer &server); + WifiManager &wifiManager_; bool restartScheduled_; }; + +template +void WifiAPI::iteration(TServer &server) +{ + if (restartScheduled_) + { + Serial.print("Restart triggered by WifiAPI"); + ESP.restart(); + } + reportScanResultIfAvailable(server); +} + +template +void WifiAPI::reportScanResultIfAvailable(TServer &server) +{ + auto numNetworks = WiFi.scanComplete(); + + if (numNetworks >= 0) + { + DynamicJsonDocument response(192 * numNetworks); + for (uint16_t i = 0; i < numNetworks; ++i) + { + JsonObject wifiObj = response.createNestedObject(); + wifiObj["ssid"] = WiFi.SSID(i); + wifiObj["rssi"] = WiFi.RSSI(i); + wifiObj["channel"] = WiFi.channel(i); + + switch (WiFi.encryptionType(i)) + { + case WIFI_AUTH_OPEN: + wifiObj["sec"] = "open"; + break; + case WIFI_AUTH_WEP: + wifiObj["sec"] = "WEP"; + break; + case WIFI_AUTH_WPA_PSK: + wifiObj["sec"] = "WPA_PSK"; + break; + case WIFI_AUTH_WPA2_PSK: + wifiObj["sec"] = "WPA2_PSK"; + break; + case WIFI_AUTH_WPA_WPA2_PSK: + wifiObj["sec"] = "WPA_WPA2_PSK"; + break; + case WIFI_AUTH_WPA2_ENTERPRISE: + wifiObj["sec"] = "WPA2_ENTP"; + break; + default: + wifiObj["sec"] = "?"; + } + } + server.sendToAll(MessageCode::WIFI_SCAN_RESPONSE, response); + WiFi.scanDelete(); + } +} \ No newline at end of file diff --git a/firmware/src/firmware_main.cpp b/firmware/src/firmware_main.cpp index de58b18..5ae5c0a 100644 --- a/firmware/src/firmware_main.cpp +++ b/firmware/src/firmware_main.cpp @@ -20,19 +20,28 @@ #include "SimpleMeasurementSession.h" #include "EspHttp.h" #include "WebDAV.h" -#include "WebsocketServer.h" #include "UserDB.h" +// Api +#include "WebsocketServer.h" +#include "SessionAPI.h" +#include "WifiAPI.h" + using Session_T = SimpleMeasurementSession; SessionManager sessionManager; UserStorage userStorage; EspHttp espHttpServer; -WebsocketServer webSocketServer(sessionManager, userStorage, 81); WifiManager wifiManager; + +auto apiTuple = std::make_tuple(SessionAPI(sessionManager), WifiAPI(wifiManager)); +WebsocketServer websocketServer(81, apiTuple); + +//WebsocketServer webSocketServer(sessionManager, userStorage, 81); + extern const uint8_t certificate_pem[] asm("_binary_certificate_pem_start"); bool firmwareUpdate() @@ -86,32 +95,38 @@ void sessionManagerSetup() template void httpSetup(SessionManager *sessionManager, WifiManager *wifiManager) { - auto cbStartSession = [sessionManager](httpd_req_t *req) { + auto cbStartSession = [sessionManager](httpd_req_t *req) + { httpd_resp_set_hdr(req, "Access-Control-Allow-Origin", "*"); httpd_resp_send(req, "Session started", -1); sessionManager->startMeasurements(); Serial.println("Started session"); }; - auto cbStopSession = [sessionManager](httpd_req_t *req) { + auto cbStopSession = [sessionManager](httpd_req_t *req) + { httpd_resp_set_hdr(req, "Access-Control-Allow-Origin", "*"); httpd_resp_send(req, "Session stopped", -1); sessionManager->stopMeasurements(); Serial.println("Stopped session"); }; - auto cbRestart = [](httpd_req_t *req) { + auto cbRestart = [](httpd_req_t *req) + { Serial.println("Restarted requested"); ESP.restart(); }; - auto cbTare = [sessionManager](httpd_req_t *req) { + auto cbTare = [sessionManager](httpd_req_t *req) + { Serial.println("Tare"); sessionManager->tare(); }; - auto cbFirmwareUpdate = [](httpd_req_t *req) { + auto cbFirmwareUpdate = [](httpd_req_t *req) + { httpd_resp_set_hdr(req, "Access-Control-Allow-Origin", "*"); httpd_resp_send(req, "OK", -1); firmwareUpdate(); }; - auto cbStatus = [sessionManager](httpd_req_t *req) { + auto cbStatus = [sessionManager](httpd_req_t *req) + { httpd_resp_set_hdr(req, "Access-Control-Allow-Origin", "*"); httpd_resp_set_hdr(req, "Content-Type", "application/json"); @@ -167,7 +182,8 @@ void httpSetup(SessionManager *sessionManager, WifiManager *wifiManage auto bytesWritten = serializeJson(json, jsonText); httpd_resp_send(req, jsonText, bytesWritten); }; - auto cbGetData = [sessionManager](httpd_req_t *req) { + auto cbGetData = [sessionManager](httpd_req_t *req) + { auto sessionId = sessionManager->session().getStartTime(); uint32_t startIdx = getUrlQueryParameter(req, "startIdx", 0); //Serial.printf("Data request, start index: %d\n", startIdx); @@ -196,7 +212,8 @@ void httpSetup(SessionManager *sessionManager, WifiManager *wifiManage httpd_resp_send(req, buf, totalSize); free(buf); }; - auto cbWifiGet = [wifiManager](httpd_req_t *req) { + auto cbWifiGet = [wifiManager](httpd_req_t *req) + { httpd_resp_set_hdr(req, "Access-Control-Allow-Origin", "*"); StaticJsonDocument<128> json; json["state"] = wifiManager->stateStr(); @@ -204,7 +221,8 @@ void httpSetup(SessionManager *sessionManager, WifiManager *wifiManage auto bytesWritten = serializeJson(json, jsonText); httpd_resp_send(req, jsonText, bytesWritten); }; - auto cbWifiPost = [wifiManager](httpd_req_t *req) { + auto cbWifiPost = [wifiManager](httpd_req_t *req) + { httpd_resp_set_hdr(req, "Access-Control-Allow-Origin", "*"); StaticJsonDocument<1024> json; char content[512]; @@ -248,7 +266,8 @@ void httpSetup(SessionManager *sessionManager, WifiManager *wifiManage httpd_resp_set_status(req, "400 Bad Request"); httpd_resp_send(req, "Invalid keys in JSON", -1); }; - auto cbSettingsGet = [](httpd_req_t *req) { + auto cbSettingsGet = [](httpd_req_t *req) + { httpd_resp_set_hdr(req, "Access-Control-Allow-Origin", "*"); httpd_resp_set_hdr(req, "Content-Type", "application/json"); @@ -269,7 +288,8 @@ void httpSetup(SessionManager *sessionManager, WifiManager *wifiManage auto bytesWritten = serializeJson(json, jsonText); httpd_resp_send(req, jsonText, bytesWritten); }; - auto cbSettingsPost = [](httpd_req_t *req) { + auto cbSettingsPost = [](httpd_req_t *req) + { httpd_resp_set_hdr(req, "Access-Control-Allow-Origin", "*"); StaticJsonDocument<1024> json; char content[512]; @@ -313,7 +333,8 @@ void httpSetup(SessionManager *sessionManager, WifiManager *wifiManage sessionManagerSetup(); httpd_resp_send(req, "OK", -1); }; - auto cbSettingsDelete = [](httpd_req_t *req) { + auto cbSettingsDelete = [](httpd_req_t *req) + { httpd_resp_set_hdr(req, "Access-Control-Allow-Origin", "*"); Preferences prefs; prefs.begin("st_prefs"); @@ -419,13 +440,12 @@ void setup() // HTTP & Websocket server httpSetup(&sessionManager, &wifiManager); - if (!wifiManager.inProvisioningMode()) - webSocketServer.begin(); + websocketServer.begin(); } void loop() { sessionManager.iteration(); - webSocketServer.iteration(); - wifiManager.wifiWatchdog(); + wifiManager.iteration(); + websocketServer.iteration(); } diff --git a/firmware/websocket_test.py b/firmware/websocket_test.py index d37fd32..38b2464 100644 --- a/firmware/websocket_test.py +++ b/firmware/websocket_test.py @@ -1,36 +1,158 @@ import asyncio + import websockets import struct import numpy as np - -INITIAL_INFO = 1 -SESSION_STARTED = 2 -SESSION_STOPPED = 3 -SESSION_NEW_DATA = 4 +from pprint import pprint +import datetime +import msgpack +import aiomonitor -async def hello(): - uri = "ws://192.168.178.110:81" +class MsgManager: + def __init__(self): + self.msg_history = [] + + def add_msg(self, msg): + pprint(msg) + self.msg_history.append(msg) + + +send_functions = [] + + +class MsgCode: + ERROR = 1 + + # device to frontend + INITIAL_INFO = 2 + SESSION_STARTED = 3 + SESSION_STOPPED = 4 + SESSION_NEW_DATA = 5 + ANSWER_USER_LIST = 6 + ANSWER_SESSION_LIST = 7 + WIFI_STATE_RESPONSE = 8 + WIFI_SCAN_RESPONSE = 9 + + # from frontend to device + START_SESSION = 128 + STOP_SESSION = 129 + TARE = 130 + QUERY_USER_LIST = 131 + QUERY_SESSION_LIST = 132 + WIFI_STATE_SET = 133 + WIFI_STATE_GET = 134 + WIFI_TRIGGER_SCAN = 135 + + +async def send_message(websocket, msg_type, payload=None): + payload = struct.pack("