Compare commits

...

6 Commits

Author SHA1 Message Date
Owen Schwartz
92f992728f Merge pull request #3074 from fosrl/dev
Optimize building aliases in jit mode
2026-05-14 12:25:44 -07:00
Owen
78ad2d17c7 Optimize building aliases in jit mode 2026-05-14 12:25:05 -07:00
Owen Schwartz
b29bb7384d Merge pull request #3073 from fosrl/dev
Further optimizations
2026-05-14 12:00:25 -07:00
Owen
5a8de8210b Further optimizations 2026-05-14 11:59:59 -07:00
Owen Schwartz
d5181454f4 Merge pull request #3072 from fosrl/dev
Optimize this
2026-05-14 11:34:56 -07:00
Owen
0e0666cacf Optimize this 2026-05-14 11:34:09 -07:00
4 changed files with 239 additions and 163 deletions

View File

@@ -11,7 +11,7 @@ import {
ExitNode ExitNode
} from "@server/db"; } from "@server/db";
import { db } from "@server/db"; import { db } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and, inArray } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
@@ -202,24 +202,29 @@ export async function updateAndGenerateEndpointDestinations(
) )
); );
// Update clientSites for each site on this exit node // Format the endpoint properly for both IPv4 and IPv6
const formattedEndpoint = formatEndpoint(ip, port);
// Determine which rows actually need updating and whether the endpoint
// (as opposed to only the publicKey) changed for any of them.
const siteIdsToUpdate: number[] = [];
let endpointChanged = false;
for (const site of sitesOnExitNode) { for (const site of sitesOnExitNode) {
// logger.debug(
// `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}`
// );
// Format the endpoint properly for both IPv4 and IPv6
const formattedEndpoint = formatEndpoint(ip, port);
// if the public key or endpoint has changed, update it otherwise continue
if ( if (
site.endpoint === formattedEndpoint && site.endpoint === formattedEndpoint &&
site.publicKey === publicKey site.publicKey === publicKey
) { ) {
continue; continue;
} }
siteIdsToUpdate.push(site.siteId);
if (site.endpoint !== formattedEndpoint) {
endpointChanged = true;
}
}
const [updatedClientSitesAssociationsCache] = await db if (siteIdsToUpdate.length > 0) {
// Single bulk update for all affected rows for this client on this exit node
await db
.update(clientSitesAssociationsCache) .update(clientSitesAssociationsCache)
.set({ .set({
endpoint: formattedEndpoint, endpoint: formattedEndpoint,
@@ -228,24 +233,22 @@ export async function updateAndGenerateEndpointDestinations(
.where( .where(
and( and(
eq(clientSitesAssociationsCache.clientId, olm.clientId), eq(clientSitesAssociationsCache.clientId, olm.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId) inArray(
clientSitesAssociationsCache.siteId,
siteIdsToUpdate
)
) )
) );
.returning();
if ( // Only trigger downstream peer updates once per hole punch: the
updatedClientSitesAssociationsCache.endpoint !== // endpoint is the same for every site on this exit node, and
site.endpoint && // this is the endpoint from the join table not the site // handleClientEndpointChange already fans out to all connected
updatedClient.pubKey === publicKey // only trigger if the client's public key matches the current public key which means it has registered so we dont prematurely send the update // sites for this client.
) { if (endpointChanged && updatedClient.pubKey === publicKey) {
logger.info( logger.info(
`ClientSitesAssociationsCache for client ${olm.clientId} and site ${site.siteId} endpoint changed from ${site.endpoint} to ${updatedClientSitesAssociationsCache.endpoint}` `ClientSitesAssociationsCache for client ${olm.clientId} endpoint changed to ${formattedEndpoint} for ${siteIdsToUpdate.length} site(s) on exit node ${exitNode.exitNodeId}`
);
// Handle any additional logic for endpoint change
handleClientEndpointChange(
olm.clientId,
updatedClientSitesAssociationsCache.endpoint!
); );
handleClientEndpointChange(olm.clientId, formattedEndpoint);
} }
} }
@@ -408,12 +411,14 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) {
return; return;
} }
// Get all non-relayed clients connected to this site // Get all non-relayed and not jit clients connected to this site
const connectedClients = await db const connectedClients = await db
.select({ .select({
online: clients.online,
clientId: clients.clientId, clientId: clients.clientId,
olmId: olms.olmId, olmId: olms.olmId,
isRelayed: clientSitesAssociationsCache.isRelayed isRelayed: clientSitesAssociationsCache.isRelayed,
isJitMode: clientSitesAssociationsCache.isJitMode
}) })
.from(clientSitesAssociationsCache) .from(clientSitesAssociationsCache)
.innerJoin( .innerJoin(
@@ -423,32 +428,36 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) {
.innerJoin(olms, eq(olms.clientId, clients.clientId)) .innerJoin(olms, eq(olms.clientId, clients.clientId))
.where( .where(
and( and(
eq(clients.online, true), // the client has to be online or it does not matter...
eq(clientSitesAssociationsCache.siteId, siteId), eq(clientSitesAssociationsCache.siteId, siteId),
eq(clientSitesAssociationsCache.isRelayed, false) eq(clientSitesAssociationsCache.isRelayed, false),
eq(clientSitesAssociationsCache.isJitMode, false)
) )
); );
// Update each non-relayed client with the new site endpoint // Update each non-relayed client with the new site endpoint (in parallel)
for (const client of connectedClients) { await Promise.allSettled(
try { connectedClients.map(async (client) => {
await updateOlmPeer( try {
client.clientId, await updateOlmPeer(
{ client.clientId,
siteId: siteId, {
publicKey: site.publicKey, siteId: siteId,
endpoint: newEndpoint publicKey: site.publicKey!,
}, endpoint: newEndpoint
client.olmId },
); client.olmId
logger.debug( );
`Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}` logger.debug(
); `Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}`
} catch (error) { );
logger.error( } catch (error) {
`Failed to update client ${client.clientId} with new site endpoint: ${error}` logger.error(
); `Failed to update client ${client.clientId} with new site endpoint: ${error}`
} );
} }
})
);
} catch (error) { } catch (error) {
logger.error( logger.error(
`Error handling site endpoint change for site ${siteId}: ${error}` `Error handling site endpoint change for site ${siteId}: ${error}`
@@ -456,11 +465,11 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) {
} }
} }
async function handleClientEndpointChange( async function handleClientEndpointChange( // TODO: I THINK WE DONT NEED TO HIT EVERY SITE HERE BECAUSE WE ONLY NEED TO UPDATE THE SITES CONNECTED TO THIS NODE WHICH WE ALREADY HAVE FROM ABOVE
clientId: number, clientId: number,
newEndpoint: string newEndpoint: string
) { ) {
// Alert all sites connected to this client that the endpoint has changed (only if NOT relayed) // Alert all sites connected to this client that the endpoint has changed (only if NOT relayed and NOT JIT MODE)
try { try {
// Get client details // Get client details
const [client] = await db const [client] = await db
@@ -480,6 +489,7 @@ async function handleClientEndpointChange(
siteId: sites.siteId, siteId: sites.siteId,
newtId: newts.newtId, newtId: newts.newtId,
isRelayed: clientSitesAssociationsCache.isRelayed, isRelayed: clientSitesAssociationsCache.isRelayed,
isJitMode: clientSitesAssociationsCache.isJitMode,
subnet: clients.subnet subnet: clients.subnet
}) })
.from(clientSitesAssociationsCache) .from(clientSitesAssociationsCache)
@@ -494,38 +504,49 @@ async function handleClientEndpointChange(
) )
.where( .where(
and( and(
eq(sites.online, true), // the site has to be online or it does not matter...
eq(clientSitesAssociationsCache.clientId, clientId), eq(clientSitesAssociationsCache.clientId, clientId),
eq(clientSitesAssociationsCache.isRelayed, false) eq(clientSitesAssociationsCache.isRelayed, false),
eq(clientSitesAssociationsCache.isJitMode, false)
) )
); );
// Update each non-relayed site with the new client endpoint if (connectedSites.length > 250) {
for (const siteData of connectedSites) { logger.warn(
try { `Client ${clientId} has ${connectedSites.length} connected sites so the client will be in jit mode anyway, skipping endpoint updates`
if (!siteData.subnet) { );
return;
}
// Update each non-relayed site with the new client endpoint (in parallel)
await Promise.allSettled(
connectedSites.map(async (siteData) => {
if (!siteData.subnet || !client.pubKey) {
logger.warn( logger.warn(
`Client ${clientId} has no subnet, skipping update for site ${siteData.siteId}` `Client ${clientId} has no subnet or public key, skipping update for site ${siteData.siteId}`
); );
continue; return;
} }
await updateNewtPeer( try {
siteData.siteId, await updateNewtPeer(
client.pubKey, siteData.siteId,
{ client.pubKey,
endpoint: newEndpoint {
}, endpoint: newEndpoint
siteData.newtId },
); siteData.newtId
logger.debug( );
`Updated site ${siteData.siteId} with new client ${clientId} endpoint: ${newEndpoint}` logger.debug(
); `Updated site ${siteData.siteId} with new client ${clientId} endpoint: ${newEndpoint}`
} catch (error) { );
logger.error( } catch (error) {
`Failed to update site ${siteData.siteId} with new client endpoint: ${error}` logger.error(
); `Failed to update site ${siteData.siteId} with new client endpoint: ${error}`
} );
} }
})
);
} catch (error) { } catch (error) {
logger.error( logger.error(
`Error handling client endpoint change for client ${clientId}: ${error}` `Error handling client endpoint change for client ${clientId}: ${error}`

View File

@@ -5,6 +5,7 @@ import {
db, db,
exitNodes, exitNodes,
networks, networks,
SiteResource,
siteNetworks, siteNetworks,
siteResources, siteResources,
sites sites
@@ -15,7 +16,7 @@ import {
generateRemoteSubnets generateRemoteSubnets
} from "@server/lib/ip"; } from "@server/lib/ip";
import logger from "@server/logger"; import logger from "@server/logger";
import { and, eq } from "drizzle-orm"; import { eq, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers"; import { addPeer, deletePeer } from "../newt/peers";
import config from "@server/lib/config"; import config from "@server/lib/config";
@@ -46,49 +47,79 @@ export async function buildSiteConfigurationForOlmClient(
) )
.where(eq(clientSitesAssociationsCache.clientId, client.clientId)); .where(eq(clientSitesAssociationsCache.clientId, client.clientId));
if (sitesData.length === 0) {
return siteConfigurations;
}
// Batch-fetch every site resource this client has access to across ALL sites
// in a single query, then group by siteId in memory. This avoids issuing one
// query per site (which would be N round-trips for N sites).
const allClientSiteResources = await db
.select({
siteResource: siteResources,
siteId: siteNetworks.siteId
})
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.innerJoin(networks, eq(siteResources.networkId, networks.networkId))
.innerJoin(siteNetworks, eq(networks.networkId, siteNetworks.networkId))
.where(
eq(clientSiteResourcesAssociationsCache.clientId, client.clientId)
);
const siteResourcesBySiteId = new Map<number, SiteResource[]>();
for (const row of allClientSiteResources) {
const arr = siteResourcesBySiteId.get(row.siteId);
if (arr) {
arr.push(row.siteResource);
} else {
siteResourcesBySiteId.set(row.siteId, [row.siteResource]);
}
}
// Batch-fetch exit nodes for all sites in one query (only needed in relay mode).
const exitNodesById = new Map<number, typeof exitNodes.$inferSelect>();
if (!jitMode && relay) {
const exitNodeIds = Array.from(
new Set(
sitesData
.map(({ sites: s }) => s.exitNodeId)
.filter((id): id is number => id != null)
)
);
if (exitNodeIds.length > 0) {
const nodes = await db
.select()
.from(exitNodes)
.where(inArray(exitNodes.exitNodeId, exitNodeIds));
for (const n of nodes) {
exitNodesById.set(n.exitNodeId, n);
}
}
}
const clientsStartPort = config.getRawConfig().gerbil.clients_start_port;
const peerOps: Promise<unknown>[] = [];
// Process each site // Process each site
for (const { for (const {
sites: site, sites: site,
clientSitesAssociationsCache: association clientSitesAssociationsCache: association
} of sitesData) { } of sitesData) {
const allSiteResources = await db // only get the site resources that this client has access to const allSiteResources = siteResourcesBySiteId.get(site.siteId) ?? [];
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.innerJoin(
networks,
eq(siteResources.networkId, networks.networkId)
)
.innerJoin(
siteNetworks,
eq(networks.networkId, siteNetworks.networkId)
)
.where(
and(
eq(siteNetworks.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
if (jitMode) { if (jitMode) {
// Add site configuration to the array // Add site configuration to the array
siteConfigurations.push({ siteConfigurations.push({
siteId: site.siteId, siteId: site.siteId,
// remoteSubnets: generateRemoteSubnets( // remoteSubnets: generateRemoteSubnets(allSiteResources),
// allSiteResources.map(({ siteResources }) => siteResources) aliases: generateAliasConfig(allSiteResources)
// ),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
}); });
continue; continue;
} }
@@ -126,7 +157,7 @@ export async function buildSiteConfigurationForOlmClient(
logger.info( logger.info(
`Public key mismatch. Deleting old peer from site ${site.siteId}...` `Public key mismatch. Deleting old peer from site ${site.siteId}...`
); );
await deletePeer(site.siteId, client.pubKey!); peerOps.push(deletePeer(site.siteId, client.pubKey!));
} }
if (!site.subnet) { if (!site.subnet) {
@@ -134,27 +165,19 @@ export async function buildSiteConfigurationForOlmClient(
continue; continue;
} }
const [clientSite] = await db // Add the peer to the exit node for this site. The endpoint comes from
.select() // the already-joined association row above, so no extra query needed.
.from(clientSitesAssociationsCache) if (association.endpoint && publicKey) {
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
)
.limit(1);
// Add the peer to the exit node for this site
if (clientSite.endpoint && publicKey) {
logger.info( logger.info(
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}` `Adding peer ${publicKey} to site ${site.siteId} with endpoint ${association.endpoint}`
);
peerOps.push(
addPeer(site.siteId, {
publicKey: publicKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
endpoint: relay ? "" : association.endpoint
})
); );
await addPeer(site.siteId, {
publicKey: publicKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
endpoint: relay ? "" : clientSite.endpoint
});
} else { } else {
logger.warn( logger.warn(
`Client ${client.clientId} has no endpoint, skipping peer addition` `Client ${client.clientId} has no endpoint, skipping peer addition`
@@ -163,16 +186,12 @@ export async function buildSiteConfigurationForOlmClient(
let relayEndpoint: string | undefined = undefined; let relayEndpoint: string | undefined = undefined;
if (relay) { if (relay) {
const [exitNode] = await db const exitNode = exitNodesById.get(site.exitNodeId);
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
if (!exitNode) { if (!exitNode) {
logger.warn(`Exit node not found for site ${site.siteId}`); logger.warn(`Exit node not found for site ${site.siteId}`);
continue; continue;
} }
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`; relayEndpoint = `${exitNode.endpoint}:${clientsStartPort}`;
} }
// Add site configuration to the array // Add site configuration to the array
@@ -184,12 +203,16 @@ export async function buildSiteConfigurationForOlmClient(
publicKey: site.publicKey, publicKey: site.publicKey,
serverIP: site.address, serverIP: site.address,
serverPort: site.listenPort, serverPort: site.listenPort,
remoteSubnets: generateRemoteSubnets( remoteSubnets: generateRemoteSubnets(allSiteResources),
allSiteResources.map(({ siteResources }) => siteResources) aliases: generateAliasConfig(allSiteResources)
), });
aliases: generateAliasConfig( }
allSiteResources.map(({ siteResources }) => siteResources)
) // Run all peer add/delete operations concurrently rather than serially per
// site, so total time is bounded by the slowest call instead of the sum.
if (peerOps.length > 0) {
Promise.allSettled(peerOps).catch((err) => {
logger.error("Error processing peer operations: ", err);
}); });
} }

View File

@@ -8,7 +8,7 @@ import {
ExitNode, ExitNode,
exitNodes, exitNodes,
sites, sites,
clientSitesAssociationsCache, clientSitesAssociationsCache
} from "@server/db"; } from "@server/db";
import { olms } from "@server/db"; import { olms } from "@server/db";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -28,6 +28,7 @@ import { verifyPassword } from "@server/auth/password";
import logger from "@server/logger"; import logger from "@server/logger";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { APP_VERSION } from "@server/lib/consts"; import { APP_VERSION } from "@server/lib/consts";
import { build } from "@server/build";
export const olmGetTokenBodySchema = z.object({ export const olmGetTokenBodySchema = z.object({
olmId: z.string(), olmId: z.string(),
@@ -220,6 +221,22 @@ export async function getOlmToken(
) )
.where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!)); .where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!));
if (clientSites.length > 250 && build == "saas") {
// set all of the cache rows isJitMode to true
await db
.update(clientSitesAssociationsCache)
.set({ isJitMode: true })
.where(
and(
eq(
clientSitesAssociationsCache.clientId,
clientIdToUse!
),
eq(clientSitesAssociationsCache.isJitMode, false)
)
);
}
// Extract unique exit node IDs // Extract unique exit node IDs
const exitNodeIds = Array.from( const exitNodeIds = Array.from(
new Set( new Set(

View File

@@ -1,4 +1,4 @@
import { db, orgs } from "@server/db"; import { db, orgs, primaryDb } from "@server/db";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { import {
clients, clients,
@@ -7,7 +7,7 @@ import {
olms, olms,
sites sites
} from "@server/db"; } from "@server/db";
import { count, eq } from "drizzle-orm"; import { and, count, eq, ne, or } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy"; import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app"; import { validateSessionToken } from "@server/auth/sessions/app";
@@ -81,7 +81,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.where(eq(olms.olmId, olm.olmId)); .where(eq(olms.olmId, olm.olmId));
} }
const [client] = await db const [client] = await primaryDb // read from the primary here so there is no latency with the last update on the holepunch
.select() .select()
.from(clients) .from(clients)
.where(eq(clients.clientId, olm.clientId)) .where(eq(clients.clientId, olm.clientId))
@@ -98,7 +98,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (client.blocked) { if (client.blocked) {
logger.debug( logger.debug(
`[handleOlmRegisterMessage] Client ${client.clientId} is blocked. Ignoring register.`, `[handleOlmRegisterMessage] Client ${client.clientId} is blocked. Ignoring register.`,
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId);
return; return;
@@ -107,7 +107,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (client.approvalState == "pending") { if (client.approvalState == "pending") {
logger.debug( logger.debug(
`[handleOlmRegisterMessage] Client ${client.clientId} approval is pending. Ignoring register.`, `[handleOlmRegisterMessage] Client ${client.clientId} approval is pending. Ignoring register.`,
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId);
return; return;
@@ -136,7 +136,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (!org) { if (!org) {
logger.warn("[handleOlmRegisterMessage] Org not found", { logger.warn("[handleOlmRegisterMessage] Org not found", {
orgId: client.orgId orgId: client.orgId,
clientId: client.clientId
}); });
sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId);
return; return;
@@ -145,7 +146,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (orgId) { if (orgId) {
if (!olm.userId) { if (!olm.userId) {
logger.warn("[handleOlmRegisterMessage] Olm has no user ID", { logger.warn("[handleOlmRegisterMessage] Olm has no user ID", {
orgId: client.orgId orgId: client.orgId,
clientId: client.clientId
}); });
sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId);
return; return;
@@ -156,7 +158,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (!userSession || !user) { if (!userSession || !user) {
logger.warn( logger.warn(
"[handleOlmRegisterMessage] Invalid user session for olm register", "[handleOlmRegisterMessage] Invalid user session for olm register",
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId); sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId);
return; return;
@@ -164,7 +166,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (user.userId !== olm.userId) { if (user.userId !== olm.userId) {
logger.warn( logger.warn(
"[handleOlmRegisterMessage] User ID mismatch for olm register", "[handleOlmRegisterMessage] User ID mismatch for olm register",
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId); sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId);
return; return;
@@ -182,13 +184,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.debug("[handleOlmRegisterMessage] Policy check result", { logger.debug("[handleOlmRegisterMessage] Policy check result", {
orgId: client.orgId, orgId: client.orgId,
clientId: client.clientId,
policyCheck policyCheck
}); });
if (policyCheck?.error) { if (policyCheck?.error) {
logger.error( logger.error(
`[handleOlmRegisterMessage] Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}`, `[handleOlmRegisterMessage] Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}`,
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId); sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
return; return;
@@ -197,7 +200,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (policyCheck.policies?.passwordAge?.compliant === false) { if (policyCheck.policies?.passwordAge?.compliant === false) {
logger.warn( logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant password age for org ${orgId}`, `[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant password age for org ${orgId}`,
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED, OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED,
@@ -209,7 +212,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
) { ) {
logger.warn( logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant session length for org ${orgId}`, `[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant session length for org ${orgId}`,
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED, OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED,
@@ -219,7 +222,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} else if (policyCheck.policies?.requiredTwoFactor === false) { } else if (policyCheck.policies?.requiredTwoFactor === false) {
logger.warn( logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}`, `[handleOlmRegisterMessage] Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}`,
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED, OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED,
@@ -229,7 +232,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} else if (!policyCheck.allowed) { } else if (!policyCheck.allowed) {
logger.warn( logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`, `[handleOlmRegisterMessage] Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`,
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId); sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
return; return;
@@ -253,7 +256,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// Prepare an array to store site configurations // Prepare an array to store site configurations
logger.debug( logger.debug(
`[handleOlmRegisterMessage] Found ${sitesCount} sites for client ${client.clientId}`, `[handleOlmRegisterMessage] Found ${sitesCount} sites for client ${client.clientId}`,
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
let jitMode = false; let jitMode = false;
@@ -263,19 +266,20 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites // If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
logger.info( logger.info(
`[handleOlmRegisterMessage] Too many sites (${sitesCount}), dropping into JIT mode`, `[handleOlmRegisterMessage] Too many sites (${sitesCount}), dropping into JIT mode`,
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
jitMode = true; jitMode = true;
} }
logger.debug( logger.debug(
`[handleOlmRegisterMessage] Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`, `[handleOlmRegisterMessage] Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`,
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
if (!publicKey) { if (!publicKey) {
logger.warn("[handleOlmRegisterMessage] Public key not provided", { logger.warn("[handleOlmRegisterMessage] Public key not provided", {
orgId: client.orgId orgId: client.orgId,
clientId: client.clientId
}); });
return; return;
} }
@@ -283,7 +287,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (client.pubKey !== publicKey || client.archived) { if (client.pubKey !== publicKey || client.archived) {
logger.info( logger.info(
"[handleOlmRegisterMessage] Public key mismatch. Updating public key and clearing session info...", "[handleOlmRegisterMessage] Public key mismatch. Updating public key and clearing session info...",
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
// Update the client's public key // Update the client's public key
await db await db
@@ -301,7 +305,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
isRelayed: relay == true, isRelayed: relay == true,
isJitMode: jitMode isJitMode: jitMode
}) })
.where(eq(clientSitesAssociationsCache.clientId, client.clientId)); .where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
or(
ne(
clientSitesAssociationsCache.isRelayed,
relay == true
),
ne(clientSitesAssociationsCache.isJitMode, jitMode)
)
)
);
} }
// this prevents us from accepting a register from an olm that has not hole punched yet. // this prevents us from accepting a register from an olm that has not hole punched yet.
@@ -310,7 +325,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) { if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) {
logger.warn( logger.warn(
`[handleOlmRegisterMessage] Client last hole punch is too old and we have sites to send; skipping this register. The client is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`, `[handleOlmRegisterMessage] Client last hole punch is too old and we have sites to send; skipping this register. The client is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`,
{ orgId: client.orgId } { orgId: client.orgId, clientId: client.clientId }
); );
return; return;
} }