Message syncing works

This commit is contained in:
Owen
2026-01-15 21:26:13 -08:00
parent d52bd65d21
commit 65e8bfc93e
6 changed files with 208 additions and 87 deletions

View File

@@ -319,6 +319,45 @@ const addClient = async (
existingClients.push(ws); existingClients.push(ws);
connectedClients.set(mapKey, existingClients); connectedClients.set(mapKey, existingClients);
// Get or initialize config version
let configVersion = 0;
// Check Redis first if enabled
if (redisManager.isRedisEnabled()) {
try {
const redisVersion = await redisManager.get(getConfigVersionKey(clientId));
if (redisVersion !== null) {
configVersion = parseInt(redisVersion, 10);
// Sync to local cache
clientConfigVersions.set(clientId, configVersion);
} else if (!clientConfigVersions.has(clientId)) {
// No version in Redis or local cache, initialize to 0
await redisManager.set(getConfigVersionKey(clientId), "0");
clientConfigVersions.set(clientId, 0);
} else {
// Use local cache version and sync to Redis
configVersion = clientConfigVersions.get(clientId) || 0;
await redisManager.set(getConfigVersionKey(clientId), configVersion.toString());
}
} catch (error) {
logger.error("Failed to get/set config version in Redis:", error);
// Fall back to local cache
if (!clientConfigVersions.has(clientId)) {
clientConfigVersions.set(clientId, 0);
}
configVersion = clientConfigVersions.get(clientId) || 0;
}
} else {
// Redis not enabled, use local cache only
if (!clientConfigVersions.has(clientId)) {
clientConfigVersions.set(clientId, 0);
}
configVersion = clientConfigVersions.get(clientId) || 0;
}
// Set config version on websocket
ws.configVersion = configVersion;
// Add to Redis tracking if enabled // Add to Redis tracking if enabled
if (redisManager.isRedisEnabled()) { if (redisManager.isRedisEnabled()) {
try { try {
@@ -337,7 +376,7 @@ const addClient = async (
} }
logger.info( logger.info(
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}` `Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}, Config version: ${configVersion}`
); );
}; };
@@ -393,7 +432,7 @@ const removeClient = async (
}; };
// Helper to get the current config version for a client // Helper to get the current config version for a client
const getClientConfigVersion = async (clientId: string): Promise<number> => { const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
// Try Redis first if available // Try Redis first if available
if (redisManager.isRedisEnabled()) { if (redisManager.isRedisEnabled()) {
try { try {
@@ -412,7 +451,7 @@ const getClientConfigVersion = async (clientId: string): Promise<number> => {
} }
// Fall back to local cache // Fall back to local cache
return clientConfigVersions.get(clientId) || 0; return clientConfigVersions.get(clientId);
}; };
// Helper to increment and get the new config version for a client // Helper to increment and get the new config version for a client
@@ -455,9 +494,6 @@ const sendToClientLocal = async (
// Handle config version // Handle config version
let configVersion = await getClientConfigVersion(clientId); let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId);
}
// Add config version to message // Add config version to message
const messageWithVersion = { const messageWithVersion = {
@@ -472,10 +508,6 @@ const sendToClientLocal = async (
} }
}); });
logger.debug(
`sendToClient: Message type ${message.type} sent to clientId ${clientId} (configVersion: ${configVersion})`
);
return true; return true;
}; };
@@ -515,19 +547,21 @@ const sendToClient = async (
message: WSMessage, message: WSMessage,
options: SendMessageOptions = {} options: SendMessageOptions = {}
): Promise<boolean> => { ): Promise<boolean> => {
let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId);
}
logger.debug(
`sendToClient: Message type ${message.type} sent to clientId ${clientId} (new configVersion: ${configVersion})`
);
// Try to send locally first // Try to send locally first
const localSent = await sendToClientLocal(clientId, message, options); const localSent = await sendToClientLocal(clientId, message, options);
// Only send via Redis if the client is not connected locally and Redis is enabled // Only send via Redis if the client is not connected locally and Redis is enabled
if (!localSent && redisManager.isRedisEnabled()) { if (!localSent && redisManager.isRedisEnabled()) {
try { try {
// If we need to increment config version, do it before sending via Redis
// so remote nodes send the correct version
let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId);
}
const redisMessage: RedisMessage = { const redisMessage: RedisMessage = {
type: "direct", type: "direct",
targetClientId: clientId, targetClientId: clientId,

View File

@@ -1,6 +1,6 @@
import { db, sites } from "@server/db"; import { db, sites } from "@server/db";
import { disconnectClient } from "#dynamic/routers/ws"; import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
import { getClientConfigVersion, MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { clients, Newt } from "@server/db"; import { clients, Newt } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm"; import { eq, lt, isNull, and, or } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";

View File

@@ -1,6 +1,6 @@
import { db } from "@server/db"; import { db } from "@server/db";
import { disconnectClient } from "#dynamic/routers/ws"; import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
import { getClientConfigVersion, MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { clients, olms, Olm } from "@server/db"; import { clients, olms, Olm } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm"; import { eq, lt, isNull, and, or } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
@@ -171,11 +171,17 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
} }
// get the version // get the version
logger.debug(`++++++++++++++++++++++++++++handleOlmPingMessage: About to get config version for olmId: ${olm.olmId}`);
const configVersion = await getClientConfigVersion(olm.olmId); const configVersion = await getClientConfigVersion(olm.olmId);
logger.debug(`++++++++++++++++++++++++++++handleOlmPingMessage: Got config version: ${configVersion} (type: ${typeof configVersion})`);
if (message.configVersion && configVersion != message.configVersion) { if (configVersion == null || configVersion === undefined) {
logger.warn( logger.debug(`++++++++++++++++++++++++++++handleOlmPingMessage: could not get config version from server for olmId: ${olm.olmId}`)
`Olm ping with outdated config version: ${message.configVersion} (current: ${configVersion})` }
if (message.configVersion != null && configVersion != null && configVersion != message.configVersion) {
logger.debug(
`++++++++++++++++++++++++++++handleOlmPingMessage: Olm ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
); );
await sendOlmSyncMessage(olm, client); await sendOlmSyncMessage(olm, client);
} }

