swimtracker-firmware/firmware/src/SessionAPI.h

192 lines
5.5 KiB
C++

#include "Dtypes.h"
#include "MessageCodes.h"
#include "SwimTrackerConfig.h"
#include <ArduinoWebsockets.h>
#include <ArduinoJson.h>
template <typename T>
class SessionManager;
template <typename SessionT>
class SessionAPI
{
public:
SessionAPI(SessionManager<SessionT> &sessionManager)
: sessionManager_(sessionManager)
{
for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i)
numSentMeasurements_[i] = 0;
}
void onClientConnect(websockets::WebsocketsClient &client);
bool handleMessage(websockets::WebsocketsClient &client, MessageCode code, const char *payload, size_t size);
template <typename TServer>
void iteration(TServer &server);
private:
template <typename TServer>
void sendSessionStartMessages(TServer &server);
template <typename TServer>
void sendSessionStopMessages(TServer &server);
template <typename TServer>
void sendNewDataMessages(TServer &server);
SessionManager<SessionT> &sessionManager_;
size_t numSentMeasurements_[MAX_WEBSOCKET_CONNECTIONS];
bool running_;
};
// sending message about current session
template <typename T>
void SessionAPI<T>::onClientConnect(websockets::WebsocketsClient &client)
{
// TODO write msgpack instead for consistency?
using MeasurementT = typename T::MeasurementType;
// Message format:
// - uint8_t messageType
// - uint8_t running
// - uint32_t sessionId
// - MeasurementT [] measurements (if running)
auto &session = sessionManager_.session();
const auto numMeasurements = session.numMeasurements();
const auto sessionId = session.getStartTime();
const size_t msgSize = sizeof(uint8_t) + sizeof(uint8_t) + sizeof(sessionId) + sizeof(MeasurementT) * numMeasurements;
char *msg = (char *)heap_caps_malloc(msgSize, MALLOC_CAP_SPIRAM);
char *writeHead = msg;
*writeHead = static_cast<char>(MessageCode::INITIAL_INFO);
writeHead += sizeof(uint8_t);
*writeHead = sessionManager_.isMeasuring();
writeHead += sizeof(uint8_t);
*((uint32_t *)writeHead) = sessionManager_.isMeasuring() ? sessionId : 0;
writeHead += sizeof(uint32_t);
assert(writeHead - msg == msgSize - sizeof(MeasurementT) * numMeasurements);
memcpy(writeHead, session.getDataPointer(), sizeof(MeasurementT) * numMeasurements);
client.sendBinary(msg, msgSize);
free(msg);
}
template <typename T>
bool SessionAPI<T>::handleMessage(websockets::WebsocketsClient &client, MessageCode code,
const char *payload, size_t size)
{
switch (code)
{
case MessageCode::START_SESSION:
Serial.println("SessionAPI: starting measurement session");
this->sessionManager_.startMeasurements();
return true;
case MessageCode::STOP_SESSION:
Serial.println("SessionAPI: stopping measurement session");
this->sessionManager_.stopMeasurements();
return true;
case MessageCode::TARE:
this->sessionManager_.tare();
return true;
default:
return false;
}
return false;
}
template <typename T>
template <typename TServer>
void SessionAPI<T>::iteration(TServer &server)
{
if (!running_ && sessionManager_.isMeasuring())
{
sendSessionStartMessages(server);
for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i)
numSentMeasurements_[i] = 0;
running_ = true;
}
else if (running_ && !sessionManager_.isMeasuring())
{
sendSessionStopMessages(server);
for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i)
numSentMeasurements_[i] = 0;
running_ = false;
}
if(running_)
sendNewDataMessages(server);
else
{
static unsigned long lastPing = 0;
auto timeNow = millis();
if(timeNow - lastPing > CONFIG_PING_INTERVAL_MS)
{
server.sendToAll(MessageCode::APP_LAYER_PING);
lastPing = timeNow;
}
}
}
template <typename T>
template <typename TServer>
void SessionAPI<T>::sendSessionStartMessages(TServer &server)
{
StaticJsonDocument<128> data;
data["sessionId"] = sessionManager_.session().getStartTime();
server.template sendToAll<32>(MessageCode::SESSION_STARTED, data);
}
template <typename T>
template <typename TServer>
void SessionAPI<T>::sendSessionStopMessages(TServer &server)
{
MessageCode code = MessageCode::SESSION_STOPPED;
server.template sendToAll(code);
}
template <typename T>
template <typename TServer>
void SessionAPI<T>::sendNewDataMessages(TServer &server)
{
constexpr size_t MAX_MEASUREMENTS_PER_MSG = 15;
constexpr size_t WAIT_UNTIL_AT_LEAST_NUM_MEASUREMENTS = 1;
// new data messages are the only messages not sent in msgpack format
// since they are sent multiple times a second
using MeasurementT = typename T::MeasurementType;
auto &session = sessionManager_.session();
char buffer[1 + MAX_MEASUREMENTS_PER_MSG * sizeof(MeasurementT)];
buffer[0] = static_cast<char>(MessageCode::SESSION_NEW_DATA);
constexpr int headerSize = 1;
for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i)
{
auto &c = server.client(i);
if (c.available())
{
MeasurementT *dataToSend = session.getDataPointer() + numSentMeasurements_[i];
int32_t numMeasurementsToSend = int32_t(session.numMeasurements()) - int32_t(numSentMeasurements_[i]);
if (numMeasurementsToSend >= WAIT_UNTIL_AT_LEAST_NUM_MEASUREMENTS)
{
if (numMeasurementsToSend > MAX_MEASUREMENTS_PER_MSG)
numMeasurementsToSend = MAX_MEASUREMENTS_PER_MSG;
memcpy(buffer + headerSize, dataToSend, sizeof(MeasurementT) * numMeasurementsToSend);
c.sendBinary(buffer, headerSize + sizeof(MeasurementT) * numMeasurementsToSend);
numSentMeasurements_[i] += numMeasurementsToSend;
}
}
}
}