Files
pangolin/server/private/routers/ws/ws.ts
2026-01-19 17:57:55 -08:00

1082 lines
36 KiB
TypeScript

/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { Router, Request, Response } from "express";
import { Server as HttpServer } from "http";
import { WebSocket, WebSocketServer } from "ws";
import { Socket } from "net";
import {
Newt,
newts,
NewtSession,
olms,
Olm,
OlmSession,
RemoteExitNode,
RemoteExitNodeSession,
remoteExitNodes
} from "@server/db";
import { eq } from "drizzle-orm";
import { db } from "@server/db";
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
import logger from "@server/logger";
import redisManager from "#private/lib/redis";
import { v4 as uuidv4 } from "uuid";
import { validateRemoteExitNodeSessionToken } from "#private/auth/sessions/remoteExitNode";
import { rateLimitService } from "#private/lib/rateLimit";
import { messageHandlers } from "@server/routers/ws/messageHandlers";
import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers";
import {
AuthenticatedWebSocket,
ClientType,
WSMessage,
TokenPayload,
WebSocketRequest,
RedisMessage,
SendMessageOptions
} from "@server/routers/ws";
import { validateSessionToken } from "@server/auth/sessions/app";
// Merge public and private message handlers
Object.assign(messageHandlers, privateMessageHandlers);
const MAX_PENDING_MESSAGES = 50; // Maximum messages to queue during connection setup
// Helper function to process a single message
const processMessage = async (
ws: AuthenticatedWebSocket,
data: Buffer,
clientId: string,
clientType: ClientType
): Promise<void> => {
try {
const message: WSMessage = JSON.parse(data.toString());
// logger.debug(
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
// );
if (!message.type || typeof message.type !== "string") {
throw new Error("Invalid message format: missing or invalid type");
}
// Check rate limiting with message type awareness
const rateLimitResult = await rateLimitService.checkRateLimit(
clientId,
message.type, // Pass message type for granular limiting
100, // max requests per window
20, // max requests per message type per window
60 * 1000 // window in milliseconds
);
if (rateLimitResult.isLimited) {
const reason =
rateLimitResult.reason === "global"
? "too many messages"
: `too many '${message.type}' messages`;
logger.debug(
`Rate limit exceeded for ${clientType.toUpperCase()} ID: ${clientId} - ${reason}, ignoring message`
);
// Send rate limit error to client
// ws.send(JSON.stringify({
// type: "rate_limit_error",
// data: {
// message: `Rate limit exceeded: ${reason}`,
// messageType: message.type,
// reason: rateLimitResult.reason
// }
// }));
return;
}
const handler = messageHandlers[message.type];
if (!handler) {
throw new Error(`Unsupported message type: ${message.type}`);
}
const response = await handler({
message,
senderWs: ws,
client: ws.client,
clientType: ws.clientType!,
sendToClient,
broadcastToAllExcept,
connectedClients
});
if (response) {
if (response.broadcast) {
await broadcastToAllExcept(
response.message,
response.excludeSender ? clientId : undefined,
response.options
);
} else if (response.targetClientId) {
await sendToClient(
response.targetClientId,
response.message,
response.options
);
} else {
await sendToClient(
clientId,
response.message,
response.options
);
}
}
} catch (error) {
logger.error("Message handling error:", error);
// ws.send(JSON.stringify({
// type: "error",
// data: {
// message: error instanceof Error ? error.message : "Unknown error occurred",
// originalMessage: data.toString()
// }
// }));
}
};
// Helper function to process pending messages
const processPendingMessages = async (
ws: AuthenticatedWebSocket,
clientId: string,
clientType: ClientType
): Promise<void> => {
if (!ws.pendingMessages || ws.pendingMessages.length === 0) {
return;
}
logger.info(
`Processing ${ws.pendingMessages.length} pending messages for ${clientType.toUpperCase()} ID: ${clientId}`
);
const jobs = [];
for (const messageData of ws.pendingMessages) {
jobs.push(processMessage(ws, messageData, clientId, clientType));
}
await Promise.all(jobs);
ws.pendingMessages = []; // Clear pending messages to prevent reprocessing
};
const router: Router = Router();
const wss: WebSocketServer = new WebSocketServer({ noServer: true });
// Generate unique node ID for this instance
const NODE_ID = uuidv4();
const REDIS_CHANNEL = "websocket_messages";
// Client tracking map (local to this node)
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
// Config version tracking map (local to this node, resets on server restart)
const clientConfigVersions: Map<string, number> = new Map();
// Recovery tracking
let isRedisRecoveryInProgress = false;
// Helper to get map key
const getClientMapKey = (clientId: string) => clientId;
// Redis keys (generalized)
const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`;
const getNodeConnectionsKey = (nodeId: string, clientId: string) =>
`ws:node:${nodeId}:${clientId}`;
const getConfigVersionKey = (clientId: string) =>
`ws:configVersion:${clientId}`;
// Initialize Redis subscription for cross-node messaging
const initializeRedisSubscription = async (): Promise<void> => {
if (!redisManager.isRedisEnabled()) return;
await redisManager.subscribe(
REDIS_CHANNEL,
async (channel: string, message: string) => {
try {
const redisMessage: RedisMessage = JSON.parse(message);
// Ignore messages from this node
if (redisMessage.fromNodeId === NODE_ID) return;
if (
redisMessage.type === "direct" &&
redisMessage.targetClientId
) {
// Send to specific client on this node
await sendToClientLocal(
redisMessage.targetClientId,
redisMessage.message
);
} else if (redisMessage.type === "broadcast") {
// Broadcast to all clients on this node except excluded
await broadcastToAllExceptLocal(
redisMessage.message,
redisMessage.excludeClientId
);
}
} catch (error) {
logger.error("Error processing Redis message:", error);
}
}
);
};
// Simple self-healing recovery function
// Each node is responsible for restoring its own connection state to Redis
// This approach is more efficient than cross-node coordination because:
// 1. Each node knows its own connections (source of truth)
// 2. No network overhead from broadcasting state between nodes
// 3. No race conditions from simultaneous updates
// 4. Redis becomes eventually consistent as each node restores independently
// 5. Simpler logic with better fault tolerance
const recoverConnectionState = async (): Promise<void> => {
if (isRedisRecoveryInProgress) {
logger.debug("Redis recovery already in progress, skipping");
return;
}
isRedisRecoveryInProgress = true;
logger.info("Starting Redis connection state recovery...");
try {
// Each node simply restores its own local connections to Redis
// This is the source of truth - no need for cross-node coordination
await restoreLocalConnectionsToRedis();
logger.info(
"Redis connection state recovery completed - restored local state"
);
} catch (error) {
logger.error("Error during Redis recovery:", error);
} finally {
isRedisRecoveryInProgress = false;
}
};
const restoreLocalConnectionsToRedis = async (): Promise<void> => {
if (!redisManager.isRedisEnabled()) return;
logger.info("Restoring local connections to Redis...");
let restoredCount = 0;
try {
// Restore all current local connections to Redis
for (const [clientId, clients] of connectedClients.entries()) {
const validClients = clients.filter(
(client) => client.readyState === WebSocket.OPEN
);
if (validClients.length > 0) {
// Add this node to the client's connection list
await redisManager.sadd(getConnectionsKey(clientId), NODE_ID);
// Store individual connection details
for (const client of validClients) {
if (client.connectionId) {
await redisManager.hset(
getNodeConnectionsKey(NODE_ID, clientId),
client.connectionId,
Date.now().toString()
);
}
}
restoredCount++;
}
}
logger.info(`Restored ${restoredCount} client connections to Redis`);
} catch (error) {
logger.error("Failed to restore local connections to Redis:", error);
}
};
// Helper functions for client management
const addClient = async (
clientType: ClientType,
clientId: string,
ws: AuthenticatedWebSocket
): Promise<void> => {
// Generate unique connection ID
const connectionId = uuidv4();
ws.connectionId = connectionId;
// Add to local tracking
const mapKey = getClientMapKey(clientId);
const existingClients = connectedClients.get(mapKey) || [];
existingClients.push(ws);
connectedClients.set(mapKey, existingClients);
// Get or initialize config version
let configVersion = 0;
// Check Redis first if enabled
if (redisManager.isRedisEnabled()) {
try {
const redisVersion = await redisManager.get(getConfigVersionKey(clientId));
if (redisVersion !== null) {
configVersion = parseInt(redisVersion, 10);
// Sync to local cache
clientConfigVersions.set(clientId, configVersion);
} else if (!clientConfigVersions.has(clientId)) {
// No version in Redis or local cache, initialize to 0
await redisManager.set(getConfigVersionKey(clientId), "0");
clientConfigVersions.set(clientId, 0);
} else {
// Use local cache version and sync to Redis
configVersion = clientConfigVersions.get(clientId) || 0;
await redisManager.set(getConfigVersionKey(clientId), configVersion.toString());
}
} catch (error) {
logger.error("Failed to get/set config version in Redis:", error);
// Fall back to local cache
if (!clientConfigVersions.has(clientId)) {
clientConfigVersions.set(clientId, 0);
}
configVersion = clientConfigVersions.get(clientId) || 0;
}
} else {
// Redis not enabled, use local cache only
if (!clientConfigVersions.has(clientId)) {
clientConfigVersions.set(clientId, 0);
}
configVersion = clientConfigVersions.get(clientId) || 0;
}
// Set config version on websocket
ws.configVersion = configVersion;
// Add to Redis tracking if enabled
if (redisManager.isRedisEnabled()) {
try {
await redisManager.sadd(getConnectionsKey(clientId), NODE_ID);
await redisManager.hset(
getNodeConnectionsKey(NODE_ID, clientId),
connectionId,
Date.now().toString()
);
} catch (error) {
logger.error(
"Failed to add client to Redis tracking (connection still functional locally):",
error
);
}
}
logger.info(
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}, Config version: ${configVersion}`
);
};
const removeClient = async (
clientType: ClientType,
clientId: string,
ws: AuthenticatedWebSocket
): Promise<void> => {
const mapKey = getClientMapKey(clientId);
const existingClients = connectedClients.get(mapKey) || [];
const updatedClients = existingClients.filter((client) => client !== ws);
if (updatedClients.length === 0) {
connectedClients.delete(mapKey);
if (redisManager.isRedisEnabled()) {
try {
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
await redisManager.del(
getNodeConnectionsKey(NODE_ID, clientId)
);
} catch (error) {
logger.error(
"Failed to remove client from Redis tracking (cleanup will occur on recovery):",
error
);
}
}
logger.info(
`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`
);
} else {
connectedClients.set(mapKey, updatedClients);
if (redisManager.isRedisEnabled() && ws.connectionId) {
try {
await redisManager.hdel(
getNodeConnectionsKey(NODE_ID, clientId),
ws.connectionId
);
} catch (error) {
logger.error(
"Failed to remove specific connection from Redis tracking:",
error
);
}
}
logger.info(
`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`
);
}
};
// Helper to get the current config version for a client
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
// Try Redis first if available
if (redisManager.isRedisEnabled()) {
try {
const redisVersion = await redisManager.get(
getConfigVersionKey(clientId)
);
if (redisVersion !== null) {
const version = parseInt(redisVersion, 10);
// Sync local cache with Redis
clientConfigVersions.set(clientId, version);
return version;
}
} catch (error) {
logger.error("Failed to get config version from Redis:", error);
}
}
// Fall back to local cache
return clientConfigVersions.get(clientId);
};
// Helper to increment and get the new config version for a client
const incrementClientConfigVersion = async (
clientId: string
): Promise<number> => {
let newVersion: number;
if (redisManager.isRedisEnabled()) {
try {
// Use Redis INCR for atomic increment across nodes
newVersion = await redisManager.incr(getConfigVersionKey(clientId));
// Sync local cache
clientConfigVersions.set(clientId, newVersion);
return newVersion;
} catch (error) {
logger.error("Failed to increment config version in Redis:", error);
// Fall through to local increment
}
}
// Local increment
const currentVersion = clientConfigVersions.get(clientId) || 0;
newVersion = currentVersion + 1;
clientConfigVersions.set(clientId, newVersion);
return newVersion;
};
// Local message sending (within this node)
const sendToClientLocal = async (
clientId: string,
message: WSMessage,
options: SendMessageOptions = {}
): Promise<boolean> => {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
if (!clients || clients.length === 0) {
return false;
}
// Handle config version
const configVersion = await getClientConfigVersion(clientId);
// Add config version to message
const messageWithVersion = {
...message,
configVersion
};
const messageString = JSON.stringify(messageWithVersion);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(messageString);
}
});
return true;
};
const broadcastToAllExceptLocal = async (
message: WSMessage,
excludeClientId?: string,
options: SendMessageOptions = {}
): Promise<void> => {
for (const [mapKey, clients] of connectedClients.entries()) {
const [type, id] = mapKey.split(":");
const clientId = mapKey; // mapKey is the clientId
if (!(excludeClientId && clientId === excludeClientId)) {
// Handle config version per client
let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId);
}
// Add config version to message
const messageWithVersion = {
...message,
configVersion
};
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion));
}
});
}
}
};
// Cross-node message sending (via Redis)
const sendToClient = async (
clientId: string,
message: WSMessage,
options: SendMessageOptions = {}
): Promise<boolean> => {
let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId);
}
logger.debug(
`sendToClient: Message type ${message.type} sent to clientId ${clientId} (new configVersion: ${configVersion})`
);
// Try to send locally first
const localSent = await sendToClientLocal(clientId, message, options);
// Only send via Redis if the client is not connected locally and Redis is enabled
if (!localSent && redisManager.isRedisEnabled()) {
try {
const redisMessage: RedisMessage = {
type: "direct",
targetClientId: clientId,
message: {
...message,
configVersion
},
fromNodeId: NODE_ID
};
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
} catch (error) {
logger.error(
"Failed to send message via Redis, message may be lost:",
error
);
// Continue execution - local delivery already attempted
}
} else if (!localSent && !redisManager.isRedisEnabled()) {
// Redis is disabled or unavailable - log that we couldn't deliver to remote nodes
logger.debug(
`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`
);
}
return localSent;
};
const broadcastToAllExcept = async (
message: WSMessage,
excludeClientId?: string,
options: SendMessageOptions = {}
): Promise<void> => {
// Broadcast locally
await broadcastToAllExceptLocal(message, excludeClientId, options);
// If Redis is enabled, also broadcast via Redis pub/sub to other nodes
// Note: For broadcasts, we include the options so remote nodes can handle versioning
if (redisManager.isRedisEnabled()) {
try {
const redisMessage: RedisMessage = {
type: "broadcast",
excludeClientId,
message,
fromNodeId: NODE_ID,
options
};
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
} catch (error) {
logger.error(
"Failed to broadcast message via Redis, remote nodes may not receive it:",
error
);
// Continue execution - local broadcast already completed
}
} else {
logger.debug(
"Redis unavailable - broadcast limited to local node only"
);
}
};
// Check if a client has active connections across all nodes
const hasActiveConnections = async (clientId: string): Promise<boolean> => {
if (!redisManager.isRedisEnabled()) {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
return !!(clients && clients.length > 0);
}
const activeNodes = await redisManager.smembers(
getConnectionsKey(clientId)
);
return activeNodes.length > 0;
};
// Get all active nodes for a client
const getActiveNodes = async (
clientType: ClientType,
clientId: string
): Promise<string[]> => {
if (!redisManager.isRedisEnabled()) {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
return clients && clients.length > 0 ? [NODE_ID] : [];
}
return await redisManager.smembers(getConnectionsKey(clientId));
};
// Token verification middleware
const verifyToken = async (
token: string,
clientType: ClientType,
userToken: string
): Promise<TokenPayload | null> => {
try {
if (clientType === "newt") {
const { session, newt } = await validateNewtSessionToken(token);
if (!session || !newt) {
return null;
}
const existingNewt = await db
.select()
.from(newts)
.where(eq(newts.newtId, newt.newtId));
if (!existingNewt || !existingNewt[0]) {
return null;
}
return { client: existingNewt[0], session, clientType };
} else if (clientType === "olm") {
const { session, olm } = await validateOlmSessionToken(token);
if (!session || !olm) {
return null;
}
const existingOlm = await db
.select()
.from(olms)
.where(eq(olms.olmId, olm.olmId));
if (!existingOlm || !existingOlm[0]) {
return null;
}
if (olm.userId) {
// this is a user device and we need to check the user token
const { session: userSession, user } =
await validateSessionToken(userToken);
if (!userSession || !user) {
return null;
}
if (user.userId !== olm.userId) {
return null;
}
}
return { client: existingOlm[0], session, clientType };
} else if (clientType === "remoteExitNode") {
const { session, remoteExitNode } =
await validateRemoteExitNodeSessionToken(token);
if (!session || !remoteExitNode) {
return null;
}
const existingRemoteExitNode = await db
.select()
.from(remoteExitNodes)
.where(
eq(
remoteExitNodes.remoteExitNodeId,
remoteExitNode.remoteExitNodeId
)
);
if (!existingRemoteExitNode || !existingRemoteExitNode[0]) {
return null;
}
return { client: existingRemoteExitNode[0], session, clientType };
}
return null;
} catch (error) {
logger.error("Token verification failed:", error);
return null;
}
};
const setupConnection = async (
ws: AuthenticatedWebSocket,
client: Newt | Olm | RemoteExitNode,
clientType: ClientType
): Promise<void> => {
logger.info("Establishing websocket connection");
if (!client) {
logger.error("Connection attempt without client");
return ws.terminate();
}
ws.client = client;
ws.clientType = clientType;
ws.isFullyConnected = false;
ws.pendingMessages = [];
// Get client ID first
let clientId: string;
if (clientType === "newt") {
clientId = (client as Newt).newtId;
} else if (clientType === "olm") {
clientId = (client as Olm).olmId;
} else if (clientType === "remoteExitNode") {
clientId = (client as RemoteExitNode).remoteExitNodeId;
} else {
throw new Error(`Unknown client type: ${clientType}`);
}
// Set up message handler FIRST to prevent race condition
ws.on("message", async (data) => {
if (!ws.isFullyConnected) {
// Queue message for later processing with limits
ws.pendingMessages = ws.pendingMessages || [];
if (ws.pendingMessages.length >= MAX_PENDING_MESSAGES) {
logger.warn(
`Too many pending messages for ${clientType.toUpperCase()} ID: ${clientId}, dropping oldest message`
);
ws.pendingMessages.shift(); // Remove oldest message
}
logger.debug(
`Queueing message from ${clientType.toUpperCase()} ID: ${clientId} (connection not fully established)`
);
ws.pendingMessages.push(data as Buffer);
return;
}
await processMessage(ws, data as Buffer, clientId, clientType);
});
// Set up other event handlers before async operations
ws.on("close", async () => {
// Clear any pending messages to prevent memory leaks
if (ws.pendingMessages) {
ws.pendingMessages = [];
}
await removeClient(clientType, clientId, ws);
logger.info(
`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`
);
});
ws.on("error", (error: Error) => {
logger.error(
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,
error
);
});
try {
await addClient(clientType, clientId, ws);
// Mark connection as fully established
ws.isFullyConnected = true;
logger.info(
`WebSocket connection fully established and ready - ${clientType.toUpperCase()} ID: ${clientId}`
);
// Process any messages that were queued while connection was being established
await processPendingMessages(ws, clientId, clientType);
} catch (error) {
logger.error(
`Failed to fully establish connection for ${clientType.toUpperCase()} ID: ${clientId}:`,
error
);
// ws.send(JSON.stringify({
// type: "connection_error",
// data: {
// message: "Failed to establish connection"
// }
// }));
ws.terminate();
return;
}
};
// Router endpoint
router.get("/ws", (req: Request, res: Response) => {
res.status(200).send("WebSocket endpoint");
});
// WebSocket upgrade handler
const handleWSUpgrade = (server: HttpServer): void => {
server.on(
"upgrade",
async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
try {
const url = new URL(
request.url || "",
`http://${request.headers.host}`
);
const token =
url.searchParams.get("token") ||
request.headers["sec-websocket-protocol"] ||
"";
const userToken = url.searchParams.get("userToken") || "";
let clientType = url.searchParams.get(
"clientType"
) as ClientType;
if (!clientType) {
clientType = "newt";
}
if (
!token ||
!clientType ||
!["newt", "olm", "remoteExitNode"].includes(clientType)
) {
logger.warn(
"Unauthorized connection attempt: invalid token or client type..."
);
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
socket.destroy();
return;
}
const tokenPayload = await verifyToken(
token,
clientType,
userToken
);
if (!tokenPayload) {
logger.debug(
"Unauthorized connection attempt: invalid token..."
);
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
socket.destroy();
return;
}
wss.handleUpgrade(
request,
socket,
head,
(ws: AuthenticatedWebSocket) => {
setupConnection(
ws,
tokenPayload.client,
tokenPayload.clientType
);
}
);
} catch (error) {
logger.error("WebSocket upgrade error:", error);
socket.write("HTTP/1.1 500 Internal Server Error\r\n\r\n");
socket.destroy();
}
}
);
};
// Add periodic connection state sync to handle Redis disconnections/reconnections
const startPeriodicStateSync = (): void => {
// Lightweight sync every 5 minutes - just restore our own state
setInterval(
async () => {
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
try {
await restoreLocalConnectionsToRedis();
logger.debug("Periodic connection state sync completed");
} catch (error) {
logger.error(
"Error during periodic connection state sync:",
error
);
}
}
},
5 * 60 * 1000
); // 5 minutes
// Cleanup stale connections every 15 minutes
setInterval(
async () => {
if (redisManager.isRedisEnabled()) {
try {
await cleanupStaleConnections();
logger.debug("Periodic connection cleanup completed");
} catch (error) {
logger.error(
"Error during periodic connection cleanup:",
error
);
}
}
},
15 * 60 * 1000
); // 15 minutes
};
const cleanupStaleConnections = async (): Promise<void> => {
if (!redisManager.isRedisEnabled()) return;
try {
const nodeKeys =
(await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`)) ||
[];
for (const nodeKey of nodeKeys) {
const connections = await redisManager.hgetall(nodeKey);
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, "");
const localClients = connectedClients.get(clientId) || [];
const localConnectionIds = localClients
.filter((client) => client.readyState === WebSocket.OPEN)
.map((client) => client.connectionId)
.filter(Boolean);
// Remove Redis entries for connections that no longer exist locally
for (const [connectionId, timestamp] of Object.entries(
connections
)) {
if (!localConnectionIds.includes(connectionId)) {
await redisManager.hdel(nodeKey, connectionId);
logger.debug(
`Cleaned up stale connection: ${connectionId} for client: ${clientId}`
);
}
}
// If no connections remain for this client, remove from Redis entirely
const remainingConnections = await redisManager.hgetall(nodeKey);
if (Object.keys(remainingConnections).length === 0) {
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
await redisManager.del(nodeKey);
logger.debug(
`Cleaned up empty connection tracking for client: ${clientId}`
);
}
}
} catch (error) {
logger.error("Error cleaning up stale connections:", error);
}
};
// Initialize Redis subscription when the module is loaded
if (redisManager.isRedisEnabled()) {
initializeRedisSubscription().catch((error) => {
logger.error("Failed to initialize Redis subscription:", error);
});
// Register recovery callback with Redis manager
// When Redis reconnects, each node simply restores its own local state
redisManager.onReconnection(async () => {
logger.info("Redis reconnected, starting WebSocket state recovery...");
await recoverConnectionState();
});
// Start periodic state synchronization
startPeriodicStateSync();
logger.info(
`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`
);
} else {
logger.debug("WebSocket handler initialized in local mode");
}
// Disconnect a specific client and force them to reconnect
const disconnectClient = async (clientId: string): Promise<boolean> => {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
if (!clients || clients.length === 0) {
logger.debug(`No connections found for client ID: ${clientId}`);
return false;
}
logger.info(
`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`
);
// Close all connections for this client
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.close(1000, "Disconnected by server");
}
});
return true;
};
// Cleanup function for graceful shutdown
const cleanup = async (): Promise<void> => {
try {
// Close all WebSocket connections
connectedClients.forEach((clients) => {
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.terminate();
}
});
});
// Clean up Redis tracking for this node
if (redisManager.isRedisEnabled()) {
const keys =
(await redisManager
.getClient()
?.keys(`ws:node:${NODE_ID}:*`)) || [];
if (keys.length > 0) {
await Promise.all(keys.map((key) => redisManager.del(key)));
}
}
logger.info("WebSocket cleanup completed");
} catch (error) {
logger.error("Error during WebSocket cleanup:", error);
}
};
export {
router,
handleWSUpgrade,
sendToClient,
broadcastToAllExcept,
connectedClients,
hasActiveConnections,
getActiveNodes,
disconnectClient,
NODE_ID,
cleanup,
getClientConfigVersion
};