View File

@@ -32,20 +32,24 @@ export async function addPeer(
olmId = olm.olmId; olmId = olm.olmId;
} }
await sendToClient(olmId, { await sendToClient(
type: "olm/wg/peer/add", olmId,
data: { {
siteId: peer.siteId, type: "olm/wg/peer/add",
name: peer.name, data: {
publicKey: peer.publicKey, siteId: peer.siteId,
endpoint: peer.endpoint, name: peer.name,
relayEndpoint: peer.relayEndpoint, publicKey: peer.publicKey,
serverIP: peer.serverIP, endpoint: peer.endpoint,
serverPort: peer.serverPort, relayEndpoint: peer.relayEndpoint,
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access serverIP: peer.serverIP,
aliases: peer.aliases serverPort: peer.serverPort,
} remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
}, { incrementConfigVersion: true }).catch((error) => { aliases: peer.aliases
}
},
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -70,13 +74,17 @@ export async function deletePeer(
olmId = olm.olmId; olmId = olm.olmId;
} }
await sendToClient(olmId, { await sendToClient(
type: "olm/wg/peer/remove", olmId,
data: { {
publicKey, type: "olm/wg/peer/remove",
siteId: siteId data: {
} publicKey,
}, { incrementConfigVersion: true }).catch((error) => { siteId: siteId
}
},
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -109,19 +117,23 @@ export async function updatePeer(
olmId = olm.olmId; olmId = olm.olmId;
} }
await sendToClient(olmId, { await sendToClient(
type: "olm/wg/peer/update", olmId,
data: { {
siteId: peer.siteId, type: "olm/wg/peer/update",
publicKey: peer.publicKey, data: {
endpoint: peer.endpoint, siteId: peer.siteId,
relayEndpoint: peer.relayEndpoint, publicKey: peer.publicKey,
serverIP: peer.serverIP, endpoint: peer.endpoint,
serverPort: peer.serverPort, relayEndpoint: peer.relayEndpoint,
remoteSubnets: peer.remoteSubnets, serverIP: peer.serverIP,
aliases: peer.aliases serverPort: peer.serverPort,
} remoteSubnets: peer.remoteSubnets,
}, { incrementConfigVersion: true }).catch((error) => { aliases: peer.aliases
}
},
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -151,19 +163,21 @@ export async function initPeerAddHandshake(
olmId = olm.olmId; olmId = olm.olmId;
} }
await sendToClient(olmId, { await sendToClient(
type: "olm/wg/peer/holepunch/site/add", olmId,
data: { {
siteId: peer.siteId, type: "olm/wg/peer/holepunch/site/add",
exitNode: { data: {
publicKey: peer.exitNode.publicKey, siteId: peer.siteId,
relayPort: config.getRawConfig().gerbil.clients_start_port, exitNode: {
endpoint: peer.exitNode.endpoint publicKey: peer.exitNode.publicKey,
relayPort: config.getRawConfig().gerbil.clients_start_port,
endpoint: peer.exitNode.endpoint
}
} }
} },
// }, { incrementConfigVersion: true }).catch((error) => { { incrementConfigVersion: true }
// TODO: DOES THIS NEED TO BE A INCREMENT VERSION? I AM NOT SURE BECAUSE IT WOULD BE TRIGGERED BY THE SYNC? ).catch((error) => {
}).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });

View File

@@ -1,7 +1,9 @@
import { Client, Olm } from "@server/db"; import { Client, db, exitNodes, Olm, sites, clientSitesAssociationsCache } from "@server/db";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration"; import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { sendToClient } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger"; import logger from "@server/logger";
import { eq, inArray } from "drizzle-orm";
import config from "@server/lib/config";
export async function sendOlmSyncMessage(olm: Olm, client: Client) { export async function sendOlmSyncMessage(olm: Olm, client: Client) {
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT // NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
@@ -11,10 +13,66 @@ export async function sendOlmSyncMessage(olm: Olm, client: Client) {
false false
); );
// Get all exit nodes from sites where the client has peers
const clientSites = await db
.select()
.from(clientSitesAssociationsCache)
.innerJoin(
sites,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract unique exit node IDs
const exitNodeIds = Array.from(
new Set(
clientSites
.map(({ sites: site }) => site.exitNodeId)
.filter((id): id is number => id !== null)
)
);
let exitNodesData: {
publicKey: string;
relayPort: number;
endpoint: string;
siteIds: number[];
}[] = [];
if (exitNodeIds.length > 0) {
const allExitNodes = await db
.select()
.from(exitNodes)
.where(inArray(exitNodes.exitNodeId, exitNodeIds));
// Map exitNodeId to siteIds
const exitNodeIdToSiteIds: Record<number, number[]> = {};
for (const { sites: site } of clientSites) {
if (site.exitNodeId !== null) {
if (!exitNodeIdToSiteIds[site.exitNodeId]) {
exitNodeIdToSiteIds[site.exitNodeId] = [];
}
exitNodeIdToSiteIds[site.exitNodeId].push(site.siteId);
}
}
exitNodesData = allExitNodes.map((exitNode) => {
return {
publicKey: exitNode.publicKey,
relayPort: config.getRawConfig().gerbil.clients_start_port,
endpoint: exitNode.endpoint,
siteIds: exitNodeIdToSiteIds[exitNode.exitNodeId] ?? []
};
});
}
logger.debug("++++++++++++++++++++++++++++sendOlmSyncMessage: sending sync message")
await sendToClient(olm.olmId, { await sendToClient(olm.olmId, {
type: "olm/sync", type: "olm/sync",
data: { data: {
sites: siteConfigurations sites: siteConfigurations,
exitNodes: exitNodesData
} }
}).catch((error) => { }).catch((error) => {
logger.warn(`Error sending olm sync message:`, error); logger.warn(`Error sending olm sync message:`, error);

View File

@@ -56,6 +56,13 @@ const addClient = async (
existingClients.push(ws); existingClients.push(ws);
connectedClients.set(mapKey, existingClients); connectedClients.set(mapKey, existingClients);
// Initialize config version to 0 if not already set, otherwise use existing
if (!clientConfigVersions.has(clientId)) {
clientConfigVersions.set(clientId, 0);
}
// Set the current config version on the websocket
ws.configVersion = clientConfigVersions.get(clientId) || 0;
logger.info( logger.info(
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}` `Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`
); );
@@ -96,19 +103,13 @@ const sendToClientLocal = async (
return false; return false;
} }
// Increment config version if requested
if (options.incrementConfigVersion) {
const currentVersion = clientConfigVersions.get(clientId) || 0;
const newVersion = currentVersion + 1;
clientConfigVersions.set(clientId, newVersion);
// Update version on all client connections
clients.forEach((client) => {
client.configVersion = newVersion;
});
}
// Include config version in message // Include config version in message
const configVersion = clientConfigVersions.get(clientId) || 0; const configVersion = clientConfigVersions.get(clientId) || 0;
// Update version on all client connections
clients.forEach((client) => {
client.configVersion = configVersion;
});
const messageWithVersion = { const messageWithVersion = {
...message, ...message,
configVersion configVersion
@@ -129,7 +130,6 @@ const broadcastToAllExceptLocal = async (
options: SendMessageOptions = {} options: SendMessageOptions = {}
): Promise<void> => { ): Promise<void> => {
connectedClients.forEach((clients, mapKey) => { connectedClients.forEach((clients, mapKey) => {
const [type, id] = mapKey.split(":");
const clientId = mapKey; // mapKey is the clientId const clientId = mapKey; // mapKey is the clientId
if (!(excludeClientId && clientId === excludeClientId)) { if (!(excludeClientId && clientId === excludeClientId)) {
// Handle config version per client // Handle config version per client
@@ -162,6 +162,13 @@ const sendToClient = async (
message: WSMessage, message: WSMessage,
options: SendMessageOptions = {} options: SendMessageOptions = {}
): Promise<boolean> => { ): Promise<boolean> => {
// Increment config version if requested
if (options.incrementConfigVersion) {
const currentVersion = clientConfigVersions.get(clientId) || 0;
const newVersion = currentVersion + 1;
clientConfigVersions.set(clientId, newVersion);
}
// Try to send locally first // Try to send locally first
const localSent = await sendToClientLocal(clientId, message, options); const localSent = await sendToClientLocal(clientId, message, options);
@@ -189,8 +196,10 @@ const hasActiveConnections = async (clientId: string): Promise<boolean> => {
}; };
// Get the current config version for a client // Get the current config version for a client
const getClientConfigVersion = async (clientId: string): Promise<number> => { const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
return clientConfigVersions.get(clientId) || 0; const version = clientConfigVersions.get(clientId);
logger.debug(`getClientConfigVersion called for clientId: ${clientId}, returning: ${version} (type: ${typeof version})`);
return version;
}; };
// Get all active nodes for a client // Get all active nodes for a client