diff --git a/server/private/routers/ws/ws.ts b/server/private/routers/ws/ws.ts index 1c1f54f4..e1fe3f54 100644 --- a/server/private/routers/ws/ws.ts +++ b/server/private/routers/ws/ws.ts @@ -319,6 +319,45 @@ const addClient = async ( existingClients.push(ws); connectedClients.set(mapKey, existingClients); + // Get or initialize config version + let configVersion = 0; + + // Check Redis first if enabled + if (redisManager.isRedisEnabled()) { + try { + const redisVersion = await redisManager.get(getConfigVersionKey(clientId)); + if (redisVersion !== null) { + configVersion = parseInt(redisVersion, 10); + // Sync to local cache + clientConfigVersions.set(clientId, configVersion); + } else if (!clientConfigVersions.has(clientId)) { + // No version in Redis or local cache, initialize to 0 + await redisManager.set(getConfigVersionKey(clientId), "0"); + clientConfigVersions.set(clientId, 0); + } else { + // Use local cache version and sync to Redis + configVersion = clientConfigVersions.get(clientId) || 0; + await redisManager.set(getConfigVersionKey(clientId), configVersion.toString()); + } + } catch (error) { + logger.error("Failed to get/set config version in Redis:", error); + // Fall back to local cache + if (!clientConfigVersions.has(clientId)) { + clientConfigVersions.set(clientId, 0); + } + configVersion = clientConfigVersions.get(clientId) || 0; + } + } else { + // Redis not enabled, use local cache only + if (!clientConfigVersions.has(clientId)) { + clientConfigVersions.set(clientId, 0); + } + configVersion = clientConfigVersions.get(clientId) || 0; + } + + // Set config version on websocket + ws.configVersion = configVersion; + // Add to Redis tracking if enabled if (redisManager.isRedisEnabled()) { try { @@ -337,7 +376,7 @@ const addClient = async ( } logger.info( - `Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}` + `Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}, Config version: ${configVersion}` ); }; @@ -393,7 +432,7 @@ const removeClient = async ( }; // Helper to get the current config version for a client -const getClientConfigVersion = async (clientId: string): Promise => { +const getClientConfigVersion = async (clientId: string): Promise => { // Try Redis first if available if (redisManager.isRedisEnabled()) { try { @@ -412,7 +451,7 @@ const getClientConfigVersion = async (clientId: string): Promise => { } // Fall back to local cache - return clientConfigVersions.get(clientId) || 0; + return clientConfigVersions.get(clientId); }; // Helper to increment and get the new config version for a client @@ -455,9 +494,6 @@ const sendToClientLocal = async ( // Handle config version let configVersion = await getClientConfigVersion(clientId); - if (options.incrementConfigVersion) { - configVersion = await incrementClientConfigVersion(clientId); - } // Add config version to message const messageWithVersion = { @@ -472,10 +508,6 @@ const sendToClientLocal = async ( } }); - logger.debug( - `sendToClient: Message type ${message.type} sent to clientId ${clientId} (configVersion: ${configVersion})` - ); - return true; }; @@ -515,19 +547,21 @@ const sendToClient = async ( message: WSMessage, options: SendMessageOptions = {} ): Promise => { + let configVersion = await getClientConfigVersion(clientId); + if (options.incrementConfigVersion) { + configVersion = await incrementClientConfigVersion(clientId); + } + + logger.debug( + `sendToClient: Message type ${message.type} sent to clientId ${clientId} (new configVersion: ${configVersion})` + ); + // Try to send locally first const localSent = await sendToClientLocal(clientId, message, options); // Only send via Redis if the client is not connected locally and Redis is enabled if (!localSent && redisManager.isRedisEnabled()) { try { - // If we need to increment config version, do it before sending via Redis - // so remote nodes send the correct version - let configVersion = await getClientConfigVersion(clientId); - if (options.incrementConfigVersion) { - configVersion = await incrementClientConfigVersion(clientId); - } - const redisMessage: RedisMessage = { type: "direct", targetClientId: clientId, diff --git a/server/routers/newt/handleNewtPingMessage.ts b/server/routers/newt/handleNewtPingMessage.ts index 8840c47e..a4af6872 100644 --- a/server/routers/newt/handleNewtPingMessage.ts +++ b/server/routers/newt/handleNewtPingMessage.ts @@ -1,6 +1,6 @@ import { db, sites } from "@server/db"; -import { disconnectClient } from "#dynamic/routers/ws"; -import { getClientConfigVersion, MessageHandler } from "@server/routers/ws"; +import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws"; +import { MessageHandler } from "@server/routers/ws"; import { clients, Newt } from "@server/db"; import { eq, lt, isNull, and, or } from "drizzle-orm"; import logger from "@server/logger"; diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index 543a9f7e..635342d8 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -1,6 +1,6 @@ import { db } from "@server/db"; -import { disconnectClient } from "#dynamic/routers/ws"; -import { getClientConfigVersion, MessageHandler } from "@server/routers/ws"; +import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws"; +import { MessageHandler } from "@server/routers/ws"; import { clients, olms, Olm } from "@server/db"; import { eq, lt, isNull, and, or } from "drizzle-orm"; import logger from "@server/logger"; @@ -171,11 +171,17 @@ export const handleOlmPingMessage: MessageHandler = async (context) => { } // get the version + logger.debug(`++++++++++++++++++++++++++++handleOlmPingMessage: About to get config version for olmId: ${olm.olmId}`); const configVersion = await getClientConfigVersion(olm.olmId); + logger.debug(`++++++++++++++++++++++++++++handleOlmPingMessage: Got config version: ${configVersion} (type: ${typeof configVersion})`); - if (message.configVersion && configVersion != message.configVersion) { - logger.warn( - `Olm ping with outdated config version: ${message.configVersion} (current: ${configVersion})` + if (configVersion == null || configVersion === undefined) { + logger.debug(`++++++++++++++++++++++++++++handleOlmPingMessage: could not get config version from server for olmId: ${olm.olmId}`) + } + + if (message.configVersion != null && configVersion != null && configVersion != message.configVersion) { + logger.debug( + `++++++++++++++++++++++++++++handleOlmPingMessage: Olm ping with outdated config version: ${message.configVersion} (current: ${configVersion})` ); await sendOlmSyncMessage(olm, client); } diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index d18e1760..4ffeff73 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -32,20 +32,24 @@ export async function addPeer( olmId = olm.olmId; } - await sendToClient(olmId, { - type: "olm/wg/peer/add", - data: { - siteId: peer.siteId, - name: peer.name, - publicKey: peer.publicKey, - endpoint: peer.endpoint, - relayEndpoint: peer.relayEndpoint, - serverIP: peer.serverIP, - serverPort: peer.serverPort, - remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access - aliases: peer.aliases - } - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + olmId, + { + type: "olm/wg/peer/add", + data: { + siteId: peer.siteId, + name: peer.name, + publicKey: peer.publicKey, + endpoint: peer.endpoint, + relayEndpoint: peer.relayEndpoint, + serverIP: peer.serverIP, + serverPort: peer.serverPort, + remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access + aliases: peer.aliases + } + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); @@ -70,13 +74,17 @@ export async function deletePeer( olmId = olm.olmId; } - await sendToClient(olmId, { - type: "olm/wg/peer/remove", - data: { - publicKey, - siteId: siteId - } - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + olmId, + { + type: "olm/wg/peer/remove", + data: { + publicKey, + siteId: siteId + } + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); @@ -109,19 +117,23 @@ export async function updatePeer( olmId = olm.olmId; } - await sendToClient(olmId, { - type: "olm/wg/peer/update", - data: { - siteId: peer.siteId, - publicKey: peer.publicKey, - endpoint: peer.endpoint, - relayEndpoint: peer.relayEndpoint, - serverIP: peer.serverIP, - serverPort: peer.serverPort, - remoteSubnets: peer.remoteSubnets, - aliases: peer.aliases - } - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + olmId, + { + type: "olm/wg/peer/update", + data: { + siteId: peer.siteId, + publicKey: peer.publicKey, + endpoint: peer.endpoint, + relayEndpoint: peer.relayEndpoint, + serverIP: peer.serverIP, + serverPort: peer.serverPort, + remoteSubnets: peer.remoteSubnets, + aliases: peer.aliases + } + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); @@ -151,19 +163,21 @@ export async function initPeerAddHandshake( olmId = olm.olmId; } - await sendToClient(olmId, { - type: "olm/wg/peer/holepunch/site/add", - data: { - siteId: peer.siteId, - exitNode: { - publicKey: peer.exitNode.publicKey, - relayPort: config.getRawConfig().gerbil.clients_start_port, - endpoint: peer.exitNode.endpoint + await sendToClient( + olmId, + { + type: "olm/wg/peer/holepunch/site/add", + data: { + siteId: peer.siteId, + exitNode: { + publicKey: peer.exitNode.publicKey, + relayPort: config.getRawConfig().gerbil.clients_start_port, + endpoint: peer.exitNode.endpoint + } } - } - // }, { incrementConfigVersion: true }).catch((error) => { - // TODO: DOES THIS NEED TO BE A INCREMENT VERSION? I AM NOT SURE BECAUSE IT WOULD BE TRIGGERED BY THE SYNC? - }).catch((error) => { + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); diff --git a/server/routers/olm/sync.ts b/server/routers/olm/sync.ts index a7db4d04..293f11b7 100644 --- a/server/routers/olm/sync.ts +++ b/server/routers/olm/sync.ts @@ -1,7 +1,9 @@ -import { Client, Olm } from "@server/db"; +import { Client, db, exitNodes, Olm, sites, clientSitesAssociationsCache } from "@server/db"; import { buildSiteConfigurationForOlmClient } from "./buildConfiguration"; import { sendToClient } from "#dynamic/routers/ws"; import logger from "@server/logger"; +import { eq, inArray } from "drizzle-orm"; +import config from "@server/lib/config"; export async function sendOlmSyncMessage(olm: Olm, client: Client) { // NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT @@ -11,12 +13,68 @@ export async function sendOlmSyncMessage(olm: Olm, client: Client) { false ); + // Get all exit nodes from sites where the client has peers + const clientSites = await db + .select() + .from(clientSitesAssociationsCache) + .innerJoin( + sites, + eq(sites.siteId, clientSitesAssociationsCache.siteId) + ) + .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); + + // Extract unique exit node IDs + const exitNodeIds = Array.from( + new Set( + clientSites + .map(({ sites: site }) => site.exitNodeId) + .filter((id): id is number => id !== null) + ) + ); + + let exitNodesData: { + publicKey: string; + relayPort: number; + endpoint: string; + siteIds: number[]; + }[] = []; + + if (exitNodeIds.length > 0) { + const allExitNodes = await db + .select() + .from(exitNodes) + .where(inArray(exitNodes.exitNodeId, exitNodeIds)); + + // Map exitNodeId to siteIds + const exitNodeIdToSiteIds: Record = {}; + for (const { sites: site } of clientSites) { + if (site.exitNodeId !== null) { + if (!exitNodeIdToSiteIds[site.exitNodeId]) { + exitNodeIdToSiteIds[site.exitNodeId] = []; + } + exitNodeIdToSiteIds[site.exitNodeId].push(site.siteId); + } + } + + exitNodesData = allExitNodes.map((exitNode) => { + return { + publicKey: exitNode.publicKey, + relayPort: config.getRawConfig().gerbil.clients_start_port, + endpoint: exitNode.endpoint, + siteIds: exitNodeIdToSiteIds[exitNode.exitNodeId] ?? [] + }; + }); + } + + logger.debug("++++++++++++++++++++++++++++sendOlmSyncMessage: sending sync message") + await sendToClient(olm.olmId, { type: "olm/sync", data: { - sites: siteConfigurations + sites: siteConfigurations, + exitNodes: exitNodesData } }).catch((error) => { logger.warn(`Error sending olm sync message:`, error); }); -} +} \ No newline at end of file diff --git a/server/routers/ws/ws.ts b/server/routers/ws/ws.ts index 7f396ea7..32432d99 100644 --- a/server/routers/ws/ws.ts +++ b/server/routers/ws/ws.ts @@ -56,6 +56,13 @@ const addClient = async ( existingClients.push(ws); connectedClients.set(mapKey, existingClients); + // Initialize config version to 0 if not already set, otherwise use existing + if (!clientConfigVersions.has(clientId)) { + clientConfigVersions.set(clientId, 0); + } + // Set the current config version on the websocket + ws.configVersion = clientConfigVersions.get(clientId) || 0; + logger.info( `Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}` ); @@ -96,19 +103,13 @@ const sendToClientLocal = async ( return false; } - // Increment config version if requested - if (options.incrementConfigVersion) { - const currentVersion = clientConfigVersions.get(clientId) || 0; - const newVersion = currentVersion + 1; - clientConfigVersions.set(clientId, newVersion); - // Update version on all client connections - clients.forEach((client) => { - client.configVersion = newVersion; - }); - } - // Include config version in message const configVersion = clientConfigVersions.get(clientId) || 0; + // Update version on all client connections + clients.forEach((client) => { + client.configVersion = configVersion; + }); + const messageWithVersion = { ...message, configVersion @@ -129,7 +130,6 @@ const broadcastToAllExceptLocal = async ( options: SendMessageOptions = {} ): Promise => { connectedClients.forEach((clients, mapKey) => { - const [type, id] = mapKey.split(":"); const clientId = mapKey; // mapKey is the clientId if (!(excludeClientId && clientId === excludeClientId)) { // Handle config version per client @@ -162,6 +162,13 @@ const sendToClient = async ( message: WSMessage, options: SendMessageOptions = {} ): Promise => { + // Increment config version if requested + if (options.incrementConfigVersion) { + const currentVersion = clientConfigVersions.get(clientId) || 0; + const newVersion = currentVersion + 1; + clientConfigVersions.set(clientId, newVersion); + } + // Try to send locally first const localSent = await sendToClientLocal(clientId, message, options); @@ -189,8 +196,10 @@ const hasActiveConnections = async (clientId: string): Promise => { }; // Get the current config version for a client -const getClientConfigVersion = async (clientId: string): Promise => { - return clientConfigVersions.get(clientId) || 0; +const getClientConfigVersion = async (clientId: string): Promise => { + const version = clientConfigVersions.get(clientId); + logger.debug(`getClientConfigVersion called for clientId: ${clientId}, returning: ${version} (type: ${typeof version})`); + return version; }; // Get all active nodes for a client