From 322f3bfb1d7d9d96c71b72b1a374e618a0518073 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 19 Dec 2025 16:44:57 -0500 Subject: [PATCH] Add version and send it down --- server/private/lib/redis.ts | 14 ++++ server/private/routers/ws/ws.ts | 126 ++++++++++++++++++++++++++++---- server/routers/ws/types.ts | 13 +++- server/routers/ws/ws.ts | 70 +++++++++++++++--- 4 files changed, 194 insertions(+), 29 deletions(-) diff --git a/server/private/lib/redis.ts b/server/private/lib/redis.ts index 6b7826ea..49cd4c61 100644 --- a/server/private/lib/redis.ts +++ b/server/private/lib/redis.ts @@ -573,6 +573,20 @@ class RedisManager { } } + public async incr(key: string): Promise { + if (!this.isRedisEnabled() || !this.writeClient) return 0; + + try { + return await this.executeWithRetry( + () => this.writeClient!.incr(key), + "Redis INCR" + ); + } catch (error) { + logger.error("Redis INCR error:", error); + return 0; + } + } + public async sadd(key: string, member: string): Promise { if (!this.isRedisEnabled() || !this.writeClient) return false; diff --git a/server/private/routers/ws/ws.ts b/server/private/routers/ws/ws.ts index 784c3d51..4b9a3295 100644 --- a/server/private/routers/ws/ws.ts +++ b/server/private/routers/ws/ws.ts @@ -43,7 +43,8 @@ import { WSMessage, TokenPayload, WebSocketRequest, - RedisMessage + RedisMessage, + SendMessageOptions } from "@server/routers/ws"; import { validateSessionToken } from "@server/auth/sessions/app"; @@ -172,6 +173,9 @@ const REDIS_CHANNEL = "websocket_messages"; // Client tracking map (local to this node) const connectedClients: Map = new Map(); +// Config version tracking map (local to this node, resets on server restart) +const clientConfigVersions: Map = new Map(); + // Recovery tracking let isRedisRecoveryInProgress = false; @@ -182,6 +186,7 @@ const getClientMapKey = (clientId: string) => clientId; const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`; const getNodeConnectionsKey = (nodeId: string, clientId: string) => `ws:node:${nodeId}:${clientId}`; +const getConfigVersionKey = (clientId: string) => `ws:configVersion:${clientId}`; // Initialize Redis subscription for cross-node messaging const initializeRedisSubscription = async (): Promise => { @@ -377,17 +382,76 @@ const removeClient = async ( } }; +// Helper to get the current config version for a client +const getClientConfigVersion = async (clientId: string): Promise => { + // Try Redis first if available + if (redisManager.isRedisEnabled()) { + try { + const redisVersion = await redisManager.get(getConfigVersionKey(clientId)); + if (redisVersion !== null) { + const version = parseInt(redisVersion, 10); + // Sync local cache with Redis + clientConfigVersions.set(clientId, version); + return version; + } + } catch (error) { + logger.error("Failed to get config version from Redis:", error); + } + } + + // Fall back to local cache + return clientConfigVersions.get(clientId) || 0; +}; + +// Helper to increment and get the new config version for a client +const incrementClientConfigVersion = async (clientId: string): Promise => { + let newVersion: number; + + if (redisManager.isRedisEnabled()) { + try { + // Use Redis INCR for atomic increment across nodes + newVersion = await redisManager.incr(getConfigVersionKey(clientId)); + // Sync local cache + clientConfigVersions.set(clientId, newVersion); + return newVersion; + } catch (error) { + logger.error("Failed to increment config version in Redis:", error); + // Fall through to local increment + } + } + + // Local increment + const currentVersion = clientConfigVersions.get(clientId) || 0; + newVersion = currentVersion + 1; + clientConfigVersions.set(clientId, newVersion); + return newVersion; +}; + // Local message sending (within this node) const sendToClientLocal = async ( clientId: string, - message: WSMessage + message: WSMessage, + options: SendMessageOptions = {} ): Promise => { const mapKey = getClientMapKey(clientId); const clients = connectedClients.get(mapKey); if (!clients || clients.length === 0) { return false; } - const messageString = JSON.stringify(message); + + // Handle config version + let configVersion = await getClientConfigVersion(clientId); + if (options.incrementConfigVersion) { + configVersion = await incrementClientConfigVersion(clientId); + } + + // Add config version to message + const messageWithVersion = { + ...message, + configVersion + }; + + const messageString = JSON.stringify(messageWithVersion); clients.forEach((client) => { if (client.readyState === WebSocket.OPEN) { client.send(messageString); @@ -395,43 +459,69 @@ const sendToClientLocal = async ( }); logger.debug( - `sendToClient: Message type ${message.type} sent to clientId ${clientId}` + `sendToClient: Message type ${message.type} sent to clientId ${clientId} (configVersion: ${configVersion})` ); + return true; }; const broadcastToAllExceptLocal = async ( message: WSMessage, - excludeClientId?: string + excludeClientId?: string, + options: SendMessageOptions = {} ): Promise => { - connectedClients.forEach((clients, mapKey) => { + for (const [mapKey, clients] of connectedClients.entries()) { const [type, id] = mapKey.split(":"); - if (!(excludeClientId && id === excludeClientId)) { + const clientId = mapKey; // mapKey is the clientId + if (!(excludeClientId && clientId === excludeClientId)) { + // Handle config version per client + let configVersion = await getClientConfigVersion(clientId); + if (options.incrementConfigVersion) { + configVersion = await incrementClientConfigVersion(clientId); + } + + // Add config version to message + const messageWithVersion = { + ...message, + configVersion + }; + clients.forEach((client) => { if (client.readyState === WebSocket.OPEN) { - client.send(JSON.stringify(message)); + client.send(JSON.stringify(messageWithVersion)); } }); } - }); + } }; // Cross-node message sending (via Redis) const sendToClient = async ( clientId: string, - message: WSMessage + message: WSMessage, + options: SendMessageOptions = {} ): Promise => { // Try to send locally first - const localSent = await sendToClientLocal(clientId, message); + const localSent = await sendToClientLocal(clientId, message, options); // Only send via Redis if the client is not connected locally and Redis is enabled if (!localSent && redisManager.isRedisEnabled()) { try { + // If we need to increment config version, do it before sending via Redis + // so remote nodes send the correct version + let configVersion = await getClientConfigVersion(clientId); + if (options.incrementConfigVersion) { + configVersion = await incrementClientConfigVersion(clientId); + } + const redisMessage: RedisMessage = { type: "direct", targetClientId: clientId, - message, + message: { + ...message, + configVersion + }, fromNodeId: NODE_ID }; @@ -458,19 +548,22 @@ const sendToClient = async ( const broadcastToAllExcept = async ( message: WSMessage, - excludeClientId?: string + excludeClientId?: string, + options: SendMessageOptions = {} ): Promise => { // Broadcast locally - await broadcastToAllExceptLocal(message, excludeClientId); + await broadcastToAllExceptLocal(message, excludeClientId, options); // If Redis is enabled, also broadcast via Redis pub/sub to other nodes + // Note: For broadcasts, we include the options so remote nodes can handle versioning if (redisManager.isRedisEnabled()) { try { const redisMessage: RedisMessage = { type: "broadcast", excludeClientId, message, - fromNodeId: NODE_ID + fromNodeId: NODE_ID, + options }; await redisManager.publish( @@ -936,5 +1029,6 @@ export { getActiveNodes, disconnectClient, NODE_ID, - cleanup + cleanup, + getClientConfigVersion }; diff --git a/server/routers/ws/types.ts b/server/routers/ws/types.ts index b4ec690b..5cca3c09 100644 --- a/server/routers/ws/types.ts +++ b/server/routers/ws/types.ts @@ -25,6 +25,7 @@ export interface AuthenticatedWebSocket extends WebSocket { connectionId?: string; isFullyConnected?: boolean; pendingMessages?: Buffer[]; + configVersion?: number; } export interface TokenPayload { @@ -36,6 +37,7 @@ export interface TokenPayload { export interface WSMessage { type: string; data: any; + configVersion?: number; } export interface HandlerResponse { @@ -50,10 +52,11 @@ export interface HandlerContext { senderWs: WebSocket; client: Newt | Olm | RemoteExitNode | undefined; clientType: ClientType; - sendToClient: (clientId: string, message: WSMessage) => Promise; + sendToClient: (clientId: string, message: WSMessage, options?: SendMessageOptions) => Promise; broadcastToAllExcept: ( message: WSMessage, - excludeClientId?: string + excludeClientId?: string, + options?: SendMessageOptions ) => Promise; connectedClients: Map; } @@ -62,6 +65,11 @@ export type MessageHandler = ( context: HandlerContext ) => Promise; +// Options for sending messages with config version tracking +export interface SendMessageOptions { + incrementConfigVersion?: boolean; +} + // Redis message type for cross-node communication export interface RedisMessage { type: "direct" | "broadcast"; @@ -69,4 +77,5 @@ export interface RedisMessage { excludeClientId?: string; message: WSMessage; fromNodeId: string; + options?: SendMessageOptions; } diff --git a/server/routers/ws/ws.ts b/server/routers/ws/ws.ts index 0544af9d..f707848c 100644 --- a/server/routers/ws/ws.ts +++ b/server/routers/ws/ws.ts @@ -15,7 +15,8 @@ import { TokenPayload, WebSocketRequest, WSMessage, - AuthenticatedWebSocket + AuthenticatedWebSocket, + SendMessageOptions } from "./types"; import { validateSessionToken } from "@server/auth/sessions/app"; @@ -34,6 +35,8 @@ const NODE_ID = uuidv4(); // Client tracking map (local to this node) const connectedClients: Map = new Map(); +// Config version tracking map (clientId -> version) +const clientConfigVersions: Map = new Map(); // Helper to get map key const getClientMapKey = (clientId: string) => clientId; @@ -84,14 +87,34 @@ const removeClient = async ( // Local message sending (within this node) const sendToClientLocal = async ( clientId: string, - message: WSMessage + message: WSMessage, + options: SendMessageOptions = {} ): Promise => { const mapKey = getClientMapKey(clientId); const clients = connectedClients.get(mapKey); if (!clients || clients.length === 0) { return false; } - const messageString = JSON.stringify(message); + + // Increment config version if requested + if (options.incrementConfigVersion) { + const currentVersion = clientConfigVersions.get(clientId) || 0; + const newVersion = currentVersion + 1; + clientConfigVersions.set(clientId, newVersion); + // Update version on all client connections + clients.forEach((client) => { + client.configVersion = newVersion; + }); + } + + // Include config version in message + const configVersion = clientConfigVersions.get(clientId) || 0; + const messageWithVersion = { + ...message, + configVersion + }; + + const messageString = JSON.stringify(messageWithVersion); clients.forEach((client) => { if (client.readyState === WebSocket.OPEN) { client.send(messageString); @@ -102,14 +125,31 @@ const sendToClientLocal = async ( const broadcastToAllExceptLocal = async ( message: WSMessage, - excludeClientId?: string + excludeClientId?: string, + options: SendMessageOptions = {} ): Promise => { connectedClients.forEach((clients, mapKey) => { const [type, id] = mapKey.split(":"); - if (!(excludeClientId && id === excludeClientId)) { + const clientId = mapKey; // mapKey is the clientId + if (!(excludeClientId && clientId === excludeClientId)) { + // Handle config version per client + if (options.incrementConfigVersion) { + const currentVersion = clientConfigVersions.get(clientId) || 0; + const newVersion = currentVersion + 1; + clientConfigVersions.set(clientId, newVersion); + clients.forEach((client) => { + client.configVersion = newVersion; + }); + } + // Include config version in message for this client + const configVersion = clientConfigVersions.get(clientId) || 0; + const messageWithVersion = { + ...message, + configVersion + }; clients.forEach((client) => { if (client.readyState === WebSocket.OPEN) { - client.send(JSON.stringify(message)); + client.send(JSON.stringify(messageWithVersion)); } }); } @@ -119,10 +159,11 @@ const broadcastToAllExceptLocal = async ( // Cross-node message sending const sendToClient = async ( clientId: string, - message: WSMessage + message: WSMessage, + options: SendMessageOptions = {} ): Promise => { // Try to send locally first - const localSent = await sendToClientLocal(clientId, message); + const localSent = await sendToClientLocal(clientId, message, options); logger.debug( `sendToClient: Message type ${message.type} sent to clientId ${clientId}` @@ -133,10 +174,11 @@ const sendToClient = async ( const broadcastToAllExcept = async ( message: WSMessage, - excludeClientId?: string + excludeClientId?: string, + options: SendMessageOptions = {} ): Promise => { // Broadcast locally - await broadcastToAllExceptLocal(message, excludeClientId); + await broadcastToAllExceptLocal(message, excludeClientId, options); }; // Check if a client has active connections across all nodes @@ -146,6 +188,11 @@ const hasActiveConnections = async (clientId: string): Promise => { return !!(clients && clients.length > 0); }; +// Get the current config version for a client +const getClientConfigVersion = (clientId: string): number => { + return clientConfigVersions.get(clientId) || 0; +}; + // Get all active nodes for a client const getActiveNodes = async ( clientType: ClientType, @@ -434,5 +481,6 @@ export { getActiveNodes, disconnectClient, NODE_ID, - cleanup + cleanup, + getClientConfigVersion };