diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index 8b601ae71..f6a94e7b7 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -21,10 +21,10 @@ import { } from "@server/db"; import { and, count, eq, inArray, ne } from "drizzle-orm"; -import { deletePeer as newtDeletePeer } from "@server/routers/newt/peers"; +import { deletePeersBatch as newtDeletePeersBatch } from "@server/routers/newt/peers"; import { - initPeerAddHandshake, - deletePeer as olmDeletePeer + initPeerAddHandshakeBatch, + deletePeersBatch as olmDeletePeersBatch } from "@server/routers/olm/peers"; import { sendToExitNode } from "#dynamic/lib/exitNodes"; import logger from "@server/logger"; @@ -35,10 +35,10 @@ import { parseEndpoint } from "@server/lib/ip"; import { - addPeerData, - addTargets as addSubnetProxyTargets, - removePeerData, - removeTargets as removeSubnetProxyTargets + addPeerDataBatch, + addTargetsBatch as addSubnetProxyTargetsBatch, + removePeerDataBatch, + removeTargetsBatch as removeSubnetProxyTargetsBatch } from "@server/routers/client/targets"; import { lockManager } from "#dynamic/lib/lock"; import { rebuildQueue } from "#dynamic/lib/rebuildQueue"; @@ -559,6 +559,28 @@ async function handleMessagesForSiteClients( const newtJobs: Promise[] = []; const olmJobs: Promise[] = []; const exitNodeJobs: Promise[] = []; + const newtPeerDeletes: { + siteId: number; + publicKey: string; + newtId: string; + }[] = []; + const olmPeerDeletes: { + clientId: number; + siteId: number; + publicKey: string; + olmId: string; + }[] = []; + const olmPeerAddHandshakes: { + clientId: number; + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }; + olmId: string; + }[] = []; // Combine all clients that need processing (those being added or removed) const clientsToProcess = new Map< @@ -638,15 +660,17 @@ async function handleMessagesForSiteClients( } if (isDelete) { - newtJobs.push(newtDeletePeer(siteId, client.pubKey, newt.newtId)); - olmJobs.push( - olmDeletePeer( - client.clientId, - siteId, - site.publicKey, - olm.olmId - ) - ); + newtPeerDeletes.push({ + siteId, + publicKey: client.pubKey, + newtId: newt.newtId + }); + olmPeerDeletes.push({ + clientId: client.clientId, + siteId, + publicKey: site.publicKey, + olmId: olm.olmId + }); } if (isAdd) { @@ -658,23 +682,34 @@ async function handleMessagesForSiteClients( continue; } - await initPeerAddHandshake( - // this will kick off the add peer process for the client - client.clientId, - { + olmPeerAddHandshakes.push({ + clientId: client.clientId, + peer: { siteId, exitNode: { publicKey: exitNode.publicKey, endpoint: exitNode.endpoint } }, - olm.olmId - ); + olmId: olm.olmId + }); } exitNodeJobs.push(updateClientSiteDestinations(client, trx)); } + if (newtPeerDeletes.length > 0) { + newtJobs.push(newtDeletePeersBatch(newtPeerDeletes)); + } + + if (olmPeerDeletes.length > 0) { + olmJobs.push(olmDeletePeersBatch(olmPeerDeletes)); + } + + if (olmPeerAddHandshakes.length > 0) { + olmJobs.push(initPeerAddHandshakeBatch(olmPeerAddHandshakes)); + } + Promise.all(exitNodeJobs).catch((error) => { logger.error( `rebuildClientAssociations: Error updating client site destinations for site ${site.siteId}:`, @@ -867,24 +902,28 @@ async function handleSubnetProxyTargetUpdates( if (targetsToAdd) { proxyJobs.push( - addSubnetProxyTargets( - newt.newtId, - targetsToAdd, - newt.version - ) + addSubnetProxyTargetsBatch([ + { + newtId: newt.newtId, + targets: targetsToAdd, + version: newt.version + } + ]) ); } - for (const client of addedClients) { - olmJobs.push( - addPeerData( - client.clientId, + olmJobs.push( + addPeerDataBatch( + addedClients.map((client) => ({ + clientId: client.clientId, siteId, - generateRemoteSubnets([siteResource]), - generateAliasConfig([siteResource]) - ) - ); - } + remoteSubnets: generateRemoteSubnets([ + siteResource + ]), + aliases: generateAliasConfig([siteResource]) + })) + ) + ); } } @@ -904,14 +943,23 @@ async function handleSubnetProxyTargetUpdates( if (targetsToRemove) { proxyJobs.push( - removeSubnetProxyTargets( - newt.newtId, - targetsToRemove, - newt.version - ) + removeSubnetProxyTargetsBatch([ + { + newtId: newt.newtId, + targets: targetsToRemove, + version: newt.version + } + ]) ); } + const peerDataRemovals: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: ReturnType; + }[] = []; + for (const client of removedClients) { if (!siteResource.destination) { continue; @@ -959,14 +1007,16 @@ async function handleSubnetProxyTargetUpdates( ? [] : generateRemoteSubnets([siteResource]); - olmJobs.push( - removePeerData( - client.clientId, - siteId, - remoteSubnetsToRemove, - generateAliasConfig([siteResource]) - ) - ); + peerDataRemovals.push({ + clientId: client.clientId, + siteId, + remoteSubnets: remoteSubnetsToRemove, + aliases: generateAliasConfig([siteResource]) + }); + } + + if (peerDataRemovals.length > 0) { + olmJobs.push(removePeerDataBatch(peerDataRemovals)); } } } @@ -1277,6 +1327,28 @@ async function handleMessagesForClientSites( const newtJobs: Promise[] = []; const olmJobs: Promise[] = []; const exitNodeJobs: Promise[] = []; + const newtPeerDeletes: { + siteId: number; + publicKey: string; + newtId: string; + }[] = []; + const olmPeerDeletes: { + clientId: number; + siteId: number; + publicKey: string; + olmId: string; + }[] = []; + const olmPeerAddHandshakes: { + clientId: number; + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }; + olmId: string; + }[] = []; const totalSitesOnClient = await trx .select({ count: count(clientSitesAssociationsCache.siteId) }) @@ -1308,19 +1380,19 @@ async function handleMessagesForClientSites( if (isRemove) { // Remove peer from newt - newtJobs.push( - newtDeletePeer(site.siteId, client.pubKey, newt.newtId) - ); + newtPeerDeletes.push({ + siteId: site.siteId, + publicKey: client.pubKey, + newtId: newt.newtId + }); try { // Remove peer from olm - olmJobs.push( - olmDeletePeer( - client.clientId, - site.siteId, - site.publicKey, - olmId - ) - ); + olmPeerDeletes.push({ + clientId: client.clientId, + siteId: site.siteId, + publicKey: site.publicKey, + olmId + }); } catch (error) { // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send if ( @@ -1352,10 +1424,9 @@ async function handleMessagesForClientSites( continue; } - await initPeerAddHandshake( - // this will kick off the add peer process for the client - client.clientId, - { + olmPeerAddHandshakes.push({ + clientId: client.clientId, + peer: { siteId: site.siteId, exitNode: { publicKey: exitNode.publicKey, @@ -1363,7 +1434,7 @@ async function handleMessagesForClientSites( } }, olmId - ); + }); } // Update exit node destinations @@ -1379,6 +1450,18 @@ async function handleMessagesForClientSites( ); } + if (newtPeerDeletes.length > 0) { + newtJobs.push(newtDeletePeersBatch(newtPeerDeletes)); + } + + if (olmPeerDeletes.length > 0) { + olmJobs.push(olmDeletePeersBatch(olmPeerDeletes)); + } + + if (olmPeerAddHandshakes.length > 0) { + olmJobs.push(initPeerAddHandshakeBatch(olmPeerAddHandshakes)); + } + Promise.all(exitNodeJobs).catch((error) => { logger.error( `rebuildClientAssociations: Error updating client site destinations for client ${client.clientId}:`, @@ -1477,6 +1560,20 @@ async function handleMessagesForClientResources( continue; } + const targetsToAddBatch: { + newtId: string; + targets: NonNullable< + Awaited> + >; + version: string | null; + }[] = []; + const peerDataAdds: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: ReturnType; + }[] = []; + for (const resource of resources) { const targets = await generateSubnetProxyTargetV2(resource, [ { @@ -1487,25 +1584,21 @@ async function handleMessagesForClientResources( ]); if (targets) { - proxyJobs.push( - addSubnetProxyTargets( - newt.newtId, - targets, - newt.version - ) - ); + targetsToAddBatch.push({ + newtId: newt.newtId, + targets, + version: newt.version + }); } try { // Add peer data to olm - olmJobs.push( - addPeerData( - client.clientId, - siteId, - generateRemoteSubnets([resource]), - generateAliasConfig([resource]) - ) - ); + peerDataAdds.push({ + clientId: client.clientId, + siteId, + remoteSubnets: generateRemoteSubnets([resource]), + aliases: generateAliasConfig([resource]) + }); } catch (error) { // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send if ( @@ -1520,6 +1613,14 @@ async function handleMessagesForClientResources( } } } + + if (targetsToAddBatch.length > 0) { + proxyJobs.push(addSubnetProxyTargetsBatch(targetsToAddBatch)); + } + + if (peerDataAdds.length > 0) { + olmJobs.push(addPeerDataBatch(peerDataAdds)); + } } } @@ -1586,6 +1687,20 @@ async function handleMessagesForClientResources( continue; } + const targetsToRemoveBatch: { + newtId: string; + targets: NonNullable< + Awaited> + >; + version: string | null; + }[] = []; + const peerDataRemovals: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: ReturnType; + }[] = []; + for (const resource of resources) { const targets = await generateSubnetProxyTargetV2(resource, [ { @@ -1596,13 +1711,11 @@ async function handleMessagesForClientResources( ]); if (targets) { - proxyJobs.push( - removeSubnetProxyTargets( - newt.newtId, - targets, - newt.version - ) - ); + targetsToRemoveBatch.push({ + newtId: newt.newtId, + targets, + version: newt.version + }); } try { @@ -1653,14 +1766,12 @@ async function handleMessagesForClientResources( : generateRemoteSubnets([resource]); // Remove peer data from olm - olmJobs.push( - removePeerData( - client.clientId, - siteId, - remoteSubnetsToRemove, - generateAliasConfig([resource]) - ) - ); + peerDataRemovals.push({ + clientId: client.clientId, + siteId, + remoteSubnets: remoteSubnetsToRemove, + aliases: generateAliasConfig([resource]) + }); } catch (error) { // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send if ( @@ -1675,6 +1786,16 @@ async function handleMessagesForClientResources( } } } + + if (targetsToRemoveBatch.length > 0) { + proxyJobs.push( + removeSubnetProxyTargetsBatch(targetsToRemoveBatch) + ); + } + + if (peerDataRemovals.length > 0) { + olmJobs.push(removePeerDataBatch(peerDataRemovals)); + } } } @@ -1928,7 +2049,15 @@ export async function cleanupSiteAssociations( for (const client of allClients) { // Tell each olm to drop the site's WireGuard peer. if (site.publicKey) { - jobs.push(olmDeletePeer(client.clientId, siteId, site.publicKey)); + jobs.push( + olmDeletePeersBatch([ + { + clientId: client.clientId, + siteId, + publicKey: site.publicKey + } + ]) + ); } // Recompute and push updated relay destinations (now excluding this site). diff --git a/server/private/routers/ws/ws.ts b/server/private/routers/ws/ws.ts index 2db8f3140..8d222fd72 100644 --- a/server/private/routers/ws/ws.ts +++ b/server/private/routers/ws/ws.ts @@ -38,6 +38,7 @@ import { messageHandlers } from "@server/routers/ws/messageHandlers"; import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers"; import { AuthenticatedWebSocket, + BatchSendMessage, ClientType, WSMessage, TokenPayload, @@ -736,6 +737,76 @@ const sendToClient = async ( return localSent; }; +const sendToClientsBatch = async ( + entries: BatchSendMessage[] +): Promise => { + if (entries.length === 0) { + return; + } + + const remoteEntries: { targetClientId: string; message: WSMessage }[] = []; + + for (const entry of entries) { + const options = entry.options || {}; + const { clientId, message } = entry; + + let configVersion = await getClientConfigVersion(clientId); + if (options.incrementConfigVersion) { + configVersion = await incrementClientConfigVersion(clientId); + } + + logger.debug( + `sendToClientsBatch: Message type ${message.type} queued for clientId ${clientId} (new configVersion: ${configVersion})` + ); + + const localSent = await sendToClientLocal( + clientId, + message, + options, + configVersion + ); + + if (!localSent && redisManager.isRedisEnabled()) { + remoteEntries.push({ + targetClientId: clientId, + message: { + ...message, + configVersion + } + }); + } else if (!localSent && !redisManager.isRedisEnabled()) { + logger.debug( + `Could not deliver batch message to ${clientId} - not connected locally and Redis unavailable` + ); + } + } + + if (!redisManager.isRedisEnabled() || remoteEntries.length === 0) { + return; + } + + for (let i = 0; i < remoteEntries.length; i += REDIS_DIRECT_BATCH_SIZE) { + const messages = remoteEntries.slice(i, i + REDIS_DIRECT_BATCH_SIZE); + try { + const redisMessage: RedisMessage = { + type: "direct-batch", + messages, + fromNodeId: NODE_ID + }; + + await redisManager.publish( + REDIS_CHANNEL, + JSON.stringify(redisMessage) + ); + } catch (error) { + logger.error( + "Failed to send explicit direct batch via Redis, messages may be lost:", + error + ); + } + } +}; + const broadcastToAllExcept = async ( message: WSMessage, excludeClientId?: string, @@ -1239,6 +1310,7 @@ export { router, handleWSUpgrade, sendToClient, + sendToClientsBatch, broadcastToAllExcept, connectedClients, hasActiveConnections, diff --git a/server/routers/client/targets.ts b/server/routers/client/targets.ts index c208acd88..c62a64ae0 100644 --- a/server/routers/client/targets.ts +++ b/server/routers/client/targets.ts @@ -1,4 +1,4 @@ -import { sendToClient } from "#dynamic/routers/ws"; +import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import { db, newts, olms } from "@server/db"; import { Alias, @@ -8,7 +8,7 @@ import { } from "@server/lib/ip"; import { canCompress } from "@server/lib/clientVersionChecks"; import logger from "@server/logger"; -import { eq } from "drizzle-orm"; +import { eq, inArray } from "drizzle-orm"; import semver from "semver"; const NEWT_V2_TARGETS_VERSION = ">=1.10.3"; @@ -59,6 +59,42 @@ export async function addTargets( ); } +export async function addTargetsBatch( + entries: { + newtId: string; + targets: SubnetProxyTarget[] | SubnetProxyTargetV2[]; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolved = await Promise.all( + entries.map(async (entry) => ({ + ...entry, + targets: await convertTargetsIfNecessary( + entry.newtId, + entry.targets + ) + })) + ); + + await sendToClientsBatch( + resolved.map((entry) => ({ + clientId: entry.newtId, + message: { + type: `newt/wg/targets/add`, + data: entry.targets + }, + options: { + incrementConfigVersion: true, + compress: canCompress(entry.version, "newt") + } + })) + ); +} + export async function removeTargets( newtId: string, targets: SubnetProxyTarget[] | SubnetProxyTargetV2[], @@ -76,6 +112,42 @@ export async function removeTargets( ); } +export async function removeTargetsBatch( + entries: { + newtId: string; + targets: SubnetProxyTarget[] | SubnetProxyTargetV2[]; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolved = await Promise.all( + entries.map(async (entry) => ({ + ...entry, + targets: await convertTargetsIfNecessary( + entry.newtId, + entry.targets + ) + })) + ); + + await sendToClientsBatch( + resolved.map((entry) => ({ + clientId: entry.newtId, + message: { + type: `newt/wg/targets/remove`, + data: entry.targets + }, + options: { + incrementConfigVersion: true, + compress: canCompress(entry.version, "newt") + } + })) + ); +} + export async function updateTargets( newtId: string, targets: { @@ -201,6 +273,171 @@ export async function removePeerData( }); } +const resolveOlmTargets = async ( + entries: { + clientId: number; + olmId?: string; + version?: string | null; + }[] +) => { + const unresolvedClientIds = entries + .filter((entry) => !entry.olmId) + .map((entry) => entry.clientId); + + const olmMap = new Map(); + + if (unresolvedClientIds.length > 0) { + const olmRows = await db + .select({ + clientId: olms.clientId, + olmId: olms.olmId, + version: olms.version + }) + .from(olms) + .where(inArray(olms.clientId, unresolvedClientIds)); + + for (const row of olmRows) { + if (row.clientId !== null) { + olmMap.set(row.clientId, { + olmId: row.olmId, + version: row.version + }); + } + } + } + + return entries + .map((entry) => { + if (entry.olmId) { + return { + clientId: entry.clientId, + olmId: entry.olmId, + version: entry.version + }; + } + + const resolved = olmMap.get(entry.clientId); + if (!resolved) { + return null; + } + + return { + clientId: entry.clientId, + olmId: resolved.olmId, + version: entry.version ?? resolved.version + }; + }) + .filter((entry) => entry !== null); +}; + +export async function addPeerDataBatch( + entries: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: Alias[]; + olmId?: string; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolvedTargets = await resolveOlmTargets(entries); + + if (resolvedTargets.length === 0) { + return; + } + + const payloads = entries + .map((entry) => { + const resolved = resolvedTargets.find( + (target) => target.clientId === entry.clientId + ); + if (!resolved) { + return null; + } + + return { + clientId: resolved.olmId, + message: { + type: `olm/wg/peer/data/add`, + data: { + siteId: entry.siteId, + remoteSubnets: entry.remoteSubnets, + aliases: entry.aliases + } + }, + options: { + incrementConfigVersion: true, + compress: canCompress(resolved.version, "olm") + } + }; + }) + .filter((entry) => entry !== null); + + if (payloads.length === 0) { + return; + } + + await sendToClientsBatch(payloads); +} + +export async function removePeerDataBatch( + entries: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: Alias[]; + olmId?: string; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolvedTargets = await resolveOlmTargets(entries); + + if (resolvedTargets.length === 0) { + return; + } + + const payloads = entries + .map((entry) => { + const resolved = resolvedTargets.find( + (target) => target.clientId === entry.clientId + ); + if (!resolved) { + return null; + } + + return { + clientId: resolved.olmId, + message: { + type: `olm/wg/peer/data/remove`, + data: { + siteId: entry.siteId, + remoteSubnets: entry.remoteSubnets, + aliases: entry.aliases + } + }, + options: { + incrementConfigVersion: true, + compress: canCompress(resolved.version, "olm") + } + }; + }) + .filter((entry) => entry !== null); + + if (payloads.length === 0) { + return; + } + + await sendToClientsBatch(payloads); +} + export async function updatePeerData( clientId: number, siteId: number, diff --git a/server/routers/newt/peers.ts b/server/routers/newt/peers.ts index 4b74d863d..6c38671f3 100644 --- a/server/routers/newt/peers.ts +++ b/server/routers/newt/peers.ts @@ -1,7 +1,7 @@ import { db, Site } from "@server/db"; import { newts, sites } from "@server/db"; import { eq } from "drizzle-orm"; -import { sendToClient } from "#dynamic/routers/ws"; +import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import logger from "@server/logger"; export async function addPeer( @@ -36,10 +36,14 @@ export async function addPeer( newtId = newt.newtId; } - await sendToClient(newtId, { - type: "newt/wg/peer/add", - data: peer - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + newtId, + { + type: "newt/wg/peer/add", + data: peer + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); @@ -76,12 +80,16 @@ export async function deletePeer( newtId = newt.newtId; } - await sendToClient(newtId, { - type: "newt/wg/peer/remove", - data: { - publicKey - } - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + newtId, + { + type: "newt/wg/peer/remove", + data: { + publicKey + } + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); @@ -90,6 +98,35 @@ export async function deletePeer( return site; } +export async function deletePeersBatch( + peers: { + siteId: number; + publicKey: string; + newtId: string; + }[] +) { + if (peers.length === 0) { + return; + } + + await sendToClientsBatch( + peers.map((peer) => ({ + clientId: peer.newtId, + message: { + type: "newt/wg/peer/remove", + data: { + publicKey: peer.publicKey + } + }, + options: { incrementConfigVersion: true } + })) + ).catch((error) => { + logger.warn(`Error sending batched newt peer removals:`, error); + }); + + logger.info(`Deleted ${peers.length} peer(s) from newts (batch)`); +} + export async function updatePeer( siteId: number, publicKey: string, @@ -122,13 +159,17 @@ export async function updatePeer( newtId = newt.newtId; } - await sendToClient(newtId, { - type: "newt/wg/peer/update", - data: { - publicKey, - ...peer - } - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + newtId, + { + type: "newt/wg/peer/update", + data: { + publicKey, + ...peer + } + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index 05e153fea..962d7367e 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -1,9 +1,9 @@ -import { sendToClient } from "#dynamic/routers/ws"; +import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import { clientSitesAssociationsCache, db, olms } from "@server/db"; import { canCompress } from "@server/lib/clientVersionChecks"; import config from "@server/lib/config"; import logger from "@server/logger"; -import { and, eq } from "drizzle-orm"; +import { and, eq, inArray } from "drizzle-orm"; import { Alias } from "yaml"; export async function addPeer( @@ -205,3 +205,150 @@ export async function initPeerAddHandshake( `Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}` ); } + +export async function deletePeersBatch( + peers: { + clientId: number; + siteId: number; + publicKey: string; + olmId?: string; + version?: string | null; + }[] +) { + if (peers.length === 0) { + return; + } + + const unresolvedClientIds = peers + .filter((peer) => !peer.olmId) + .map((peer) => peer.clientId); + + const olmByClientId = new Map< + number, + { olmId: string; version: string | null } + >(); + + if (unresolvedClientIds.length > 0) { + const olmRows = await db + .select({ + clientId: olms.clientId, + olmId: olms.olmId, + version: olms.version + }) + .from(olms) + .where(inArray(olms.clientId, unresolvedClientIds)); + + for (const row of olmRows) { + if (row.clientId !== null) { + olmByClientId.set(row.clientId, { + olmId: row.olmId, + version: row.version + }); + } + } + } + + const batchPayloads = peers + .map((peer) => { + const resolved = peer.olmId + ? { olmId: peer.olmId, version: peer.version ?? null } + : olmByClientId.get(peer.clientId); + if (!resolved) { + return null; + } + + return { + clientId: resolved.olmId, + message: { + type: "olm/wg/peer/remove", + data: { + publicKey: peer.publicKey, + siteId: peer.siteId + } + }, + options: { + incrementConfigVersion: true, + compress: canCompress( + peer.version ?? resolved.version, + "olm" + ) + } + }; + }) + .filter((payload) => payload !== null); + + if (batchPayloads.length === 0) { + return; + } + + await sendToClientsBatch(batchPayloads).catch((error) => { + logger.warn(`Error sending batched olm peer removals:`, error); + }); + + logger.info(`Deleted ${batchPayloads.length} peer(s) from olms (batch)`); +} + +export async function initPeerAddHandshakeBatch( + handshakes: { + clientId: number; + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }; + olmId: string; + chainId?: string; + }[] +) { + if (handshakes.length === 0) { + return; + } + + await sendToClientsBatch( + handshakes.map((item) => ({ + clientId: item.olmId, + message: { + type: "olm/wg/peer/holepunch/site/add", + data: { + siteId: item.peer.siteId, + exitNode: { + publicKey: item.peer.exitNode.publicKey, + relayPort: + config.getRawConfig().gerbil.clients_start_port, + endpoint: item.peer.exitNode.endpoint + }, + chainId: item.chainId + } + }, + options: { incrementConfigVersion: true } + })) + ).catch((error) => { + logger.warn(`Error sending batched olm handshakes:`, error); + }); + + await Promise.all( + handshakes.map((item) => + db + .update(clientSitesAssociationsCache) + .set({ isJitMode: false }) + .where( + and( + eq( + clientSitesAssociationsCache.clientId, + item.clientId + ), + eq( + clientSitesAssociationsCache.siteId, + item.peer.siteId + ) + ) + ) + ) + ); + + logger.info( + `Initiated ${handshakes.length} peer add handshake(s) to olms (batch)` + ); +} diff --git a/server/routers/ws/types.ts b/server/routers/ws/types.ts index d541d3276..eeb272457 100644 --- a/server/routers/ws/types.ts +++ b/server/routers/ws/types.ts @@ -76,6 +76,12 @@ export interface SendMessageOptions { compress?: boolean; } +export interface BatchSendMessage { + clientId: string; + message: WSMessage; + options?: SendMessageOptions; +} + // Redis message types for cross-node communication export type RedisMessage = | { diff --git a/server/routers/ws/ws.ts b/server/routers/ws/ws.ts index e7dcfe9cb..4ce337a20 100644 --- a/server/routers/ws/ws.ts +++ b/server/routers/ws/ws.ts @@ -26,7 +26,8 @@ import { WebSocketRequest, WSMessage, AuthenticatedWebSocket, - SendMessageOptions + SendMessageOptions, + BatchSendMessage } from "./types"; import { validateSessionToken } from "@server/auth/sessions/app"; @@ -212,6 +213,20 @@ const sendToClient = async ( return localSent; }; +const sendToClientsBatch = async ( + entries: BatchSendMessage[] +): Promise => { + if (entries.length === 0) { + return; + } + + await Promise.all( + entries.map((entry) => + sendToClient(entry.clientId, entry.message, entry.options) + ) + ); +}; + const broadcastToAllExcept = async ( message: WSMessage, excludeClientId?: string, @@ -552,6 +567,7 @@ export { router, handleWSUpgrade, sendToClient, + sendToClientsBatch, broadcastToAllExcept, connectedClients, hasActiveConnections,