diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index 8b601ae71..f6a94e7b7 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -21,10 +21,10 @@ import { } from "@server/db"; import { and, count, eq, inArray, ne } from "drizzle-orm"; -import { deletePeer as newtDeletePeer } from "@server/routers/newt/peers"; +import { deletePeersBatch as newtDeletePeersBatch } from "@server/routers/newt/peers"; import { - initPeerAddHandshake, - deletePeer as olmDeletePeer + initPeerAddHandshakeBatch, + deletePeersBatch as olmDeletePeersBatch } from "@server/routers/olm/peers"; import { sendToExitNode } from "#dynamic/lib/exitNodes"; import logger from "@server/logger"; @@ -35,10 +35,10 @@ import { parseEndpoint } from "@server/lib/ip"; import { - addPeerData, - addTargets as addSubnetProxyTargets, - removePeerData, - removeTargets as removeSubnetProxyTargets + addPeerDataBatch, + addTargetsBatch as addSubnetProxyTargetsBatch, + removePeerDataBatch, + removeTargetsBatch as removeSubnetProxyTargetsBatch } from "@server/routers/client/targets"; import { lockManager } from "#dynamic/lib/lock"; import { rebuildQueue } from "#dynamic/lib/rebuildQueue"; @@ -559,6 +559,28 @@ async function handleMessagesForSiteClients( const newtJobs: Promise[] = []; const olmJobs: Promise[] = []; const exitNodeJobs: Promise[] = []; + const newtPeerDeletes: { + siteId: number; + publicKey: string; + newtId: string; + }[] = []; + const olmPeerDeletes: { + clientId: number; + siteId: number; + publicKey: string; + olmId: string; + }[] = []; + const olmPeerAddHandshakes: { + clientId: number; + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }; + olmId: string; + }[] = []; // Combine all clients that need processing (those being added or removed) const clientsToProcess = new Map< @@ -638,15 +660,17 @@ async function handleMessagesForSiteClients( } if (isDelete) { - newtJobs.push(newtDeletePeer(siteId, client.pubKey, newt.newtId)); - olmJobs.push( - olmDeletePeer( - client.clientId, - siteId, - site.publicKey, - olm.olmId - ) - ); + newtPeerDeletes.push({ + siteId, + publicKey: client.pubKey, + newtId: newt.newtId + }); + olmPeerDeletes.push({ + clientId: client.clientId, + siteId, + publicKey: site.publicKey, + olmId: olm.olmId + }); } if (isAdd) { @@ -658,23 +682,34 @@ async function handleMessagesForSiteClients( continue; } - await initPeerAddHandshake( - // this will kick off the add peer process for the client - client.clientId, - { + olmPeerAddHandshakes.push({ + clientId: client.clientId, + peer: { siteId, exitNode: { publicKey: exitNode.publicKey, endpoint: exitNode.endpoint } }, - olm.olmId - ); + olmId: olm.olmId + }); } exitNodeJobs.push(updateClientSiteDestinations(client, trx)); } + if (newtPeerDeletes.length > 0) { + newtJobs.push(newtDeletePeersBatch(newtPeerDeletes)); + } + + if (olmPeerDeletes.length > 0) { + olmJobs.push(olmDeletePeersBatch(olmPeerDeletes)); + } + + if (olmPeerAddHandshakes.length > 0) { + olmJobs.push(initPeerAddHandshakeBatch(olmPeerAddHandshakes)); + } + Promise.all(exitNodeJobs).catch((error) => { logger.error( `rebuildClientAssociations: Error updating client site destinations for site ${site.siteId}:`, @@ -867,24 +902,28 @@ async function handleSubnetProxyTargetUpdates( if (targetsToAdd) { proxyJobs.push( - addSubnetProxyTargets( - newt.newtId, - targetsToAdd, - newt.version - ) + addSubnetProxyTargetsBatch([ + { + newtId: newt.newtId, + targets: targetsToAdd, + version: newt.version + } + ]) ); } - for (const client of addedClients) { - olmJobs.push( - addPeerData( - client.clientId, + olmJobs.push( + addPeerDataBatch( + addedClients.map((client) => ({ + clientId: client.clientId, siteId, - generateRemoteSubnets([siteResource]), - generateAliasConfig([siteResource]) - ) - ); - } + remoteSubnets: generateRemoteSubnets([ + siteResource + ]), + aliases: generateAliasConfig([siteResource]) + })) + ) + ); } } @@ -904,14 +943,23 @@ async function handleSubnetProxyTargetUpdates( if (targetsToRemove) { proxyJobs.push( - removeSubnetProxyTargets( - newt.newtId, - targetsToRemove, - newt.version - ) + removeSubnetProxyTargetsBatch([ + { + newtId: newt.newtId, + targets: targetsToRemove, + version: newt.version + } + ]) ); } + const peerDataRemovals: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: ReturnType; + }[] = []; + for (const client of removedClients) { if (!siteResource.destination) { continue; @@ -959,14 +1007,16 @@ async function handleSubnetProxyTargetUpdates( ? [] : generateRemoteSubnets([siteResource]); - olmJobs.push( - removePeerData( - client.clientId, - siteId, - remoteSubnetsToRemove, - generateAliasConfig([siteResource]) - ) - ); + peerDataRemovals.push({ + clientId: client.clientId, + siteId, + remoteSubnets: remoteSubnetsToRemove, + aliases: generateAliasConfig([siteResource]) + }); + } + + if (peerDataRemovals.length > 0) { + olmJobs.push(removePeerDataBatch(peerDataRemovals)); } } } @@ -1277,6 +1327,28 @@ async function handleMessagesForClientSites( const newtJobs: Promise[] = []; const olmJobs: Promise[] = []; const exitNodeJobs: Promise[] = []; + const newtPeerDeletes: { + siteId: number; + publicKey: string; + newtId: string; + }[] = []; + const olmPeerDeletes: { + clientId: number; + siteId: number; + publicKey: string; + olmId: string; + }[] = []; + const olmPeerAddHandshakes: { + clientId: number; + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }; + olmId: string; + }[] = []; const totalSitesOnClient = await trx .select({ count: count(clientSitesAssociationsCache.siteId) }) @@ -1308,19 +1380,19 @@ async function handleMessagesForClientSites( if (isRemove) { // Remove peer from newt - newtJobs.push( - newtDeletePeer(site.siteId, client.pubKey, newt.newtId) - ); + newtPeerDeletes.push({ + siteId: site.siteId, + publicKey: client.pubKey, + newtId: newt.newtId + }); try { // Remove peer from olm - olmJobs.push( - olmDeletePeer( - client.clientId, - site.siteId, - site.publicKey, - olmId - ) - ); + olmPeerDeletes.push({ + clientId: client.clientId, + siteId: site.siteId, + publicKey: site.publicKey, + olmId + }); } catch (error) { // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send if ( @@ -1352,10 +1424,9 @@ async function handleMessagesForClientSites( continue; } - await initPeerAddHandshake( - // this will kick off the add peer process for the client - client.clientId, - { + olmPeerAddHandshakes.push({ + clientId: client.clientId, + peer: { siteId: site.siteId, exitNode: { publicKey: exitNode.publicKey, @@ -1363,7 +1434,7 @@ async function handleMessagesForClientSites( } }, olmId - ); + }); } // Update exit node destinations @@ -1379,6 +1450,18 @@ async function handleMessagesForClientSites( ); } + if (newtPeerDeletes.length > 0) { + newtJobs.push(newtDeletePeersBatch(newtPeerDeletes)); + } + + if (olmPeerDeletes.length > 0) { + olmJobs.push(olmDeletePeersBatch(olmPeerDeletes)); + } + + if (olmPeerAddHandshakes.length > 0) { + olmJobs.push(initPeerAddHandshakeBatch(olmPeerAddHandshakes)); + } + Promise.all(exitNodeJobs).catch((error) => { logger.error( `rebuildClientAssociations: Error updating client site destinations for client ${client.clientId}:`, @@ -1477,6 +1560,20 @@ async function handleMessagesForClientResources( continue; } + const targetsToAddBatch: { + newtId: string; + targets: NonNullable< + Awaited> + >; + version: string | null; + }[] = []; + const peerDataAdds: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: ReturnType; + }[] = []; + for (const resource of resources) { const targets = await generateSubnetProxyTargetV2(resource, [ { @@ -1487,25 +1584,21 @@ async function handleMessagesForClientResources( ]); if (targets) { - proxyJobs.push( - addSubnetProxyTargets( - newt.newtId, - targets, - newt.version - ) - ); + targetsToAddBatch.push({ + newtId: newt.newtId, + targets, + version: newt.version + }); } try { // Add peer data to olm - olmJobs.push( - addPeerData( - client.clientId, - siteId, - generateRemoteSubnets([resource]), - generateAliasConfig([resource]) - ) - ); + peerDataAdds.push({ + clientId: client.clientId, + siteId, + remoteSubnets: generateRemoteSubnets([resource]), + aliases: generateAliasConfig([resource]) + }); } catch (error) { // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send if ( @@ -1520,6 +1613,14 @@ async function handleMessagesForClientResources( } } } + + if (targetsToAddBatch.length > 0) { + proxyJobs.push(addSubnetProxyTargetsBatch(targetsToAddBatch)); + } + + if (peerDataAdds.length > 0) { + olmJobs.push(addPeerDataBatch(peerDataAdds)); + } } } @@ -1586,6 +1687,20 @@ async function handleMessagesForClientResources( continue; } + const targetsToRemoveBatch: { + newtId: string; + targets: NonNullable< + Awaited> + >; + version: string | null; + }[] = []; + const peerDataRemovals: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: ReturnType; + }[] = []; + for (const resource of resources) { const targets = await generateSubnetProxyTargetV2(resource, [ { @@ -1596,13 +1711,11 @@ async function handleMessagesForClientResources( ]); if (targets) { - proxyJobs.push( - removeSubnetProxyTargets( - newt.newtId, - targets, - newt.version - ) - ); + targetsToRemoveBatch.push({ + newtId: newt.newtId, + targets, + version: newt.version + }); } try { @@ -1653,14 +1766,12 @@ async function handleMessagesForClientResources( : generateRemoteSubnets([resource]); // Remove peer data from olm - olmJobs.push( - removePeerData( - client.clientId, - siteId, - remoteSubnetsToRemove, - generateAliasConfig([resource]) - ) - ); + peerDataRemovals.push({ + clientId: client.clientId, + siteId, + remoteSubnets: remoteSubnetsToRemove, + aliases: generateAliasConfig([resource]) + }); } catch (error) { // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send if ( @@ -1675,6 +1786,16 @@ async function handleMessagesForClientResources( } } } + + if (targetsToRemoveBatch.length > 0) { + proxyJobs.push( + removeSubnetProxyTargetsBatch(targetsToRemoveBatch) + ); + } + + if (peerDataRemovals.length > 0) { + olmJobs.push(removePeerDataBatch(peerDataRemovals)); + } } } @@ -1928,7 +2049,15 @@ export async function cleanupSiteAssociations( for (const client of allClients) { // Tell each olm to drop the site's WireGuard peer. if (site.publicKey) { - jobs.push(olmDeletePeer(client.clientId, siteId, site.publicKey)); + jobs.push( + olmDeletePeersBatch([ + { + clientId: client.clientId, + siteId, + publicKey: site.publicKey + } + ]) + ); } // Recompute and push updated relay destinations (now excluding this site). diff --git a/server/private/lib/rebuildQueue.ts b/server/private/lib/rebuildQueue.ts index e5ee7e7cb..2cd1dadc0 100644 --- a/server/private/lib/rebuildQueue.ts +++ b/server/private/lib/rebuildQueue.ts @@ -29,6 +29,7 @@ export interface RebuildJobHandlers { // Redis list holding pending rebuild jobs (RPUSH to enqueue, LPOP to dequeue — FIFO order). const QUEUE_KEY = "rebuild-client-associations:queue"; +const QUEUED_SET_KEY = "rebuild-client-associations:queued"; // Distributed lock that serialises queue consumption to a single server instance // at a time. TTL is generous enough to cover a full batch of expensive rebuilds. @@ -54,11 +55,28 @@ class RedisRebuildQueue { } try { + const dedupeKey = `${job.type}:${job.id}`; + const added = await redis.sadd(QUEUED_SET_KEY, dedupeKey); + if (added === 0) { + logger.debug( + `Rebuild queue: skipped duplicate queued job ${job.type}:${job.id}` + ); + return; + } + await redis.rpush(QUEUE_KEY, JSON.stringify(job)); logger.debug( `Rebuild queue: enqueued ${job.type}:${job.id} (queue position: tail)` ); } catch (err) { + await redis + .srem(QUEUED_SET_KEY, `${job.type}:${job.id}`) + .catch((cleanupErr) => + logger.warn( + `Rebuild queue: failed to cleanup dedupe key for ${job.type}:${job.id} after enqueue failure:`, + cleanupErr + ) + ); logger.error( `Rebuild queue: failed to enqueue ${job.type}:${job.id}:`, err @@ -121,6 +139,17 @@ class RedisRebuildQueue { continue; } + // Remove from dedupe set once dequeued so the same job + // can be re-queued while this one is in progress. + await redis + .srem(QUEUED_SET_KEY, `${job.type}:${job.id}`) + .catch((cleanupErr) => + logger.warn( + `Rebuild queue: failed to remove dedupe key for ${job.type}:${job.id} on dequeue:`, + cleanupErr + ) + ); + logger.debug( `Rebuild queue: processing ${job.type}:${job.id}` ); diff --git a/server/private/routers/ws/ws.ts b/server/private/routers/ws/ws.ts index a592927cc..8d222fd72 100644 --- a/server/private/routers/ws/ws.ts +++ b/server/private/routers/ws/ws.ts @@ -38,6 +38,7 @@ import { messageHandlers } from "@server/routers/ws/messageHandlers"; import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers"; import { AuthenticatedWebSocket, + BatchSendMessage, ClientType, WSMessage, TokenPayload, @@ -187,6 +188,8 @@ const wss: WebSocketServer = new WebSocketServer({ noServer: true }); // Generate unique node ID for this instance const NODE_ID = uuidv4(); const REDIS_CHANNEL = "websocket_messages"; +const REDIS_DIRECT_BATCH_SIZE = 250; +const REDIS_DIRECT_FLUSH_INTERVAL_MS = 10; // Client tracking map (local to this node) const connectedClients: Map = new Map(); @@ -197,6 +200,15 @@ const clientConfigVersions: Map = new Map(); // Recovery tracking let isRedisRecoveryInProgress = false; +interface RedisDirectBatchEntry { + targetClientId: string; + message: WSMessage; + resolve: () => void; +} + +let pendingRedisDirectMessages: RedisDirectBatchEntry[] = []; +let redisDirectFlushTimer: NodeJS.Timeout | null = null; + // Helper to get map key const getClientMapKey = (clientId: string) => clientId; @@ -207,6 +219,78 @@ const getNodeConnectionsKey = (nodeId: string, clientId: string) => const getConfigVersionKey = (clientId: string) => `ws:configVersion:${clientId}`; +const clearRedisDirectFlushTimer = (): void => { + if (redisDirectFlushTimer) { + clearTimeout(redisDirectFlushTimer); + redisDirectFlushTimer = null; + } +}; + +const publishDirectBatch = async ( + entries: RedisDirectBatchEntry[] +): Promise => { + const redisMessage: RedisMessage = { + type: "direct-batch", + messages: entries.map((entry) => ({ + targetClientId: entry.targetClientId, + message: entry.message + })), + fromNodeId: NODE_ID + }; + + await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage)); +}; + +const flushPendingRedisDirectMessages = async (): Promise => { + clearRedisDirectFlushTimer(); + + if (pendingRedisDirectMessages.length === 0) { + return; + } + + const entries = pendingRedisDirectMessages; + pendingRedisDirectMessages = []; + + if (!redisManager.isRedisEnabled()) { + entries.forEach((entry) => entry.resolve()); + return; + } + + for (let i = 0; i < entries.length; i += REDIS_DIRECT_BATCH_SIZE) { + const batch = entries.slice(i, i + REDIS_DIRECT_BATCH_SIZE); + try { + await publishDirectBatch(batch); + } catch (error) { + logger.error( + "Failed to send batched direct messages via Redis, messages may be lost:", + error + ); + } finally { + batch.forEach((entry) => entry.resolve()); + } + } +}; + +const enqueueRedisDirectMessage = async ( + targetClientId: string, + message: WSMessage +): Promise => { + await new Promise((resolve) => { + pendingRedisDirectMessages.push({ targetClientId, message, resolve }); + + if (pendingRedisDirectMessages.length >= REDIS_DIRECT_BATCH_SIZE) { + void flushPendingRedisDirectMessages(); + return; + } + + if (!redisDirectFlushTimer) { + redisDirectFlushTimer = setTimeout(() => { + void flushPendingRedisDirectMessages(); + }, REDIS_DIRECT_FLUSH_INTERVAL_MS); + } + }); +}; + // Initialize Redis subscription for cross-node messaging const initializeRedisSubscription = async (): Promise => { if (!redisManager.isRedisEnabled()) return; @@ -227,7 +311,16 @@ const initializeRedisSubscription = async (): Promise => { // Send to specific client on this node await sendToClientLocal( redisMessage.targetClientId, - redisMessage.message + redisMessage.message, + {}, + redisMessage.message.configVersion + ); + } else if ( + redisMessage.type === "direct-batch" && + redisMessage.messages + ) { + await sendRedisDirectBatchToLocalClients( + redisMessage.messages ); } else if (redisMessage.type === "broadcast") { // Broadcast to all clients on this node except excluded @@ -503,7 +596,8 @@ const incrementClientConfigVersion = async ( const sendToClientLocal = async ( clientId: string, message: WSMessage, - options: SendMessageOptions = {} + options: SendMessageOptions = {}, + preResolvedConfigVersion?: number ): Promise => { const mapKey = getClientMapKey(clientId); const clients = connectedClients.get(mapKey); @@ -512,7 +606,8 @@ const sendToClientLocal = async ( } // Handle config version - const configVersion = await getClientConfigVersion(clientId); + const configVersion = + preResolvedConfigVersion ?? (await getClientConfigVersion(clientId)); // Add config version to message const messageWithVersion = { @@ -545,6 +640,20 @@ const sendToClientLocal = async ( return true; }; +const sendRedisDirectBatchToLocalClients = async ( + entries: { targetClientId: string; message: WSMessage }[] +): Promise => { + const jobs = entries.map((entry) => + sendToClientLocal( + entry.targetClientId, + entry.message, + {}, + entry.message.configVersion + ) + ); + await Promise.all(jobs); +}; + const broadcastToAllExceptLocal = async ( message: WSMessage, excludeClientId?: string, @@ -607,23 +716,13 @@ const sendToClient = async ( // Only send via Redis if the client is not connected locally and Redis is enabled if (!localSent && redisManager.isRedisEnabled()) { try { - const redisMessage: RedisMessage = { - type: "direct", - targetClientId: clientId, - message: { - ...message, - configVersion - }, - fromNodeId: NODE_ID - }; - - await redisManager.publish( - REDIS_CHANNEL, - JSON.stringify(redisMessage) - ); + await enqueueRedisDirectMessage(clientId, { + ...message, + configVersion + }); } catch (error) { logger.error( - "Failed to send message via Redis, message may be lost:", + "Failed to queue batched direct message for Redis delivery, message may be lost:", error ); // Continue execution - local delivery already attempted @@ -638,6 +737,76 @@ const sendToClient = async ( return localSent; }; +const sendToClientsBatch = async ( + entries: BatchSendMessage[] +): Promise => { + if (entries.length === 0) { + return; + } + + const remoteEntries: { targetClientId: string; message: WSMessage }[] = []; + + for (const entry of entries) { + const options = entry.options || {}; + const { clientId, message } = entry; + + let configVersion = await getClientConfigVersion(clientId); + if (options.incrementConfigVersion) { + configVersion = await incrementClientConfigVersion(clientId); + } + + logger.debug( + `sendToClientsBatch: Message type ${message.type} queued for clientId ${clientId} (new configVersion: ${configVersion})` + ); + + const localSent = await sendToClientLocal( + clientId, + message, + options, + configVersion + ); + + if (!localSent && redisManager.isRedisEnabled()) { + remoteEntries.push({ + targetClientId: clientId, + message: { + ...message, + configVersion + } + }); + } else if (!localSent && !redisManager.isRedisEnabled()) { + logger.debug( + `Could not deliver batch message to ${clientId} - not connected locally and Redis unavailable` + ); + } + } + + if (!redisManager.isRedisEnabled() || remoteEntries.length === 0) { + return; + } + + for (let i = 0; i < remoteEntries.length; i += REDIS_DIRECT_BATCH_SIZE) { + const messages = remoteEntries.slice(i, i + REDIS_DIRECT_BATCH_SIZE); + try { + const redisMessage: RedisMessage = { + type: "direct-batch", + messages, + fromNodeId: NODE_ID + }; + + await redisManager.publish( + REDIS_CHANNEL, + JSON.stringify(redisMessage) + ); + } catch (error) { + logger.error( + "Failed to send explicit direct batch via Redis, messages may be lost:", + error + ); + } + } +}; + const broadcastToAllExcept = async ( message: WSMessage, excludeClientId?: string, @@ -1109,6 +1278,8 @@ const disconnectClient = async (clientId: string): Promise => { // Cleanup function for graceful shutdown const cleanup = async (): Promise => { try { + await flushPendingRedisDirectMessages(); + // Close all WebSocket connections connectedClients.forEach((clients) => { clients.forEach((client) => { @@ -1139,6 +1310,7 @@ export { router, handleWSUpgrade, sendToClient, + sendToClientsBatch, broadcastToAllExcept, connectedClients, hasActiveConnections, diff --git a/server/routers/client/targets.ts b/server/routers/client/targets.ts index c208acd88..c62a64ae0 100644 --- a/server/routers/client/targets.ts +++ b/server/routers/client/targets.ts @@ -1,4 +1,4 @@ -import { sendToClient } from "#dynamic/routers/ws"; +import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import { db, newts, olms } from "@server/db"; import { Alias, @@ -8,7 +8,7 @@ import { } from "@server/lib/ip"; import { canCompress } from "@server/lib/clientVersionChecks"; import logger from "@server/logger"; -import { eq } from "drizzle-orm"; +import { eq, inArray } from "drizzle-orm"; import semver from "semver"; const NEWT_V2_TARGETS_VERSION = ">=1.10.3"; @@ -59,6 +59,42 @@ export async function addTargets( ); } +export async function addTargetsBatch( + entries: { + newtId: string; + targets: SubnetProxyTarget[] | SubnetProxyTargetV2[]; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolved = await Promise.all( + entries.map(async (entry) => ({ + ...entry, + targets: await convertTargetsIfNecessary( + entry.newtId, + entry.targets + ) + })) + ); + + await sendToClientsBatch( + resolved.map((entry) => ({ + clientId: entry.newtId, + message: { + type: `newt/wg/targets/add`, + data: entry.targets + }, + options: { + incrementConfigVersion: true, + compress: canCompress(entry.version, "newt") + } + })) + ); +} + export async function removeTargets( newtId: string, targets: SubnetProxyTarget[] | SubnetProxyTargetV2[], @@ -76,6 +112,42 @@ export async function removeTargets( ); } +export async function removeTargetsBatch( + entries: { + newtId: string; + targets: SubnetProxyTarget[] | SubnetProxyTargetV2[]; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolved = await Promise.all( + entries.map(async (entry) => ({ + ...entry, + targets: await convertTargetsIfNecessary( + entry.newtId, + entry.targets + ) + })) + ); + + await sendToClientsBatch( + resolved.map((entry) => ({ + clientId: entry.newtId, + message: { + type: `newt/wg/targets/remove`, + data: entry.targets + }, + options: { + incrementConfigVersion: true, + compress: canCompress(entry.version, "newt") + } + })) + ); +} + export async function updateTargets( newtId: string, targets: { @@ -201,6 +273,171 @@ export async function removePeerData( }); } +const resolveOlmTargets = async ( + entries: { + clientId: number; + olmId?: string; + version?: string | null; + }[] +) => { + const unresolvedClientIds = entries + .filter((entry) => !entry.olmId) + .map((entry) => entry.clientId); + + const olmMap = new Map(); + + if (unresolvedClientIds.length > 0) { + const olmRows = await db + .select({ + clientId: olms.clientId, + olmId: olms.olmId, + version: olms.version + }) + .from(olms) + .where(inArray(olms.clientId, unresolvedClientIds)); + + for (const row of olmRows) { + if (row.clientId !== null) { + olmMap.set(row.clientId, { + olmId: row.olmId, + version: row.version + }); + } + } + } + + return entries + .map((entry) => { + if (entry.olmId) { + return { + clientId: entry.clientId, + olmId: entry.olmId, + version: entry.version + }; + } + + const resolved = olmMap.get(entry.clientId); + if (!resolved) { + return null; + } + + return { + clientId: entry.clientId, + olmId: resolved.olmId, + version: entry.version ?? resolved.version + }; + }) + .filter((entry) => entry !== null); +}; + +export async function addPeerDataBatch( + entries: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: Alias[]; + olmId?: string; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolvedTargets = await resolveOlmTargets(entries); + + if (resolvedTargets.length === 0) { + return; + } + + const payloads = entries + .map((entry) => { + const resolved = resolvedTargets.find( + (target) => target.clientId === entry.clientId + ); + if (!resolved) { + return null; + } + + return { + clientId: resolved.olmId, + message: { + type: `olm/wg/peer/data/add`, + data: { + siteId: entry.siteId, + remoteSubnets: entry.remoteSubnets, + aliases: entry.aliases + } + }, + options: { + incrementConfigVersion: true, + compress: canCompress(resolved.version, "olm") + } + }; + }) + .filter((entry) => entry !== null); + + if (payloads.length === 0) { + return; + } + + await sendToClientsBatch(payloads); +} + +export async function removePeerDataBatch( + entries: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: Alias[]; + olmId?: string; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolvedTargets = await resolveOlmTargets(entries); + + if (resolvedTargets.length === 0) { + return; + } + + const payloads = entries + .map((entry) => { + const resolved = resolvedTargets.find( + (target) => target.clientId === entry.clientId + ); + if (!resolved) { + return null; + } + + return { + clientId: resolved.olmId, + message: { + type: `olm/wg/peer/data/remove`, + data: { + siteId: entry.siteId, + remoteSubnets: entry.remoteSubnets, + aliases: entry.aliases + } + }, + options: { + incrementConfigVersion: true, + compress: canCompress(resolved.version, "olm") + } + }; + }) + .filter((entry) => entry !== null); + + if (payloads.length === 0) { + return; + } + + await sendToClientsBatch(payloads); +} + export async function updatePeerData( clientId: number, siteId: number, diff --git a/server/routers/newt/peers.ts b/server/routers/newt/peers.ts index 4b74d863d..6c38671f3 100644 --- a/server/routers/newt/peers.ts +++ b/server/routers/newt/peers.ts @@ -1,7 +1,7 @@ import { db, Site } from "@server/db"; import { newts, sites } from "@server/db"; import { eq } from "drizzle-orm"; -import { sendToClient } from "#dynamic/routers/ws"; +import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import logger from "@server/logger"; export async function addPeer( @@ -36,10 +36,14 @@ export async function addPeer( newtId = newt.newtId; } - await sendToClient(newtId, { - type: "newt/wg/peer/add", - data: peer - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + newtId, + { + type: "newt/wg/peer/add", + data: peer + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); @@ -76,12 +80,16 @@ export async function deletePeer( newtId = newt.newtId; } - await sendToClient(newtId, { - type: "newt/wg/peer/remove", - data: { - publicKey - } - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + newtId, + { + type: "newt/wg/peer/remove", + data: { + publicKey + } + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); @@ -90,6 +98,35 @@ export async function deletePeer( return site; } +export async function deletePeersBatch( + peers: { + siteId: number; + publicKey: string; + newtId: string; + }[] +) { + if (peers.length === 0) { + return; + } + + await sendToClientsBatch( + peers.map((peer) => ({ + clientId: peer.newtId, + message: { + type: "newt/wg/peer/remove", + data: { + publicKey: peer.publicKey + } + }, + options: { incrementConfigVersion: true } + })) + ).catch((error) => { + logger.warn(`Error sending batched newt peer removals:`, error); + }); + + logger.info(`Deleted ${peers.length} peer(s) from newts (batch)`); +} + export async function updatePeer( siteId: number, publicKey: string, @@ -122,13 +159,17 @@ export async function updatePeer( newtId = newt.newtId; } - await sendToClient(newtId, { - type: "newt/wg/peer/update", - data: { - publicKey, - ...peer - } - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + newtId, + { + type: "newt/wg/peer/update", + data: { + publicKey, + ...peer + } + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index 05e153fea..962d7367e 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -1,9 +1,9 @@ -import { sendToClient } from "#dynamic/routers/ws"; +import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import { clientSitesAssociationsCache, db, olms } from "@server/db"; import { canCompress } from "@server/lib/clientVersionChecks"; import config from "@server/lib/config"; import logger from "@server/logger"; -import { and, eq } from "drizzle-orm"; +import { and, eq, inArray } from "drizzle-orm"; import { Alias } from "yaml"; export async function addPeer( @@ -205,3 +205,150 @@ export async function initPeerAddHandshake( `Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}` ); } + +export async function deletePeersBatch( + peers: { + clientId: number; + siteId: number; + publicKey: string; + olmId?: string; + version?: string | null; + }[] +) { + if (peers.length === 0) { + return; + } + + const unresolvedClientIds = peers + .filter((peer) => !peer.olmId) + .map((peer) => peer.clientId); + + const olmByClientId = new Map< + number, + { olmId: string; version: string | null } + >(); + + if (unresolvedClientIds.length > 0) { + const olmRows = await db + .select({ + clientId: olms.clientId, + olmId: olms.olmId, + version: olms.version + }) + .from(olms) + .where(inArray(olms.clientId, unresolvedClientIds)); + + for (const row of olmRows) { + if (row.clientId !== null) { + olmByClientId.set(row.clientId, { + olmId: row.olmId, + version: row.version + }); + } + } + } + + const batchPayloads = peers + .map((peer) => { + const resolved = peer.olmId + ? { olmId: peer.olmId, version: peer.version ?? null } + : olmByClientId.get(peer.clientId); + if (!resolved) { + return null; + } + + return { + clientId: resolved.olmId, + message: { + type: "olm/wg/peer/remove", + data: { + publicKey: peer.publicKey, + siteId: peer.siteId + } + }, + options: { + incrementConfigVersion: true, + compress: canCompress( + peer.version ?? resolved.version, + "olm" + ) + } + }; + }) + .filter((payload) => payload !== null); + + if (batchPayloads.length === 0) { + return; + } + + await sendToClientsBatch(batchPayloads).catch((error) => { + logger.warn(`Error sending batched olm peer removals:`, error); + }); + + logger.info(`Deleted ${batchPayloads.length} peer(s) from olms (batch)`); +} + +export async function initPeerAddHandshakeBatch( + handshakes: { + clientId: number; + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }; + olmId: string; + chainId?: string; + }[] +) { + if (handshakes.length === 0) { + return; + } + + await sendToClientsBatch( + handshakes.map((item) => ({ + clientId: item.olmId, + message: { + type: "olm/wg/peer/holepunch/site/add", + data: { + siteId: item.peer.siteId, + exitNode: { + publicKey: item.peer.exitNode.publicKey, + relayPort: + config.getRawConfig().gerbil.clients_start_port, + endpoint: item.peer.exitNode.endpoint + }, + chainId: item.chainId + } + }, + options: { incrementConfigVersion: true } + })) + ).catch((error) => { + logger.warn(`Error sending batched olm handshakes:`, error); + }); + + await Promise.all( + handshakes.map((item) => + db + .update(clientSitesAssociationsCache) + .set({ isJitMode: false }) + .where( + and( + eq( + clientSitesAssociationsCache.clientId, + item.clientId + ), + eq( + clientSitesAssociationsCache.siteId, + item.peer.siteId + ) + ) + ) + ) + ); + + logger.info( + `Initiated ${handshakes.length} peer add handshake(s) to olms (batch)` + ); +} diff --git a/server/routers/ws/types.ts b/server/routers/ws/types.ts index e539954ce..eeb272457 100644 --- a/server/routers/ws/types.ts +++ b/server/routers/ws/types.ts @@ -76,12 +76,32 @@ export interface SendMessageOptions { compress?: boolean; } -// Redis message type for cross-node communication -export interface RedisMessage { - type: "direct" | "broadcast"; - targetClientId?: string; - excludeClientId?: string; +export interface BatchSendMessage { + clientId: string; message: WSMessage; - fromNodeId: string; options?: SendMessageOptions; } + +// Redis message types for cross-node communication +export type RedisMessage = + | { + type: "direct"; + targetClientId: string; + message: WSMessage; + fromNodeId: string; + } + | { + type: "direct-batch"; + messages: { + targetClientId: string; + message: WSMessage; + }[]; + fromNodeId: string; + } + | { + type: "broadcast"; + excludeClientId?: string; + message: WSMessage; + fromNodeId: string; + options?: SendMessageOptions; + }; diff --git a/server/routers/ws/ws.ts b/server/routers/ws/ws.ts index e7dcfe9cb..4ce337a20 100644 --- a/server/routers/ws/ws.ts +++ b/server/routers/ws/ws.ts @@ -26,7 +26,8 @@ import { WebSocketRequest, WSMessage, AuthenticatedWebSocket, - SendMessageOptions + SendMessageOptions, + BatchSendMessage } from "./types"; import { validateSessionToken } from "@server/auth/sessions/app"; @@ -212,6 +213,20 @@ const sendToClient = async ( return localSent; }; +const sendToClientsBatch = async ( + entries: BatchSendMessage[] +): Promise => { + if (entries.length === 0) { + return; + } + + await Promise.all( + entries.map((entry) => + sendToClient(entry.clientId, entry.message, entry.options) + ) + ); +}; + const broadcastToAllExcept = async ( message: WSMessage, excludeClientId?: string, @@ -552,6 +567,7 @@ export { router, handleWSUpgrade, sendToClient, + sendToClientsBatch, broadcastToAllExcept, connectedClients, hasActiveConnections,