swimtracker-firmware/firmware/src/WebsocketServer.h

177 lines
5.9 KiB
C++

#pragma once
#include "Dtypes.h"
#include "UserDB.h"
#include "MessageCodes.h"
#include "SwimTrackerConfig.h"
#include <ArduinoWebsockets.h>
#include <ArduinoJson.h>
#include <type_traits>
template <typename T>
class SessionManager;
static constexpr int NUM_DATA_CHUNK_SIZE = 1;
template <typename ApiManagerTuple>
class WebsocketServer
{
public:
WebsocketServer(int port, ApiManagerTuple &tuple)
: port_(port), nextFreeClient_(0), apiManagers_(tuple)
{
}
void begin()
{
server_.listen(port_);
}
void iteration()
{
using namespace websockets;
const auto onMessage = [this](WebsocketsClient &client, WebsocketsMessage message)
{
if (message.isPing()) // websocket ping
client.pong();
else if (message.isBinary())
{
const char *data = message.c_str();
const size_t length = message.length();
MessageCode msgCode = MessageCode((uint8_t)(data[0]));
this->handleMessageImpl(client, msgCode, data + 1, length - 1);
}
else
client.close(CloseReason_UnsupportedData);
};
if (server_.poll())
{
LOG_INFO("new websocket connection, storing at pos %d - occupancy: ", nextFreeClient_);
clients_[nextFreeClient_] = server_.accept();
clients_[nextFreeClient_].onMessage(onMessage);
this->onClientConnectImpl(clients_[nextFreeClient_]);
nextFreeClient_ = (nextFreeClient_ + 1) % MAX_WEBSOCKET_CONNECTIONS;
for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i)
LOG_INFO((clients_[i].available()) ? "x" : "o");
}
for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i)
clients_[i].poll();
this->iterationImpl<>();
}
template <size_t bufferSize>
void sendToAll(MessageCode msgCode, const JsonDocument &content)
{
char buffer[bufferSize];
buffer[0] = (char)(msgCode);
size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), bufferSize - sizeof(msgCode));
for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i)
if (clients_[i].available())
clients_[i].sendBinary(buffer, bytesWritten);
}
void sendToAll(MessageCode msgCode, const JsonDocument &content)
{
size_t expectedSize = measureMsgPack(content);
char *buffer = (char *)malloc(expectedSize + sizeof(msgCode));
buffer[0] = (char)(msgCode);
size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), expectedSize);
for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i)
if (clients_[i].available())
clients_[i].sendBinary(buffer, bytesWritten + sizeof(msgCode));
free(buffer);
}
void sendToAll(MessageCode msgCode)
{
for (int i = 0; i < MAX_WEBSOCKET_CONNECTIONS; ++i)
if (clients_[i].available())
clients_[i].sendBinary((const char *)&msgCode, sizeof(MessageCode));
}
websockets::WebsocketsClient &client(size_t i) { return clients_[i]; }
private:
// -- Tuple calls
template <size_t managerIdx = 0, typename std::enable_if<(managerIdx < std::tuple_size<ApiManagerTuple>::value), bool>::type = true>
void iterationImpl()
{
std::get<managerIdx>(apiManagers_).iteration(*this);
iterationImpl<managerIdx + 1>();
}
template <size_t managerIdx, typename std::enable_if<managerIdx == std::tuple_size<ApiManagerTuple>::value, bool>::type = true>
void iterationImpl() {}
template <size_t managerIdx = 0, typename std::enable_if<(managerIdx < std::tuple_size<ApiManagerTuple>::value), bool>::type = true>
bool handleMessageImpl(websockets::WebsocketsClient &client, MessageCode code, const char *payload, size_t size)
{
bool handled = std::get<managerIdx>(apiManagers_).handleMessage(client, code, payload, size);
if (handled)
return true;
else
return handleMessageImpl<managerIdx + 1>(client, code, payload, size);
}
template <size_t managerIdx, typename std::enable_if<managerIdx == std::tuple_size<ApiManagerTuple>::value, bool>::type = true>
bool handleMessageImpl(websockets::WebsocketsClient &, MessageCode, const char *, size_t) { return false; }
template <size_t managerIdx = 0, typename std::enable_if<(managerIdx < std::tuple_size<ApiManagerTuple>::value), bool>::type = true>
void onClientConnectImpl(websockets::WebsocketsClient &client)
{
std::get<managerIdx>(apiManagers_).onClientConnect(client);
onClientConnectImpl<managerIdx + 1>(client);
}
template <size_t managerIdx, typename std::enable_if<managerIdx == std::tuple_size<ApiManagerTuple>::value, bool>::type = true>
void onClientConnectImpl(websockets::WebsocketsClient &client) {}
// -- Members
int port_;
int nextFreeClient_;
ApiManagerTuple apiManagers_;
websockets::WebsocketsServer server_;
websockets::WebsocketsClient clients_[MAX_WEBSOCKET_CONNECTIONS];
};
template <typename... ApiManagers>
inline WebsocketServer<std::tuple<ApiManagers...>> makeWebsocketServer(int port, ApiManagers... managers)
{
auto tuple = std::make_tuple(managers...);
return WebsocketServer<decltype(tuple)>(port, tuple);
}
template <size_t bufferSize>
inline void sendToClient(websockets::WebsocketsClient &client, MessageCode msgCode, const JsonDocument &content)
{
char buffer[bufferSize];
buffer[0] = (char)(msgCode);
size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), bufferSize - sizeof(msgCode));
client.sendBinary(buffer, bytesWritten + 1);
}
inline void sendToClient(websockets::WebsocketsClient &client, MessageCode msgCode, const JsonDocument &content)
{
size_t expectedSize = measureMsgPack(content);
char *buffer = (char *)malloc(expectedSize + sizeof(msgCode));
buffer[0] = static_cast<char>(msgCode);
size_t bytesWritten = serializeMsgPack(content, buffer + sizeof(msgCode), expectedSize);
client.sendBinary(buffer, bytesWritten + sizeof(msgCode));
free(buffer);
}
inline void sendErrorToClient(websockets::WebsocketsClient &client, const char *msg)
{
DynamicJsonDocument doc(strlen(msg) + 64);
doc["msg"] = msg;
sendToClient(client, MessageCode::ERROR, doc);
}