diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index 031cd23e..e0564fd6 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -150,7 +150,7 @@ export async function updateAndGenerateEndpointDestinations( throw new Error("Olm not found"); } - const [client] = await db + const [updatedClient] = await db .update(clients) .set({ lastHolePunch: timestamp @@ -158,10 +158,16 @@ export async function updateAndGenerateEndpointDestinations( .where(eq(clients.clientId, olm.clientId)) .returning(); - if (await checkExitNodeOrg(exitNode.exitNodeId, client.orgId) && checkOrg) { + if ( + (await checkExitNodeOrg( + exitNode.exitNodeId, + updatedClient.orgId + )) && + checkOrg + ) { // not allowed logger.warn( - `Exit node ${exitNode.exitNodeId} is not allowed for org ${client.orgId}` + `Exit node ${exitNode.exitNodeId} is not allowed for org ${updatedClient.orgId}` ); throw new Error("Exit node not allowed"); } @@ -171,10 +177,14 @@ export async function updateAndGenerateEndpointDestinations( .select({ siteId: sites.siteId, subnet: sites.subnet, - listenPort: sites.listenPort + listenPort: sites.listenPort, + endpoint: clientSitesAssociationsCache.endpoint }) .from(sites) - .innerJoin(clientSitesAssociationsCache, eq(sites.siteId, clientSitesAssociationsCache.siteId)) + .innerJoin( + clientSitesAssociationsCache, + eq(sites.siteId, clientSitesAssociationsCache.siteId) + ) .where( and( eq(sites.exitNodeId, exitNode.exitNodeId), @@ -188,7 +198,7 @@ export async function updateAndGenerateEndpointDestinations( `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}` ); - await db + const [updatedClientSitesAssociationsCache] = await db .update(clientSitesAssociationsCache) .set({ endpoint: `${ip}:${port}` @@ -198,13 +208,27 @@ export async function updateAndGenerateEndpointDestinations( eq(clientSitesAssociationsCache.clientId, olm.clientId), eq(clientSitesAssociationsCache.siteId, site.siteId) ) + ) + .returning(); + + if ( + updatedClientSitesAssociationsCache.endpoint !== site.endpoint // this is the endpoint from the join table not the site + ) { + logger.info( + `ClientSitesAssociationsCache for client ${olm.clientId} and site ${site.siteId} endpoint changed from ${site.endpoint} to ${updatedClientSitesAssociationsCache.endpoint}` ); + // Handle any additional logic for endpoint change + handleClientEndpointChange( + olm.clientId, + updatedClientSitesAssociationsCache.endpoint! + ); + } } logger.debug( `Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}` ); - if (!client) { + if (!updatedClient) { logger.warn(`Client not found for olm: ${olmId}`); throw new Error("Client not found"); } @@ -253,7 +277,10 @@ export async function updateAndGenerateEndpointDestinations( .where(eq(sites.siteId, newt.siteId)) .limit(1); - if (await checkExitNodeOrg(exitNode.exitNodeId, site.orgId) && checkOrg) { + if ( + (await checkExitNodeOrg(exitNode.exitNodeId, site.orgId)) && + checkOrg + ) { // not allowed logger.warn( `Exit node ${exitNode.exitNodeId} is not allowed for org ${site.orgId}` @@ -273,6 +300,14 @@ export async function updateAndGenerateEndpointDestinations( .where(eq(sites.siteId, newt.siteId)) .returning(); + if (updatedSite.endpoint != site.endpoint) { + logger.info( + `Site ${newt.siteId} endpoint changed from ${site.endpoint} to ${updatedSite.endpoint}` + ); + // Handle any additional logic for endpoint change + handleSiteEndpointChange(newt.siteId, updatedSite.endpoint!); + } + if (!updatedSite || !updatedSite.subnet) { logger.warn(`Site not found: ${newt.siteId}`); throw new Error("Site not found"); @@ -326,3 +361,12 @@ export async function updateAndGenerateEndpointDestinations( } return destinations; } + +function handleSiteEndpointChange(siteId: number, newEndpoint: string) { + // just alert all of the clients connected to this site that the endpoint has changed but only if they are NOT relayed +} + +function handleClientEndpointChange(clientId: number, newEndpoint: string) { + // just alert all of the sites connected to this client that the endpoint has changed but only if they are NOT relayed + +}