#pragma once #include "Dtypes.h" #include "UserDB.h" #include template class SessionManager; static constexpr int MAX_CONNECTIONS = 3; static constexpr int NUM_DATA_CHUNK_SIZE = 1; template class WebsocketServer { public: WebsocketServer(SessionManager &sessionManager, UserStorage &userStorage, int port) : sessionManager_(sessionManager), userStorage_(userStorage), nextFreeClient_(0), port_(port), running_(false) { } void begin() { server_.listen(port_); } void iteration(); private: void reportSessionUpdate(); void sendMessageOnConnection(websockets::WebsocketsClient &client); void sendSessionStartMessages(); void sendSessionStopMessages(); void sendNewDataMessages(); void sendUserList(websockets::WebsocketsClient &client); void sendSessionList(websockets::WebsocketsClient &client, const String &userId); SessionManager &sessionManager_; UserStorage &userStorage_; int nextFreeClient_; int port_; size_t sentMessageCount_; websockets::WebsocketsServer server_; websockets::WebsocketsClient clients_[MAX_CONNECTIONS]; // previous session state size_t numSentMeasurements_[MAX_CONNECTIONS]; bool running_; }; using websockets::WebsocketsClient; // ------------------------------------- Message types & classes --------------------------- enum MessageType { // from swim tracker device to frontend INITIAL_INFO = 1, SESSION_STARTED = 2, SESSION_STOPPED = 3, SESSION_NEW_DATA = 4, ANSWER_USER_LIST = 5, ANSWER_SESSION_LIST = 6, // from frontend to device START_SESSION = 7, STOP_SESSION = 8, TARE = 9, QUERY_USER_LIST = 10, QUERY_SESSION_LIST = 11, }; #pragma pack(push, 1) class SessionStartedMessage { public: SessionStartedMessage(uint32_t id) : messageType_(SESSION_STARTED), sessionId_(id) {} void send(WebsocketsClient &c) const { c.sendBinary((const char *)(this), sizeof(*this)); } private: uint8_t messageType_; uint32_t sessionId_; }; class SessionStoppedMessage { public: SessionStoppedMessage() : messageType_(SESSION_STOPPED) {} void send(WebsocketsClient &c) const { c.sendBinary((const char *)(this), sizeof(*this)); } private: uint8_t messageType_; }; template class SessionNewDataMessage { public: // typically a message contains NUM_DATA_CHUNK_SIZE measurements // if some measurements are skipped, because loop() takes too long // there might actually be more measurements, to be safe there is an // additional factor here static constexpr size_t MAX_MEASUREMENTS = 4 * NUM_DATA_CHUNK_SIZE; SessionNewDataMessage(MeasurementT *ptr, size_t numMeasurements) : messageType_(SESSION_NEW_DATA), numMeasurements_(min(numMeasurements, MAX_MEASUREMENTS)) { memcpy(measurements_, ptr, sizeof(MeasurementT) * numMeasurements_); } void send(WebsocketsClient &c) const { c.sendBinary((const char *)(this), numBytes()); } size_t numMeasurements() const { return numMeasurements_; } private: size_t numBytes() const { return sizeof(uint8_t) + numMeasurements() * sizeof(MeasurementT); } // data to be sent uint8_t messageType_; MeasurementT measurements_[MAX_MEASUREMENTS]; // book-keeping size_t numMeasurements_; }; #pragma pack(pop) // ------------------------------------- WebsocketServer members --------------------------- template void WebsocketServer::iteration() { using namespace websockets; auto onMessage = [this](WebsocketsClient &client, WebsocketsMessage message) { if (message.isPing()) client.pong(); else if (message.isBinary()) { const char *data = message.c_str(); const size_t length = message.length(); if (length < 1) { client.close(CloseReason_UnsupportedData); return; } uint8_t opCode = uint8_t(data[0]); switch (opCode) { case START_SESSION: this->sessionManager_.startMeasurements(); break; case STOP_SESSION: this->sessionManager_.stopMeasurements(); break; case TARE: this->sessionManager_.tare(); break; case QUERY_USER_LIST: this->sendUserList(client); break; case QUERY_SESSION_LIST: { StaticJsonDocument doc; deserializeMsgPack(doc, data, length); String userId = doc.as(); if (userId.length() > 0) this->sendSessionList(client, userId); } break; default: client.close(CloseReason_UnsupportedData); return; } } }; if (server_.poll()) { clients_[nextFreeClient_] = server_.accept(); clients_[nextFreeClient_].onMessage(onMessage); Serial.println("new websocket connection"); sendMessageOnConnection(clients_[nextFreeClient_]); numSentMeasurements_[nextFreeClient_] = sessionManager_.session().numMeasurements(); nextFreeClient_ = (nextFreeClient_ + 1) % MAX_CONNECTIONS; } for (int i = 0; i < MAX_CONNECTIONS; ++i) clients_[i].poll(); reportSessionUpdate(); } template void WebsocketServer::reportSessionUpdate() { if (!running_ && sessionManager_.isMeasuring()) { sendSessionStartMessages(); for (int i = 0; i < MAX_CONNECTIONS; ++i) numSentMeasurements_[i] = 0; } else if (running_ && !sessionManager_.isMeasuring()) { sendSessionStopMessages(); for (int i = 0; i < MAX_CONNECTIONS; ++i) numSentMeasurements_[i] = 0; } sendNewDataMessages(); } template void WebsocketServer::sendUserList(websockets::WebsocketsClient &client) { const auto numUsers = userStorage_.numUsers(); constexpr size_t constantSlack = 64; DynamicJsonDocument result(JSON_ARRAY_SIZE(numUsers) + numUsers * (USER_STRING_ID_MAX_LEN + 2) + constantSlack); JsonArray arr = result.to(); for (auto userIt = userStorage_.beginWithoutUnassigned(); userIt != userStorage_.end(); ++userIt) arr.add(userIt->stringId()); char buffer[MAX_USERS * (USER_STRING_ID_MAX_LEN + 1) + constantSlack]; size_t bytesWritten = serializeMsgPack(result, buffer, sizeof(buffer)); client.sendBinary(buffer, bytesWritten); } template void WebsocketServer::sendSessionList(websockets::WebsocketsClient &client, const String &userId) { User *user = userStorage_.getUserInfo(userId); if (user != nullptr) { DynamicJsonDocument result(JSON_ARRAY_SIZE(user->numSessions()) + user->numSessions() * (sizeof(SessionIdType) + 8)); JsonArray arr = result.to(); for (SessionIdType *sIt = user->sessionBegin(); sIt != user->sessionEnd(); ++sIt) arr.add(*sIt); size_t bytesToWrite = measureMsgPack(result); char *buffer = (char *)malloc(bytesToWrite); size_t bytesWritten = serializeMsgPack(result, buffer, bytesToWrite); assert(bytesWritten <= bytesToWrite); client.sendBinary(buffer, bytesWritten); free(buffer); } else { DynamicJsonDocument result(JSON_ARRAY_SIZE(1) + 8); result.to(); char buffer[32]; size_t bytesWritten = serializeMsgPack(result, buffer, sizeof(buffer)); client.sendBinary(buffer, bytesWritten); } } template void WebsocketServer::sendSessionStartMessages() { SessionStartedMessage msg(sessionManager_.session().getStartTime()); for (auto &c : clients_) if (c.available()) msg.send(c); running_ = sessionManager_.isMeasuring(); } template void WebsocketServer::sendSessionStopMessages() { SessionStoppedMessage msg; for (auto &c : clients_) if (c.available()) msg.send(c); running_ = sessionManager_.isMeasuring(); } template void WebsocketServer::sendNewDataMessages() { using MeasurementT = typename SessionT::MeasurementType; auto &session = sessionManager_.session(); for (int i = 0; i < MAX_CONNECTIONS; ++i) { auto &c = clients_[i]; if (c.available()) { MeasurementT *dataToSend = session.getDataPointer() + numSentMeasurements_[i]; int32_t numMeasurementsToSend = int32_t(session.numMeasurements()) - int32_t(numSentMeasurements_[i]); if (numMeasurementsToSend > 0) { SessionNewDataMessage msg(dataToSend, numMeasurementsToSend); msg.send(c); numSentMeasurements_[i] += msg.numMeasurements(); } } } } template void WebsocketServer::sendMessageOnConnection(WebsocketsClient &client) { using MeasurementT = typename SessionT::MeasurementType; // Message format: // - uint8_t messageType // - uint8_t running // - uint32_t sessionId // - MeasurementT [] measurements (if running) auto &session = sessionManager_.session(); const auto numMeasurements = session.numMeasurements(); const auto sessionId = session.getStartTime(); const size_t msgSize = sizeof(uint8_t) + sizeof(uint8_t) + sizeof(sessionId) + sizeof(MeasurementT) * numMeasurements; char *msg = (char *)heap_caps_malloc(msgSize, MALLOC_CAP_SPIRAM); char *writeHead = msg; *writeHead = INITIAL_INFO; writeHead += sizeof(uint8_t); *writeHead = sessionManager_.isMeasuring(); writeHead += sizeof(uint8_t); *((uint32_t *)writeHead) = sessionManager_.isMeasuring() ? sessionId : 0; writeHead += sizeof(uint32_t); assert(writeHead - msg == msgSize - sizeof(MeasurementT) * numMeasurements); memcpy(writeHead, session.getDataPointer(), sizeof(MeasurementT) * numMeasurements); client.sendBinary(msg, msgSize); free(msg); }