From 73b0411e1c864244c2354cf6cb9313b27bd481ff Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 24 Nov 2025 20:43:26 -0500 Subject: [PATCH 01/14] Add alias config --- server/db/pg/schema/schema.ts | 5 +- server/db/sqlite/schema/schema.ts | 4 +- server/lib/ip.ts | 84 ++++++++++++++++++- server/lib/rebuildClientAssociations.ts | 15 ++-- server/routers/client/targets.ts | 29 ++++--- server/routers/newt/handleGetConfigMessage.ts | 1 + .../routers/olm/handleOlmRegisterMessage.ts | 55 ++++++++---- .../siteResource/createSiteResource.ts | 8 +- .../siteResource/updateSiteResource.ts | 16 ++-- 9 files changed, 176 insertions(+), 41 deletions(-) diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index 8ab1b24c..d15676a0 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -11,6 +11,7 @@ import { } from "drizzle-orm/pg-core"; import { InferSelectModel } from "drizzle-orm"; import { randomUUID } from "crypto"; +import { alias } from "yargs"; export const domains = pgTable("domains", { domainId: varchar("domainId").primaryKey(), @@ -40,6 +41,7 @@ export const orgs = pgTable("orgs", { orgId: varchar("orgId").primaryKey(), name: varchar("name").notNull(), subnet: varchar("subnet"), + utilitySubnet: varchar("utilitySubnet"), // this is the subnet for utility addresses createdAt: text("createdAt"), requireTwoFactor: boolean("requireTwoFactor"), maxSessionLengthHours: integer("maxSessionLengthHours"), @@ -209,7 +211,8 @@ export const siteResources = pgTable("siteResources", { destinationPort: integer("destinationPort"), // only for port mode destination: varchar("destination").notNull(), // ip, cidr, hostname; validate against the mode enabled: boolean("enabled").notNull().default(true), - alias: varchar("alias") + alias: varchar("alias"), + aliasAddress: varchar("aliasAddress") }); export const clientSiteResources = pgTable("clientSiteResources", { diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index cfffdba7..634afd36 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -32,6 +32,7 @@ export const orgs = sqliteTable("orgs", { orgId: text("orgId").primaryKey(), name: text("name").notNull(), subnet: text("subnet"), + utilitySubnet: text("utilitySubnet"), // this is the subnet for utility addresses createdAt: text("createdAt"), requireTwoFactor: integer("requireTwoFactor", { mode: "boolean" }), maxSessionLengthHours: integer("maxSessionLengthHours"), // hours @@ -230,7 +231,8 @@ export const siteResources = sqliteTable("siteResources", { destinationPort: integer("destinationPort"), // only for port mode destination: text("destination").notNull(), // ip, cidr, hostname enabled: integer("enabled", { mode: "boolean" }).notNull().default(true), - alias: text("alias") + alias: text("alias"), + aliasAddress: text("aliasAddress") }); export const clientSiteResources = sqliteTable("clientSiteResources", { diff --git a/server/lib/ip.ts b/server/lib/ip.ts index d530e2f0..7835ad84 100644 --- a/server/lib/ip.ts +++ b/server/lib/ip.ts @@ -1,4 +1,10 @@ -import { clientSitesAssociationsCache, db, SiteResource, Transaction } from "@server/db"; +import { + clientSitesAssociationsCache, + db, + SiteResource, + siteResources, + Transaction +} from "@server/db"; import { clients, orgs, sites } from "@server/db"; import { and, eq, isNotNull } from "drizzle-orm"; import config from "@server/lib/config"; @@ -281,6 +287,56 @@ export async function getNextAvailableClientSubnet( return subnet; } +export async function getNextAvailableAliasAddress( + orgId: string +): Promise { + const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId)); + + if (!org) { + throw new Error(`Organization with ID ${orgId} not found`); + } + + if (!org.subnet) { + throw new Error(`Organization with ID ${orgId} has no subnet defined`); + } + + if (!org.utilitySubnet) { + throw new Error( + `Organization with ID ${orgId} has no utility subnet defined` + ); + } + + const existingAddresses = await db + .select({ + aliasAddress: siteResources.aliasAddress + }) + .from(siteResources) + .where( + and( + isNotNull(siteResources.aliasAddress), + eq(siteResources.orgId, orgId) + ) + ); + + const addresses = [ + ...existingAddresses.map( + (site) => `${site.aliasAddress?.split("/")[0]}/32` + ), + // reserve a /29 for the dns server and other stuff + `${org.utilitySubnet.split("/")[0]}/29` + ].filter((address) => address !== null) as string[]; + + let subnet = findNextAvailableCidr(addresses, 32, org.utilitySubnet); + if (!subnet) { + throw new Error("No available subnets remaining in space"); + } + + // remove the cidr + subnet = subnet.split("/")[0]; + + return subnet; +} + export async function getNextAvailableOrgSubnet(): Promise { const existingAddresses = await db .select({ @@ -303,7 +359,9 @@ export async function getNextAvailableOrgSubnet(): Promise { return subnet; } -export function generateRemoteSubnets(allSiteResources: SiteResource[]): string[] { +export function generateRemoteSubnets( + allSiteResources: SiteResource[] +): string[] { let remoteSubnets = allSiteResources .filter((sr) => { if (sr.mode === "cidr") return true; @@ -327,6 +385,18 @@ export function generateRemoteSubnets(allSiteResources: SiteResource[]): string[ return Array.from(new Set(remoteSubnets)); } +export type Alias = { alias: string | null; aliasAddress: string | null }; + +export function generateAliasConfig(allSiteResources: SiteResource[]): Alias[] { + let aliasConfigs = allSiteResources + .filter((sr) => sr.alias && sr.aliasAddress && sr.mode == "host") + .map((sr) => ({ + alias: sr.alias, + aliasAddress: sr.aliasAddress + })); + return aliasConfigs; +} + export type SubnetProxyTarget = { sourcePrefix: string; destPrefix: string; @@ -372,6 +442,14 @@ export function generateSubnetProxyTargets( destPrefix: `${siteResource.destination}/32` }); } + + if (siteResource.alias && siteResource.aliasAddress) { + // also push a match for the alias address + targets.push({ + sourcePrefix: clientPrefix, + destPrefix: `${siteResource.aliasAddress}/32` + }); + } } else if (siteResource.mode == "cidr") { targets.push({ sourcePrefix: clientPrefix, @@ -386,4 +464,4 @@ export function generateSubnetProxyTargets( ); return targets; -} \ No newline at end of file +} diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index 2773e098..a1072196 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -31,14 +31,15 @@ import { import { sendToExitNode } from "#dynamic/lib/exitNodes"; import logger from "@server/logger"; import { + generateAliasConfig, generateRemoteSubnets, generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip"; import { - addRemoteSubnets, + addPeerData, addTargets as addSubnetProxyTargets, - removeRemoteSubnets, + removePeerData, removeTargets as removeSubnetProxyTargets } from "@server/routers/client/targets"; @@ -703,10 +704,11 @@ async function handleSubnetProxyTargetUpdates( for (const client of addedClients) { olmJobs.push( - addRemoteSubnets( + addPeerData( client.clientId, siteResource.siteId, - generateRemoteSubnets([siteResource]) + generateRemoteSubnets([siteResource]), + generateAliasConfig([siteResource]) ) ); } @@ -738,10 +740,11 @@ async function handleSubnetProxyTargetUpdates( for (const client of removedClients) { olmJobs.push( - removeRemoteSubnets( + removePeerData( client.clientId, siteResource.siteId, - generateRemoteSubnets([siteResource]) + generateRemoteSubnets([siteResource]), + generateAliasConfig([siteResource]) ) ); } diff --git a/server/routers/client/targets.ts b/server/routers/client/targets.ts index c94cb680..b5684436 100644 --- a/server/routers/client/targets.ts +++ b/server/routers/client/targets.ts @@ -1,6 +1,6 @@ import { sendToClient } from "#dynamic/routers/ws"; import { db, olms } from "@server/db"; -import { SubnetProxyTarget } from "@server/lib/ip"; +import { Alias, SubnetProxyTarget } from "@server/lib/ip"; import { eq } from "drizzle-orm"; export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) { @@ -33,10 +33,11 @@ export async function updateTargets( }); } -export async function addRemoteSubnets( +export async function addPeerData( clientId: number, siteId: number, remoteSubnets: string[], + aliases: Alias[], olmId?: string ) { if (!olmId) { @@ -52,18 +53,20 @@ export async function addRemoteSubnets( } await sendToClient(olmId, { - type: `olm/wg/peer/add-remote-subnets`, + type: `olm/wg/peer/data/add`, data: { siteId: siteId, - remoteSubnets: remoteSubnets + remoteSubnets: remoteSubnets, + aliases: aliases } }); } -export async function removeRemoteSubnets( +export async function removePeerData( clientId: number, siteId: number, remoteSubnets: string[], + aliases: Alias[], olmId?: string ) { if (!olmId) { @@ -79,21 +82,26 @@ export async function removeRemoteSubnets( } await sendToClient(olmId, { - type: `olm/wg/peer/remove-remote-subnets`, + type: `olm/wg/peer/data/remove`, data: { siteId: siteId, - remoteSubnets: remoteSubnets + remoteSubnets: remoteSubnets, + aliases: aliases } }); } -export async function updateRemoteSubnets( +export async function updatePeerData( clientId: number, siteId: number, remoteSubnets: { oldRemoteSubnets: string[], newRemoteSubnets: string[] }, + aliases: { + oldAliases: Alias[], + newAliases: Alias[] + }, olmId?: string ) { if (!olmId) { @@ -109,10 +117,11 @@ export async function updateRemoteSubnets( } await sendToClient(olmId, { - type: `olm/wg/peer/update-remote-subnets`, + type: `olm/wg/peer/data/update`, data: { siteId: siteId, - ...remoteSubnets + ...remoteSubnets, + ...aliases } }); } diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts index 68116686..fbbcb4fb 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -275,6 +275,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { resource, resourceClients ); + targetsToSend.push(...resourceTargets); } diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 5c438e4f..048e6baa 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -3,6 +3,7 @@ import { clientSiteResourcesAssociationsCache, db, ExitNode, + Org, orgs, roleClients, roles, @@ -25,7 +26,10 @@ import { and, eq, inArray, isNull } from "drizzle-orm"; import { addPeer, deletePeer } from "../newt/peers"; import logger from "@server/logger"; import { listExitNodes } from "#dynamic/lib/exitNodes"; -import { getNextAvailableClientSubnet } from "@server/lib/ip"; +import { + generateAliasConfig, + getNextAvailableClientSubnet +} from "@server/lib/ip"; import { generateRemoteSubnets } from "@server/lib/ip"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { @@ -42,18 +46,24 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient } = message.data; - let client: Client; + + let client: Client | undefined; + let org: Org | undefined; if (orgId) { try { - client = await getOrCreateOrgClient( - orgId, - olm.userId, - olm.olmId, - olm.name || "User Device", - // doNotCreateNewClient ? true : false - true // for now never create a new client automatically because we create the users clients when they are added to the org - ); + const { client: clientRes, org: orgRes } = + await getOrCreateOrgClient( + orgId, + olm.userId, + olm.olmId, + olm.name || "User Device", + // doNotCreateNewClient ? true : false + true // for now never create a new client automatically because we create the users clients when they are added to the org + ); + + client = clientRes; + org = orgRes; } catch (err) { logger.error( `Error switching olm client ${olm.olmId} to org ${orgId}: ${err}` @@ -96,6 +106,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } + if (!org) { + logger.warn("Org not found"); + return; + } + logger.debug( `Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}` ); @@ -302,7 +317,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { publicKey: site.publicKey, serverIP: site.address, serverPort: site.listenPort, - remoteSubnets: generateRemoteSubnets(allSiteResources.map(({ siteResources }) => siteResources)) + remoteSubnets: generateRemoteSubnets( + allSiteResources.map(({ siteResources }) => siteResources) + ), + aliases: generateAliasConfig( + allSiteResources.map(({ siteResources }) => siteResources) + ) }); } @@ -318,7 +338,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { type: "olm/wg/connect", data: { sites: siteConfigurations, - tunnelIP: client.subnet + tunnelIP: client.subnet, + utilitySubnet: org.utilitySubnet } }, broadcast: false, @@ -333,7 +354,10 @@ async function getOrCreateOrgClient( name: string, doNotCreateNewClient: boolean, trx: Transaction | typeof db = db -): Promise { +): Promise<{ + client: Client; + org: Org; +}> { // get the org const [org] = await trx .select() @@ -441,5 +465,8 @@ async function getOrCreateOrgClient( client = newClient; } - return client; + return { + client: client, + org: org + }; } diff --git a/server/routers/siteResource/createSiteResource.ts b/server/routers/siteResource/createSiteResource.ts index 2c7bf0fe..ecbb7768 100644 --- a/server/routers/siteResource/createSiteResource.ts +++ b/server/routers/siteResource/createSiteResource.ts @@ -18,6 +18,7 @@ import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; import { getUniqueSiteResourceName } from "@server/db/names"; import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { getNextAvailableAliasAddress } from "@server/lib/ip"; const createSiteResourceParamsSchema = z.strictObject({ siteId: z.string().transform(Number).pipe(z.int().positive()), @@ -193,6 +194,10 @@ export async function createSiteResource( // } const niceId = await getUniqueSiteResourceName(orgId); + let aliasAddress: string | null = null; + if (mode == "host") { // we can only have an alias on a host + aliasAddress = await getNextAvailableAliasAddress(orgId); + } let newSiteResource: SiteResource | undefined; await db.transaction(async (trx) => { @@ -210,7 +215,8 @@ export async function createSiteResource( // destinationPort: mode === "port" ? destinationPort : null, destination, enabled, - alias: alias || null + alias, + aliasAddress }) .returning(); diff --git a/server/routers/siteResource/updateSiteResource.ts b/server/routers/siteResource/updateSiteResource.ts index 2e2c1592..d66d2cb8 100644 --- a/server/routers/siteResource/updateSiteResource.ts +++ b/server/routers/siteResource/updateSiteResource.ts @@ -17,11 +17,9 @@ import { eq, and, ne } from "drizzle-orm"; import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; +import { updatePeerData, updateTargets } from "@server/routers/client/targets"; import { - updateRemoteSubnets, - updateTargets -} from "@server/routers/client/targets"; -import { + generateAliasConfig, generateRemoteSubnets, generateSubnetProxyTargets } from "@server/lib/ip"; @@ -266,7 +264,7 @@ export async function updateSiteResource( for (const client of mergedAllClients) { // we also need to update the remote subnets on the olms for each client that has access to this site olmJobs.push( - updateRemoteSubnets( + updatePeerData( client.clientId, updatedSiteResource.siteId, { @@ -276,6 +274,14 @@ export async function updateSiteResource( newRemoteSubnets: generateRemoteSubnets([ updatedSiteResource ]) + }, + { + oldAliases: generateAliasConfig([ + existingSiteResource + ]), + newAliases: generateAliasConfig([ + updatedSiteResource + ]) } ) ); From ceae787cf5ca9fbf6f52645bf53dc37366ff6404 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 25 Nov 2025 18:20:02 -0500 Subject: [PATCH 02/14] Attempt to handle creating/deleting clients and role --- server/lib/calculateUserClientsForOrgs.ts | 54 +- server/lib/ip.ts | 8 +- server/lib/rebuildClientAssociations.ts | 551 +++++++++++++++++- server/routers/auth/securityKey.ts | 2 +- server/routers/client/createClient.ts | 36 +- server/routers/client/createUserClient.ts | 20 +- server/routers/client/deleteClient.ts | 36 +- server/routers/client/terminate.ts | 22 + server/routers/olm/deleteUserOlm.ts | 32 +- server/routers/olm/getUserOlm.ts | 2 +- .../routers/olm/handleOlmRegisterMessage.ts | 17 +- .../siteResource/addClientToSiteResource.ts | 4 +- .../siteResource/addRoleToSiteResource.ts | 4 +- .../siteResource/addUserToSiteResource.ts | 4 +- .../siteResource/createSiteResource.ts | 4 +- .../siteResource/deleteSiteResource.ts | 4 +- .../removeClientFromSiteResource.ts | 4 +- .../removeRoleFromSiteResource.ts | 4 +- .../removeUserFromSiteResource.ts | 4 +- .../siteResource/setSiteResourceClients.ts | 4 +- .../siteResource/setSiteResourceRoles.ts | 4 +- .../siteResource/setSiteResourceUsers.ts | 4 +- .../siteResource/updateSiteResource.ts | 4 +- server/routers/user/addUserRole.ts | 50 +- server/routers/user/adminRemoveUser.ts | 11 +- 25 files changed, 778 insertions(+), 111 deletions(-) create mode 100644 server/routers/client/terminate.ts diff --git a/server/lib/calculateUserClientsForOrgs.ts b/server/lib/calculateUserClientsForOrgs.ts index f66e3888..4cde8657 100644 --- a/server/lib/calculateUserClientsForOrgs.ts +++ b/server/lib/calculateUserClientsForOrgs.ts @@ -1,8 +1,20 @@ -import { clients, clientSitesAssociationsCache, db, olms, orgs, roleClients, roles, userClients, userOrgs, Transaction } from "@server/db"; +import { + clients, + db, + olms, + orgs, + roleClients, + roles, + userClients, + userOrgs, + Transaction +} from "@server/db"; import { eq, and, notInArray } from "drizzle-orm"; import { listExitNodes } from "#dynamic/lib/exitNodes"; import { getNextAvailableClientSubnet } from "@server/lib/ip"; import logger from "@server/logger"; +import { rebuildClientAssociationsFromClient } from "./rebuildClientAssociations"; +import { sendTerminateClient } from "@server/routers/client/terminate"; export async function calculateUserClientsForOrgs( userId: string, @@ -88,7 +100,10 @@ export async function calculateUserClientsForOrgs( .where( and( eq(roleClients.roleId, adminRole.roleId), - eq(roleClients.clientId, existingClient.clientId) + eq( + roleClients.clientId, + existingClient.clientId + ) ) ) .limit(1); @@ -110,7 +125,10 @@ export async function calculateUserClientsForOrgs( .where( and( eq(userClients.userId, userId), - eq(userClients.clientId, existingClient.clientId) + eq( + userClients.clientId, + existingClient.clientId + ) ) ) .limit(1); @@ -172,6 +190,11 @@ export async function calculateUserClientsForOrgs( }) .returning(); + await rebuildClientAssociationsFromClient( + newClient, + transaction + ); + // Grant admin role access to the client await transaction.insert(roleClients).values({ roleId: adminRole.roleId, @@ -225,15 +248,8 @@ async function cleanupOrphanedClients( : and(eq(clients.userId, userId)) ); - // Delete client-site associations first, then delete the clients - for (const client of clientsToDelete) { - await trx - .delete(clientSitesAssociationsCache) - .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); - } - if (clientsToDelete.length > 0) { - await trx + const deletedClients = await trx .delete(clients) .where( userOrgIds.length > 0 @@ -242,7 +258,20 @@ async function cleanupOrphanedClients( notInArray(clients.orgId, userOrgIds) ) : and(eq(clients.userId, userId)) - ); + ) + .returning(); + + // Rebuild associations for each deleted client to clean up related data + for (const deletedClient of deletedClients) { + await rebuildClientAssociationsFromClient(deletedClient, trx); + + if (deletedClient.olmId) { + await sendTerminateClient( + deletedClient.clientId, + deletedClient.olmId + ); + } + } if (userOrgIds.length === 0) { logger.debug( @@ -255,4 +284,3 @@ async function cleanupOrphanedClients( } } } - diff --git a/server/lib/ip.ts b/server/lib/ip.ts index 7835ad84..4a02694a 100644 --- a/server/lib/ip.ts +++ b/server/lib/ip.ts @@ -398,8 +398,9 @@ export function generateAliasConfig(allSiteResources: SiteResource[]): Alias[] { } export type SubnetProxyTarget = { - sourcePrefix: string; - destPrefix: string; + sourcePrefix: string; // must be a cidr + destPrefix: string; // must be a cidr + rewriteTo?: string; // must be a cidr portRange?: { min: number; max: number; @@ -447,7 +448,8 @@ export function generateSubnetProxyTargets( // also push a match for the alias address targets.push({ sourcePrefix: clientPrefix, - destPrefix: `${siteResource.aliasAddress}/32` + destPrefix: `${siteResource.aliasAddress}/32`, + rewriteTo: `${siteResource.destination}/32` }); } } else if (siteResource.mode == "cidr") { diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index a1072196..f1cbea0c 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -129,7 +129,7 @@ export async function getClientSiteResourceAccess( }; } -export async function rebuildClientAssociations( +export async function rebuildClientAssociationsFromSiteResource( siteResource: SiteResource, trx: Transaction | typeof db = db ): Promise<{ @@ -753,3 +753,552 @@ async function handleSubnetProxyTargetUpdates( await Promise.all(proxyJobs); } + +export async function rebuildClientAssociationsFromClient( + client: Client, + trx: Transaction | typeof db = db +): Promise { + let newSiteResourceIds: number[] = []; + + // 1. Direct client associations + const directSiteResources = await trx + .select({ siteResourceId: clientSiteResources.siteResourceId }) + .from(clientSiteResources) + .where(eq(clientSiteResources.clientId, client.clientId)); + + newSiteResourceIds.push( + ...directSiteResources.map((r) => r.siteResourceId) + ); + + // 2. User-based and role-based access (if client has a userId) + if (client.userId) { + // Direct user associations + const userSiteResourceIds = await trx + .select({ siteResourceId: userSiteResources.siteResourceId }) + .from(userSiteResources) + .where(eq(userSiteResources.userId, client.userId)); + + newSiteResourceIds.push( + ...userSiteResourceIds.map((r) => r.siteResourceId) + ); + + // Role-based access + const roleIds = await trx + .select({ roleId: userOrgs.roleId }) + .from(userOrgs) + .where(eq(userOrgs.userId, client.userId)) + .then((rows) => rows.map((row) => row.roleId)); + + if (roleIds.length > 0) { + const roleSiteResourceIds = await trx + .select({ siteResourceId: roleSiteResources.siteResourceId }) + .from(roleSiteResources) + .where(inArray(roleSiteResources.roleId, roleIds)); + + newSiteResourceIds.push( + ...roleSiteResourceIds.map((r) => r.siteResourceId) + ); + } + } + + // Remove duplicates + newSiteResourceIds = Array.from(new Set(newSiteResourceIds)); + + // Get full siteResource details + const newSiteResources = + newSiteResourceIds.length > 0 + ? await trx + .select() + .from(siteResources) + .where( + inArray(siteResources.siteResourceId, newSiteResourceIds) + ) + : []; + + // Group by siteId for site-level associations + const newSiteIds = Array.from( + new Set(newSiteResources.map((sr) => sr.siteId)) + ); + + /////////// Process client-siteResource associations /////////// + + // Get existing resource associations + const existingResourceAssociations = await trx + .select({ + siteResourceId: clientSiteResourcesAssociationsCache.siteResourceId + }) + .from(clientSiteResourcesAssociationsCache) + .where( + eq(clientSiteResourcesAssociationsCache.clientId, client.clientId) + ); + + const existingSiteResourceIds = existingResourceAssociations.map( + (r) => r.siteResourceId + ); + + const resourcesToAdd = newSiteResourceIds.filter( + (id) => !existingSiteResourceIds.includes(id) + ); + + const resourcesToRemove = existingSiteResourceIds.filter( + (id) => !newSiteResourceIds.includes(id) + ); + + // Insert new associations + if (resourcesToAdd.length > 0) { + await trx.insert(clientSiteResourcesAssociationsCache).values( + resourcesToAdd.map((siteResourceId) => ({ + clientId: client.clientId, + siteResourceId + })) + ); + } + + // Remove old associations + if (resourcesToRemove.length > 0) { + await trx + .delete(clientSiteResourcesAssociationsCache) + .where( + and( + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ), + inArray( + clientSiteResourcesAssociationsCache.siteResourceId, + resourcesToRemove + ) + ) + ); + } + + /////////// Process client-site associations /////////// + + // Get existing site associations + const existingSiteAssociations = await trx + .select({ siteId: clientSitesAssociationsCache.siteId }) + .from(clientSitesAssociationsCache) + .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); + + const existingSiteIds = existingSiteAssociations.map((s) => s.siteId); + + const sitesToAdd = newSiteIds.filter((id) => !existingSiteIds.includes(id)); + const sitesToRemove = existingSiteIds.filter( + (id) => !newSiteIds.includes(id) + ); + + // Insert new site associations + if (sitesToAdd.length > 0) { + await trx.insert(clientSitesAssociationsCache).values( + sitesToAdd.map((siteId) => ({ + clientId: client.clientId, + siteId + })) + ); + } + + // Remove old site associations + if (sitesToRemove.length > 0) { + await trx + .delete(clientSitesAssociationsCache) + .where( + and( + eq(clientSitesAssociationsCache.clientId, client.clientId), + inArray(clientSitesAssociationsCache.siteId, sitesToRemove) + ) + ); + } + + /////////// Send messages /////////// + + // Get the olm for this client + const [olm] = await trx + .select({ olmId: olms.olmId }) + .from(olms) + .where(eq(olms.clientId, client.clientId)) + .limit(1); + + if (!olm) { + logger.warn( + `Olm not found for client ${client.clientId}, skipping peer updates` + ); + return; + } + + // Handle messages for sites being added + await handleMessagesForClientSites( + client, + olm.olmId, + sitesToAdd, + sitesToRemove, + trx + ); + + // Handle subnet proxy target updates for resources + await handleMessagesForClientResources( + client, + newSiteResources, + resourcesToAdd, + resourcesToRemove, + trx + ); +} + +async function handleMessagesForClientSites( + client: { + clientId: number; + pubKey: string | null; + subnet: string | null; + userId: string | null; + orgId: string; + }, + olmId: string, + sitesToAdd: number[], + sitesToRemove: number[], + trx: Transaction | typeof db = db +): Promise { + if (!client.subnet || !client.pubKey) { + logger.warn( + `Client ${client.clientId} missing subnet or pubKey, skipping peer updates` + ); + return; + } + + const allSiteIds = [...sitesToAdd, ...sitesToRemove]; + if (allSiteIds.length === 0) { + return; + } + + // Get site details for all affected sites + const sitesData = await trx + .select() + .from(sites) + .leftJoin(exitNodes, eq(sites.exitNodeId, exitNodes.exitNodeId)) + .leftJoin(newts, eq(sites.siteId, newts.siteId)) + .where(inArray(sites.siteId, allSiteIds)); + + let newtJobs: Promise[] = []; + let olmJobs: Promise[] = []; + let exitNodeJobs: Promise[] = []; + + for (const siteData of sitesData) { + const site = siteData.sites; + const exitNode = siteData.exitNodes; + const newt = siteData.newt; + + if (!site.publicKey) { + logger.warn( + `Site ${site.siteId} missing publicKey, skipping peer updates` + ); + continue; + } + + if (!newt) { + logger.warn( + `Newt not found for site ${site.siteId}, skipping peer updates` + ); + continue; + } + + const isAdd = sitesToAdd.includes(site.siteId); + const isRemove = sitesToRemove.includes(site.siteId); + + if (isRemove) { + // Remove peer from newt + newtJobs.push( + newtDeletePeer(site.siteId, client.pubKey, newt.newtId) + ); + try { + // Remove peer from olm + olmJobs.push( + olmDeletePeer( + client.clientId, + site.siteId, + site.publicKey, + olmId + ) + ); + } catch (error) { + // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send + if ( + error instanceof Error && + error.message.includes("not found") + ) { + logger.debug( + `Olm data not found for client ${client.clientId}, skipping removal` + ); + } else { + throw error; + } + } + } + + if (isAdd) { + if (!exitNode) { + logger.warn( + `Exit node not found for site ${site.siteId}, skipping peer add` + ); + continue; + } + + // Add peer to newt + const isRelayed = true; // Default to relaying for new connections + newtJobs.push( + newtAddPeer( + site.siteId, + { + publicKey: client.pubKey, + allowedIps: [`${client.subnet.split("/")[0]}/32`], + endpoint: isRelayed ? "" : "" + }, + newt.newtId + ) + ); + + // Get all site resources for this site that the client has access to + const accessibleResources = await trx + .select() + .from(siteResources) + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + siteResources.siteResourceId, + clientSiteResourcesAssociationsCache.siteResourceId + ) + ) + .where( + and( + eq(siteResources.siteId, site.siteId), + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ) + ) + ); + try { + // Add peer to olm + olmJobs.push( + olmAddPeer( + client.clientId, + { + siteId: site.siteId, + endpoint: + isRelayed || !site.endpoint + ? `${exitNode.endpoint}:21820` + : site.endpoint, + publicKey: site.publicKey, + serverIP: site.address || "", + serverPort: site.listenPort || 0, + remoteSubnets: generateRemoteSubnets( + accessibleResources.map( + ({ siteResources }) => siteResources + ) + ) + }, + olmId + ) + ); + } catch (error) { + // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send + if ( + error instanceof Error && + error.message.includes("not found") + ) { + logger.debug( + `Olm data not found for client ${client.clientId}, skipping removal` + ); + } else { + throw error; + } + } + } + + // Update exit node destinations + exitNodeJobs.push( + updateClientSiteDestinations( + { + clientId: client.clientId, + pubKey: client.pubKey, + subnet: client.subnet + }, + trx + ) + ); + } + + await Promise.all(exitNodeJobs); + await Promise.all(newtJobs); + await Promise.all(olmJobs); +} + +async function handleMessagesForClientResources( + client: { + clientId: number; + pubKey: string | null; + subnet: string | null; + userId: string | null; + orgId: string; + }, + allNewResources: SiteResource[], + resourcesToAdd: number[], + resourcesToRemove: number[], + trx: Transaction | typeof db = db +): Promise { + // Group resources by site + const resourcesBySite = new Map(); + + for (const resource of allNewResources) { + if (!resourcesBySite.has(resource.siteId)) { + resourcesBySite.set(resource.siteId, []); + } + resourcesBySite.get(resource.siteId)!.push(resource); + } + + let proxyJobs: Promise[] = []; + let olmJobs: Promise[] = []; + + // Handle additions + if (resourcesToAdd.length > 0) { + const addedResources = allNewResources.filter((r) => + resourcesToAdd.includes(r.siteResourceId) + ); + + // Group by site for proxy updates + const addedBySite = new Map(); + for (const resource of addedResources) { + if (!addedBySite.has(resource.siteId)) { + addedBySite.set(resource.siteId, []); + } + addedBySite.get(resource.siteId)!.push(resource); + } + + // Add subnet proxy targets for each site + for (const [siteId, resources] of addedBySite.entries()) { + const [newt] = await trx + .select({ newtId: newts.newtId }) + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); + + if (!newt) { + logger.warn( + `Newt not found for site ${siteId}, skipping proxy updates` + ); + continue; + } + + for (const resource of resources) { + const targets = generateSubnetProxyTargets(resource, [ + { + clientId: client.clientId, + pubKey: client.pubKey, + subnet: client.subnet + } + ]); + + if (targets.length > 0) { + proxyJobs.push(addSubnetProxyTargets(newt.newtId, targets)); + } + + try { + // Add peer data to olm + olmJobs.push( + addPeerData( + client.clientId, + resource.siteId, + generateRemoteSubnets([resource]), + generateAliasConfig([resource]) + ) + ); + } catch (error) { + // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send + if ( + error instanceof Error && + error.message.includes("not found") + ) { + logger.debug( + `Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal` + ); + } else { + throw error; + } + } + } + } + } + + // Handle removals + if (resourcesToRemove.length > 0) { + const removedResources = await trx + .select() + .from(siteResources) + .where(inArray(siteResources.siteResourceId, resourcesToRemove)); + + // Group by site for proxy updates + const removedBySite = new Map(); + for (const resource of removedResources) { + if (!removedBySite.has(resource.siteId)) { + removedBySite.set(resource.siteId, []); + } + removedBySite.get(resource.siteId)!.push(resource); + } + + // Remove subnet proxy targets for each site + for (const [siteId, resources] of removedBySite.entries()) { + const [newt] = await trx + .select({ newtId: newts.newtId }) + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); + + if (!newt) { + logger.warn( + `Newt not found for site ${siteId}, skipping proxy updates` + ); + continue; + } + + for (const resource of resources) { + const targets = generateSubnetProxyTargets(resource, [ + { + clientId: client.clientId, + pubKey: client.pubKey, + subnet: client.subnet + } + ]); + + if (targets.length > 0) { + proxyJobs.push( + removeSubnetProxyTargets(newt.newtId, targets) + ); + } + + try { + // Remove peer data from olm + olmJobs.push( + removePeerData( + client.clientId, + resource.siteId, + generateRemoteSubnets([resource]), + generateAliasConfig([resource]) + ) + ); + } catch (error) { + // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send + if ( + error instanceof Error && + error.message.includes("not found") + ) { + logger.debug( + `Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal` + ); + } else { + throw error; + } + } + } + } + } + + await Promise.all([...proxyJobs, ...olmJobs]); +} diff --git a/server/routers/auth/securityKey.ts b/server/routers/auth/securityKey.ts index cde2f61a..eed2328d 100644 --- a/server/routers/auth/securityKey.ts +++ b/server/routers/auth/securityKey.ts @@ -52,7 +52,7 @@ setInterval(async () => { await db .delete(webauthnChallenge) .where(lt(webauthnChallenge.expiresAt, now)); - logger.debug("Cleaned up expired security key challenges"); + // logger.debug("Cleaned up expired security key challenges"); } catch (error) { logger.error("Failed to clean up expired security key challenges", error); } diff --git a/server/routers/client/createClient.ts b/server/routers/client/createClient.ts index 908ea689..160006e1 100644 --- a/server/routers/client/createClient.ts +++ b/server/routers/client/createClient.ts @@ -24,18 +24,19 @@ import { isIpInCidr } from "@server/lib/ip"; import { listExitNodes } from "#dynamic/lib/exitNodes"; import { generateId } from "@server/auth/sessions/app"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; const createClientParamsSchema = z.strictObject({ - orgId: z.string() - }); + orgId: z.string() +}); const createClientSchema = z.strictObject({ - name: z.string().min(1).max(255), - olmId: z.string(), - secret: z.string(), - subnet: z.string(), - type: z.enum(["olm"]) - }); + name: z.string().min(1).max(255), + olmId: z.string(), + secret: z.string(), + subnet: z.string(), + type: z.enum(["olm"]) +}); export type CreateClientBody = z.infer; @@ -186,6 +187,7 @@ export async function createClient( ); } + let newClient: Client | null = null; await db.transaction(async (trx) => { // TODO: more intelligent way to pick the exit node const exitNodesList = await listExitNodes(orgId); @@ -204,7 +206,7 @@ export async function createClient( ); } - const [newClient] = await trx + [newClient] = await trx .insert(clients) .values({ exitNodeId: randomExitNode.exitNodeId, @@ -244,13 +246,15 @@ export async function createClient( dateCreated: moment().toISOString() }); - return response(res, { - data: newClient, - success: true, - error: false, - message: "Site created successfully", - status: HttpCode.CREATED - }); + await rebuildClientAssociationsFromClient(newClient, trx); + }); + + return response(res, { + data: newClient, + success: true, + error: false, + message: "Site created successfully", + status: HttpCode.CREATED }); } catch (error) { logger.error(error); diff --git a/server/routers/client/createUserClient.ts b/server/routers/client/createUserClient.ts index f49a0783..e5b5ea8f 100644 --- a/server/routers/client/createUserClient.ts +++ b/server/routers/client/createUserClient.ts @@ -21,6 +21,7 @@ import { isValidIP } from "@server/lib/validators"; import { isIpInCidr } from "@server/lib/ip"; import { listExitNodes } from "#dynamic/lib/exitNodes"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; const paramsSchema = z .object({ @@ -191,6 +192,7 @@ export async function createUserClient( ); } + let newClient: Client | null = null; await db.transaction(async (trx) => { // TODO: more intelligent way to pick the exit node const exitNodesList = await listExitNodes(orgId); @@ -209,7 +211,7 @@ export async function createUserClient( ); } - const [newClient] = await trx + [newClient] = await trx .insert(clients) .values({ exitNodeId: randomExitNode.exitNodeId, @@ -232,13 +234,15 @@ export async function createUserClient( clientId: newClient.clientId }); - return response(res, { - data: newClient, - success: true, - error: false, - message: "Site created successfully", - status: HttpCode.CREATED - }); + await rebuildClientAssociationsFromClient(newClient, trx); + }); + + return response(res, { + data: newClient, + success: true, + error: false, + message: "Site created successfully", + status: HttpCode.CREATED }); } catch (error) { logger.error(error); diff --git a/server/routers/client/deleteClient.ts b/server/routers/client/deleteClient.ts index 34019a53..775708ce 100644 --- a/server/routers/client/deleteClient.ts +++ b/server/routers/client/deleteClient.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; +import { db, olms } from "@server/db"; import { clients, clientSitesAssociationsCache } from "@server/db"; import { eq } from "drizzle-orm"; import response from "@server/lib/response"; @@ -9,10 +9,12 @@ import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { sendTerminateClient } from "./terminate"; const deleteClientSchema = z.strictObject({ - clientId: z.string().transform(Number).pipe(z.int().positive()) - }); + clientId: z.string().transform(Number).pipe(z.int().positive()) +}); registry.registerPath({ method: "delete", @@ -68,19 +70,27 @@ export async function deleteClient( } await db.transaction(async (trx) => { - // Delete the client-site associations first - await trx - .delete(clientSitesAssociationsCache) - .where(eq(clientSitesAssociationsCache.clientId, clientId)); - // Then delete the client itself - await trx.delete(clients).where(eq(clients.clientId, clientId)); + const [deletedClient] = await trx + .delete(clients) + .where(eq(clients.clientId, clientId)) + .returning(); - // this is a machine client + const [olm] = await trx + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + + // this is a machine client so we also delete the olm if (!client.userId && client.olmId) { - await trx - .delete(clients) - .where(eq(clients.olmId, client.olmId)); + await trx.delete(olms).where(eq(olms.olmId, client.olmId)); + } + + await rebuildClientAssociationsFromClient(deletedClient, trx); + + if (olm) { + await sendTerminateClient(deletedClient.clientId, olm.olmId); // the olmId needs to be provided because it cant look it up after deletion } }); diff --git a/server/routers/client/terminate.ts b/server/routers/client/terminate.ts new file mode 100644 index 00000000..dc49ef05 --- /dev/null +++ b/server/routers/client/terminate.ts @@ -0,0 +1,22 @@ +import { sendToClient } from "#dynamic/routers/ws"; +import { db, olms } from "@server/db"; +import { eq } from "drizzle-orm"; + +export async function sendTerminateClient(clientId: number, olmId?: string | null) { + if (!olmId) { + const [olm] = await db + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + olmId = olm.olmId; + } + + await sendToClient(olmId, { + type: `olm/terminate`, + data: {} + }); +} diff --git a/server/routers/olm/deleteUserOlm.ts b/server/routers/olm/deleteUserOlm.ts index 88e791db..83a3d16f 100644 --- a/server/routers/olm/deleteUserOlm.ts +++ b/server/routers/olm/deleteUserOlm.ts @@ -1,5 +1,5 @@ import { NextFunction, Request, Response } from "express"; -import { db } from "@server/db"; +import { Client, db } from "@server/db"; import { olms, clients, clientSitesAssociationsCache } from "@server/db"; import { eq } from "drizzle-orm"; import HttpCode from "@server/types/HttpCode"; @@ -9,6 +9,8 @@ import { z } from "zod"; import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { sendTerminateClient } from "../client/terminate"; const paramsSchema = z .object({ @@ -54,20 +56,30 @@ export async function deleteUserOlm( .from(clients) .where(eq(clients.olmId, olmId)); - // Delete client-site associations for each associated client - for (const client of associatedClients) { - await trx - .delete(clientSitesAssociationsCache) - .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); - } - + let deletedClient: Client | null = null; // Delete all associated clients if (associatedClients.length > 0) { - await trx.delete(clients).where(eq(clients.olmId, olmId)); + [deletedClient] = await trx + .delete(clients) + .where(eq(clients.olmId, olmId)) + .returning(); } // Finally, delete the OLM itself - await trx.delete(olms).where(eq(olms.olmId, olmId)); + const [olm] = await trx + .delete(olms) + .where(eq(olms.olmId, olmId)) + .returning(); + + if (deletedClient) { + await rebuildClientAssociationsFromClient(deletedClient, trx); + if (olm) { + await sendTerminateClient( + deletedClient.clientId, + olm.olmId + ); // the olmId needs to be provided because it cant look it up after deletion + } + } }); return response(res, { diff --git a/server/routers/olm/getUserOlm.ts b/server/routers/olm/getUserOlm.ts index 50b32fd8..aa9b89af 100644 --- a/server/routers/olm/getUserOlm.ts +++ b/server/routers/olm/getUserOlm.ts @@ -1,6 +1,6 @@ import { NextFunction, Request, Response } from "express"; import { db } from "@server/db"; -import { olms, clients, clientSites } from "@server/db"; +import { olms } from "@server/db"; import { eq, and } from "drizzle-orm"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 048e6baa..2ee5c120 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -31,6 +31,7 @@ import { getNextAvailableClientSubnet } from "@server/lib/ip"; import { generateRemoteSubnets } from "@server/lib/ip"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.info("Handling register olm message!"); @@ -60,6 +61,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { olm.name || "User Device", // doNotCreateNewClient ? true : false true // for now never create a new client automatically because we create the users clients when they are added to the org + // this means that the rebuildClientAssociationsFromClient call below issue is not a problem ); client = clientRes; @@ -99,6 +101,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .from(clients) .where(eq(clients.clientId, olm.clientId)) .limit(1); + + [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, client.orgId)) + .limit(1); } if (!client) { @@ -205,13 +213,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { `Found ${sitesData.length} sites for client ${client.clientId}` ); - if (sitesData.length === 0) { - sendToClient(olm.olmId, { - type: "olm/register/no-sites", - data: {} - }); - } - // Process each site for (const { sites: site } of sitesData) { if (!site.exitNodeId) { @@ -462,6 +463,8 @@ async function getOrCreateOrgClient( }); } + await rebuildClientAssociationsFromClient(newClient, trx); // TODO: this will try to messages to the olm which has not connected yet - is that a problem? + client = newClient; } diff --git a/server/routers/siteResource/addClientToSiteResource.ts b/server/routers/siteResource/addClientToSiteResource.ts index 8fb6afdc..587294e5 100644 --- a/server/routers/siteResource/addClientToSiteResource.ts +++ b/server/routers/siteResource/addClientToSiteResource.ts @@ -8,7 +8,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const addClientToSiteResourceBodySchema = z .object({ @@ -136,7 +136,7 @@ export async function addClientToSiteResource( siteResourceId }); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/addRoleToSiteResource.ts b/server/routers/siteResource/addRoleToSiteResource.ts index 859ca5be..542ca535 100644 --- a/server/routers/siteResource/addRoleToSiteResource.ts +++ b/server/routers/siteResource/addRoleToSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const addRoleToSiteResourceBodySchema = z .object({ @@ -146,7 +146,7 @@ export async function addRoleToSiteResource( siteResourceId }); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/addUserToSiteResource.ts b/server/routers/siteResource/addUserToSiteResource.ts index 411d37b4..c9d1f30a 100644 --- a/server/routers/siteResource/addUserToSiteResource.ts +++ b/server/routers/siteResource/addUserToSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const addUserToSiteResourceBodySchema = z .object({ @@ -115,7 +115,7 @@ export async function addUserToSiteResource( siteResourceId }); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/createSiteResource.ts b/server/routers/siteResource/createSiteResource.ts index ecbb7768..1d9cd6aa 100644 --- a/server/routers/siteResource/createSiteResource.ts +++ b/server/routers/siteResource/createSiteResource.ts @@ -17,7 +17,7 @@ import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; import { getUniqueSiteResourceName } from "@server/db/names"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; import { getNextAvailableAliasAddress } from "@server/lib/ip"; const createSiteResourceParamsSchema = z.strictObject({ @@ -278,7 +278,7 @@ export async function createSiteResource( ); } - await rebuildClientAssociations(newSiteResource, trx); // we need to call this because we added to the admin role + await rebuildClientAssociationsFromSiteResource(newSiteResource, trx); // we need to call this because we added to the admin role }); if (!newSiteResource) { diff --git a/server/routers/siteResource/deleteSiteResource.ts b/server/routers/siteResource/deleteSiteResource.ts index 75f2c3f2..a7175608 100644 --- a/server/routers/siteResource/deleteSiteResource.ts +++ b/server/routers/siteResource/deleteSiteResource.ts @@ -9,7 +9,7 @@ import { eq, and } from "drizzle-orm"; import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const deleteSiteResourceParamsSchema = z.strictObject({ siteResourceId: z.string().transform(Number).pipe(z.int().positive()), @@ -106,7 +106,7 @@ export async function deleteSiteResource( ); } - await rebuildClientAssociations(removedSiteResource, trx); + await rebuildClientAssociationsFromSiteResource(removedSiteResource, trx); }); logger.info( diff --git a/server/routers/siteResource/removeClientFromSiteResource.ts b/server/routers/siteResource/removeClientFromSiteResource.ts index d46e5d67..c6a5dfe8 100644 --- a/server/routers/siteResource/removeClientFromSiteResource.ts +++ b/server/routers/siteResource/removeClientFromSiteResource.ts @@ -8,7 +8,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const removeClientFromSiteResourceBodySchema = z .object({ @@ -142,7 +142,7 @@ export async function removeClientFromSiteResource( ) ); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/removeRoleFromSiteResource.ts b/server/routers/siteResource/removeRoleFromSiteResource.ts index c4c68e06..0041ed83 100644 --- a/server/routers/siteResource/removeRoleFromSiteResource.ts +++ b/server/routers/siteResource/removeRoleFromSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const removeRoleFromSiteResourceBodySchema = z .object({ @@ -151,7 +151,7 @@ export async function removeRoleFromSiteResource( ) ); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/removeUserFromSiteResource.ts b/server/routers/siteResource/removeUserFromSiteResource.ts index 8a90b752..280a01f2 100644 --- a/server/routers/siteResource/removeUserFromSiteResource.ts +++ b/server/routers/siteResource/removeUserFromSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const removeUserFromSiteResourceBodySchema = z .object({ @@ -121,7 +121,7 @@ export async function removeUserFromSiteResource( ) ); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/setSiteResourceClients.ts b/server/routers/siteResource/setSiteResourceClients.ts index 974b27cc..0a25b7e9 100644 --- a/server/routers/siteResource/setSiteResourceClients.ts +++ b/server/routers/siteResource/setSiteResourceClients.ts @@ -8,7 +8,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, inArray } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const setSiteResourceClientsBodySchema = z .object({ @@ -124,7 +124,7 @@ export async function setSiteResourceClients( .values(clientIds.map((clientId) => ({ clientId, siteResourceId }))); } - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/setSiteResourceRoles.ts b/server/routers/siteResource/setSiteResourceRoles.ts index df44e02b..7aa07de1 100644 --- a/server/routers/siteResource/setSiteResourceRoles.ts +++ b/server/routers/siteResource/setSiteResourceRoles.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and, ne, inArray } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const setSiteResourceRolesBodySchema = z .object({ @@ -147,7 +147,7 @@ export async function setSiteResourceRoles( .values(roleIds.map((roleId) => ({ roleId, siteResourceId }))); } - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/setSiteResourceUsers.ts b/server/routers/siteResource/setSiteResourceUsers.ts index 8ef9a0ab..4dae0ada 100644 --- a/server/routers/siteResource/setSiteResourceUsers.ts +++ b/server/routers/siteResource/setSiteResourceUsers.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const setSiteResourceUsersBodySchema = z .object({ @@ -102,7 +102,7 @@ export async function setSiteResourceUsers( .values(userIds.map((userId) => ({ userId, siteResourceId }))); } - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/updateSiteResource.ts b/server/routers/siteResource/updateSiteResource.ts index d66d2cb8..51a18af9 100644 --- a/server/routers/siteResource/updateSiteResource.ts +++ b/server/routers/siteResource/updateSiteResource.ts @@ -25,7 +25,7 @@ import { } from "@server/lib/ip"; import { getClientSiteResourceAccess, - rebuildClientAssociations + rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const updateSiteResourceParamsSchema = z.strictObject({ @@ -224,7 +224,7 @@ export async function updateSiteResource( ); } - const { mergedAllClients } = await rebuildClientAssociations( + const { mergedAllClients } = await rebuildClientAssociationsFromSiteResource( existingSiteResource, // we want to rebuild based on the existing resource then we will apply the change to the destination below trx ); diff --git a/server/routers/user/addUserRole.ts b/server/routers/user/addUserRole.ts index 915ea64a..9404d94f 100644 --- a/server/routers/user/addUserRole.ts +++ b/server/routers/user/addUserRole.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; +import { clients, db, UserOrg } from "@server/db"; import { userOrgs, roles } from "@server/db"; import { eq, and } from "drizzle-orm"; import response from "@server/lib/response"; @@ -10,11 +10,12 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import stoi from "@server/lib/stoi"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; const addUserRoleParamsSchema = z.strictObject({ - userId: z.string(), - roleId: z.string().transform(stoi).pipe(z.number()) - }); + userId: z.string(), + roleId: z.string().transform(stoi).pipe(z.number()) +}); export type AddUserRoleResponse = z.infer; @@ -72,7 +73,9 @@ export async function addUserRole( const existingUser = await db .select() .from(userOrgs) - .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, role.orgId))) + .where( + and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, role.orgId)) + ) .limit(1); if (existingUser.length === 0) { @@ -108,14 +111,39 @@ export async function addUserRole( ); } - const newUserRole = await db - .update(userOrgs) - .set({ roleId }) - .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, role.orgId))) - .returning(); + let newUserRole: UserOrg | null = null; + await db.transaction(async (trx) => { + [newUserRole] = await trx + .update(userOrgs) + .set({ roleId }) + .where( + and( + eq(userOrgs.userId, userId), + eq(userOrgs.orgId, role.orgId) + ) + ) + .returning(); + + // get the client associated with this user in this org + const [orgClient] = await trx + .select() + .from(clients) + .where( + and( + eq(clients.userId, userId), + eq(clients.orgId, role.orgId) + ) + ) + .limit(1); + + if (orgClient) { + // we just changed the user's role, so we need to rebuild client associations and what they have access to + await rebuildClientAssociationsFromClient(orgClient, trx); + } + }); return response(res, { - data: newUserRole[0], + data: newUserRole, success: true, error: false, message: "Role added to user successfully", diff --git a/server/routers/user/adminRemoveUser.ts b/server/routers/user/adminRemoveUser.ts index 02ad56d6..ae7f9f47 100644 --- a/server/routers/user/adminRemoveUser.ts +++ b/server/routers/user/adminRemoveUser.ts @@ -8,10 +8,11 @@ import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; +import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; const removeUserSchema = z.strictObject({ - userId: z.string() - }); + userId: z.string() +}); export async function adminRemoveUser( req: Request, @@ -50,7 +51,11 @@ export async function adminRemoveUser( ); } - await db.delete(users).where(eq(users.userId, userId)); + await db.transaction(async (trx) => { + await trx.delete(users).where(eq(users.userId, userId)); + + await calculateUserClientsForOrgs(userId, trx); + }); return response(res, { data: null, From de83cf9d8ca11fb92a5d2f563616d53281c07603 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 26 Nov 2025 15:35:33 -0500 Subject: [PATCH 03/14] Handle delete org and checking org policy --- server/routers/olm/handleOlmPingMessage.ts | 35 ++++++ .../routers/olm/handleOlmRegisterMessage.ts | 33 ++++- server/routers/org/deleteOrg.ts | 117 +++++++++++++----- server/routers/user/addUserRole.ts | 4 +- 4 files changed, 155 insertions(+), 34 deletions(-) diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index ab503d4c..ee9443f5 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -5,6 +5,8 @@ import { clients, Olm } from "@server/db"; import { eq, lt, isNull, and, or } from "drizzle-orm"; import logger from "@server/logger"; import { validateSessionToken } from "@server/auth/sessions/app"; +import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; +import { sendTerminateClient } from "../client/terminate"; // Track if the offline checker interval is running let offlineCheckerInterval: NodeJS.Timeout | null = null; @@ -57,6 +59,9 @@ export const startOlmOfflineChecker = (): void => { // Send a disconnect message to the client if connected try { + await sendTerminateClient(offlineClient.clientId); // terminate first + // wait a moment to ensure the message is sent + await new Promise(resolve => setTimeout(resolve, 1000)); await disconnectClient(offlineClient.olmId); } catch (error) { logger.error( @@ -110,6 +115,36 @@ export const handleOlmPingMessage: MessageHandler = async (context) => { logger.warn("User ID mismatch for olm ping"); return; } + + // get the client + const [client] = await db + .select() + .from(clients) + .where( + and( + eq(clients.olmId, olm.olmId), + eq(clients.userId, olm.userId) + ) + ) + .limit(1); + + if (!client) { + logger.warn("Client not found for olm ping"); + return; + } + + const policyCheck = await checkOrgAccessPolicy({ + orgId: client.orgId, + userId: olm.userId, + session: userToken // this is the user token passed in the message + }); + + if (!policyCheck.allowed) { + logger.warn( + `Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}` + ); + return; + } } if (!olm.clientId) { diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 2ee5c120..53cd9815 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -32,6 +32,8 @@ import { } from "@server/lib/ip"; import { generateRemoteSubnets } from "@server/lib/ip"; import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; +import { validateSessionToken } from "@server/auth/sessions/app"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.info("Handling register olm message!"); @@ -45,7 +47,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } - const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient } = + const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient, token: userToken } = message.data; let client: Client | undefined; @@ -78,6 +80,35 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } + if (!olm.userId) { + logger.warn("Olm has no user ID"); + return; + } + + const { session: userSession, user } = + await validateSessionToken(userToken); + if (!userSession || !user) { + logger.warn("Invalid user session for olm ping"); + return; // by returning here we just ignore the ping and the setInterval will force it to disconnect + } + if (user.userId !== olm.userId) { + logger.warn("User ID mismatch for olm ping"); + return; + } + + const policyCheck = await checkOrgAccessPolicy({ + orgId: orgId, + userId: olm.userId, + session: userToken // this is the user token passed in the message + }); + + if (!policyCheck.allowed) { + logger.warn( + `Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}` + ); + return; + } + logger.debug( `Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}` ); diff --git a/server/routers/org/deleteOrg.ts b/server/routers/org/deleteOrg.ts index 0e21a8c0..098c5c41 100644 --- a/server/routers/org/deleteOrg.ts +++ b/server/routers/org/deleteOrg.ts @@ -1,6 +1,15 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db, domains, orgDomains, resources } from "@server/db"; +import { + clients, + clientSiteResourcesAssociationsCache, + clientSitesAssociationsCache, + db, + domains, + olms, + orgDomains, + resources +} from "@server/db"; import { newts, newtSessions, orgs, sites, userActions } from "@server/db"; import { eq, and, inArray, sql } from "drizzle-orm"; import response from "@server/lib/response"; @@ -14,8 +23,8 @@ import { deletePeer } from "../gerbil/peers"; import { OpenAPITags, registry } from "@server/openApi"; const deleteOrgSchema = z.strictObject({ - orgId: z.string() - }); + orgId: z.string() +}); export type DeleteOrgResponse = {}; @@ -69,41 +78,75 @@ export async function deleteOrg( .where(eq(sites.orgId, orgId)) .limit(1); + const orgClients = await db + .select() + .from(clients) + .where(eq(clients.orgId, orgId)); + const deletedNewtIds: string[] = []; + const olmsToTerminate: string[] = []; await db.transaction(async (trx) => { - if (sites) { - for (const site of orgSites) { - if (site.pubKey) { - if (site.type == "wireguard") { - await deletePeer(site.exitNodeId!, site.pubKey); - } else if (site.type == "newt") { - // get the newt on the site by querying the newt table for siteId - const [deletedNewt] = await trx - .delete(newts) - .where(eq(newts.siteId, site.siteId)) - .returning(); - if (deletedNewt) { - deletedNewtIds.push(deletedNewt.newtId); + for (const site of orgSites) { + if (site.pubKey) { + if (site.type == "wireguard") { + await deletePeer(site.exitNodeId!, site.pubKey); + } else if (site.type == "newt") { + // get the newt on the site by querying the newt table for siteId + const [deletedNewt] = await trx + .delete(newts) + .where(eq(newts.siteId, site.siteId)) + .returning(); + if (deletedNewt) { + deletedNewtIds.push(deletedNewt.newtId); - // delete all of the sessions for the newt - await trx - .delete(newtSessions) - .where( - eq( - newtSessions.newtId, - deletedNewt.newtId - ) - ); - } + // delete all of the sessions for the newt + await trx + .delete(newtSessions) + .where( + eq(newtSessions.newtId, deletedNewt.newtId) + ); } } - - logger.info(`Deleting site ${site.siteId}`); - await trx - .delete(sites) - .where(eq(sites.siteId, site.siteId)); } + + logger.info(`Deleting site ${site.siteId}`); + await trx.delete(sites).where(eq(sites.siteId, site.siteId)); + } + for (const client of orgClients) { + const [olm] = await trx + .select() + .from(olms) + .where(eq(olms.clientId, client.clientId)) + .limit(1); + + if (olm) { + olmsToTerminate.push(olm.olmId); + } + + logger.info(`Deleting client ${client.clientId}`); + await trx + .delete(clients) + .where(eq(clients.clientId, client.clientId)); + + // also delete the associations + await trx + .delete(clientSiteResourcesAssociationsCache) + .where( + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ) + ); + + await trx + .delete(clientSitesAssociationsCache) + .where( + eq( + clientSitesAssociationsCache.clientId, + client.clientId + ) + ); } const allOrgDomains = await trx @@ -162,6 +205,18 @@ export async function deleteOrg( }); } + for (const olmId of olmsToTerminate) { + sendToClient(olmId, { + type: "olm/terminate", + data: {} + }).catch((error) => { + logger.error( + "Failed to send termination message to olm:", + error + ); + }); + } + return response(res, { data: null, success: true, diff --git a/server/routers/user/addUserRole.ts b/server/routers/user/addUserRole.ts index 9404d94f..32eaa19d 100644 --- a/server/routers/user/addUserRole.ts +++ b/server/routers/user/addUserRole.ts @@ -125,7 +125,7 @@ export async function addUserRole( .returning(); // get the client associated with this user in this org - const [orgClient] = await trx + const orgClients = await trx .select() .from(clients) .where( @@ -136,7 +136,7 @@ export async function addUserRole( ) .limit(1); - if (orgClient) { + for (const orgClient of orgClients) { // we just changed the user's role, so we need to rebuild client associations and what they have access to await rebuildClientAssociationsFromClient(orgClient, trx); } From 5bd31f87f0be0e3a29d9f538a947a724c74d8e2b Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Wed, 26 Nov 2025 15:48:38 -0500 Subject: [PATCH 04/14] only allow one device auth per session --- messages/en-US.json | 4 +-- server/auth/sessions/verifySession.ts | 8 ++++- server/db/pg/schema/schema.ts | 3 +- server/db/sqlite/schema/schema.ts | 3 +- server/routers/auth/verifyDeviceWebAuth.ts | 40 +++++++++++++++++----- 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/messages/en-US.json b/messages/en-US.json index f3ae34a7..76908b29 100644 --- a/messages/en-US.json +++ b/messages/en-US.json @@ -2140,7 +2140,7 @@ "deviceOrganizationsAccess": "Access to all organizations your account has access to", "deviceAuthorize": "Authorize {applicationName}", "deviceConnected": "Device Connected!", - "deviceAuthorizedMessage": "Your device is authorized to access your account.", + "deviceAuthorizedMessage": "Device is authorized to access your account.", "pangolinCloud": "Pangolin Cloud", "viewDevices": "View Devices", "viewDevicesDescription": "Manage your connected devices", @@ -2202,5 +2202,5 @@ "enterIdentifier": "Enter identifier", "identifier": "Identifier", "deviceLoginUseDifferentAccount": "Not you? Use a different account.", - "deviceLoginDeviceRequestingAccessToAccount": "Your device is requesting access to this account." + "deviceLoginDeviceRequestingAccessToAccount": "A device is requesting access to this account." } diff --git a/server/auth/sessions/verifySession.ts b/server/auth/sessions/verifySession.ts index 68a1f17e..01b32ef6 100644 --- a/server/auth/sessions/verifySession.ts +++ b/server/auth/sessions/verifySession.ts @@ -18,13 +18,19 @@ export async function verifySession(req: Request, forceLogin?: boolean) { user: null }; } + if (res.session.deviceAuthUsed) { + return { + session: null, + user: null + }; + } if (!res.session.issuedAt) { return { session: null, user: null }; } - const mins = 3 * 60 * 1000; + const mins = 5 * 60 * 1000; const now = new Date().getTime(); if (now - res.session.issuedAt > mins) { return { diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index d15676a0..120a7aa3 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -287,7 +287,8 @@ export const sessions = pgTable("session", { .notNull() .references(() => users.userId, { onDelete: "cascade" }), expiresAt: bigint("expiresAt", { mode: "number" }).notNull(), - issuedAt: bigint("issuedAt", { mode: "number" }) + issuedAt: bigint("issuedAt", { mode: "number" }), + deviceAuthUsed: boolean("deviceAuthUsed") }); export const newtSessions = pgTable("newtSession", { diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index 634afd36..a5f3d0f6 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -415,7 +415,8 @@ export const sessions = sqliteTable("session", { .notNull() .references(() => users.userId, { onDelete: "cascade" }), expiresAt: integer("expiresAt").notNull(), - issuedAt: integer("issuedAt") + issuedAt: integer("issuedAt"), + deviceAuthUsed: integer("deviceAuthUsed", { mode: "boolean" }) }); export const newtSessions = sqliteTable("newtSession", { diff --git a/server/routers/auth/verifyDeviceWebAuth.ts b/server/routers/auth/verifyDeviceWebAuth.ts index 715b299a..be0e0ff2 100644 --- a/server/routers/auth/verifyDeviceWebAuth.ts +++ b/server/routers/auth/verifyDeviceWebAuth.ts @@ -5,7 +5,7 @@ import { fromError } from "zod-validation-error"; import HttpCode from "@server/types/HttpCode"; import logger from "@server/logger"; import { response } from "@server/lib/response"; -import { db, deviceWebAuthCodes } from "@server/db"; +import { db, deviceWebAuthCodes, sessions } from "@server/db"; import { eq, and, gt } from "drizzle-orm"; import { encodeHexLowerCase } from "@oslojs/encoding"; import { sha256 } from "@oslojs/crypto/sha2"; @@ -44,20 +44,36 @@ export async function verifyDeviceWebAuth( ): Promise { const { user, session } = req; if (!user || !session) { - logger.debug("Unauthorized attempt to verify device web auth code"); - return next(unauthorized()); + return next(createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")); + } + + if (session.deviceAuthUsed) { + return next( + createHttpError( + HttpCode.UNAUTHORIZED, + "Device web auth code already used for this session" + ) + ); } if (!session.issuedAt) { - logger.debug("Session missing issuedAt timestamp"); - return next(unauthorized()); + return next( + createHttpError( + HttpCode.UNAUTHORIZED, + "Session issuedAt timestamp missing" + ) + ); } // make sure sessions is not older than 5 minutes const now = Date.now(); - if (now - session.issuedAt > 3 * 60 * 1000) { - logger.debug("Session is too old to verify device web auth code"); - return next(unauthorized()); + if (now - session.issuedAt > 5 * 60 * 1000) { + return next( + createHttpError( + HttpCode.UNAUTHORIZED, + "Session is too old to verify device web auth code" + ) + ); } const parsedBody = bodySchema.safeParse(req.body); @@ -134,6 +150,14 @@ export async function verifyDeviceWebAuth( }) .where(eq(deviceWebAuthCodes.codeId, deviceCode.codeId)); + // Also update the session to mark that device auth was used + await db + .update(sessions) + .set({ + deviceAuthUsed: true + }) + .where(eq(sessions.sessionId, session.sessionId)); + return response(res, { data: { success: true, From 79f0d60533fc38d375323604c95b0c1e23b3bf15 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 29 Nov 2025 22:57:11 -0500 Subject: [PATCH 05/14] Start working on HP IP changes --- server/routers/gerbil/updateHolePunch.ts | 60 ++++++++++++++++++++---- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index 031cd23e..e0564fd6 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -150,7 +150,7 @@ export async function updateAndGenerateEndpointDestinations( throw new Error("Olm not found"); } - const [client] = await db + const [updatedClient] = await db .update(clients) .set({ lastHolePunch: timestamp @@ -158,10 +158,16 @@ export async function updateAndGenerateEndpointDestinations( .where(eq(clients.clientId, olm.clientId)) .returning(); - if (await checkExitNodeOrg(exitNode.exitNodeId, client.orgId) && checkOrg) { + if ( + (await checkExitNodeOrg( + exitNode.exitNodeId, + updatedClient.orgId + )) && + checkOrg + ) { // not allowed logger.warn( - `Exit node ${exitNode.exitNodeId} is not allowed for org ${client.orgId}` + `Exit node ${exitNode.exitNodeId} is not allowed for org ${updatedClient.orgId}` ); throw new Error("Exit node not allowed"); } @@ -171,10 +177,14 @@ export async function updateAndGenerateEndpointDestinations( .select({ siteId: sites.siteId, subnet: sites.subnet, - listenPort: sites.listenPort + listenPort: sites.listenPort, + endpoint: clientSitesAssociationsCache.endpoint }) .from(sites) - .innerJoin(clientSitesAssociationsCache, eq(sites.siteId, clientSitesAssociationsCache.siteId)) + .innerJoin( + clientSitesAssociationsCache, + eq(sites.siteId, clientSitesAssociationsCache.siteId) + ) .where( and( eq(sites.exitNodeId, exitNode.exitNodeId), @@ -188,7 +198,7 @@ export async function updateAndGenerateEndpointDestinations( `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}` ); - await db + const [updatedClientSitesAssociationsCache] = await db .update(clientSitesAssociationsCache) .set({ endpoint: `${ip}:${port}` @@ -198,13 +208,27 @@ export async function updateAndGenerateEndpointDestinations( eq(clientSitesAssociationsCache.clientId, olm.clientId), eq(clientSitesAssociationsCache.siteId, site.siteId) ) + ) + .returning(); + + if ( + updatedClientSitesAssociationsCache.endpoint !== site.endpoint // this is the endpoint from the join table not the site + ) { + logger.info( + `ClientSitesAssociationsCache for client ${olm.clientId} and site ${site.siteId} endpoint changed from ${site.endpoint} to ${updatedClientSitesAssociationsCache.endpoint}` ); + // Handle any additional logic for endpoint change + handleClientEndpointChange( + olm.clientId, + updatedClientSitesAssociationsCache.endpoint! + ); + } } logger.debug( `Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}` ); - if (!client) { + if (!updatedClient) { logger.warn(`Client not found for olm: ${olmId}`); throw new Error("Client not found"); } @@ -253,7 +277,10 @@ export async function updateAndGenerateEndpointDestinations( .where(eq(sites.siteId, newt.siteId)) .limit(1); - if (await checkExitNodeOrg(exitNode.exitNodeId, site.orgId) && checkOrg) { + if ( + (await checkExitNodeOrg(exitNode.exitNodeId, site.orgId)) && + checkOrg + ) { // not allowed logger.warn( `Exit node ${exitNode.exitNodeId} is not allowed for org ${site.orgId}` @@ -273,6 +300,14 @@ export async function updateAndGenerateEndpointDestinations( .where(eq(sites.siteId, newt.siteId)) .returning(); + if (updatedSite.endpoint != site.endpoint) { + logger.info( + `Site ${newt.siteId} endpoint changed from ${site.endpoint} to ${updatedSite.endpoint}` + ); + // Handle any additional logic for endpoint change + handleSiteEndpointChange(newt.siteId, updatedSite.endpoint!); + } + if (!updatedSite || !updatedSite.subnet) { logger.warn(`Site not found: ${newt.siteId}`); throw new Error("Site not found"); @@ -326,3 +361,12 @@ export async function updateAndGenerateEndpointDestinations( } return destinations; } + +function handleSiteEndpointChange(siteId: number, newEndpoint: string) { + // just alert all of the clients connected to this site that the endpoint has changed but only if they are NOT relayed +} + +function handleClientEndpointChange(clientId: number, newEndpoint: string) { + // just alert all of the sites connected to this client that the endpoint has changed but only if they are NOT relayed + +} From dd6b1d88d30271003b28c3c2218ec0c613a24871 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 30 Nov 2025 11:39:16 -0500 Subject: [PATCH 06/14] Update peer data when HP changes --- server/routers/gerbil/updateHolePunch.ts | 141 ++++++++++++++++++++++- server/routers/olm/peers.ts | 4 +- 2 files changed, 139 insertions(+), 6 deletions(-) diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index e0564fd6..f049a5c1 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -19,6 +19,8 @@ import { fromError } from "zod-validation-error"; import { validateNewtSessionToken } from "@server/auth/sessions/newt"; import { validateOlmSessionToken } from "@server/auth/sessions/olm"; import { checkExitNodeOrg } from "#dynamic/lib/exitNodes"; +import { updatePeer as updateOlmPeer } from "../olm/peers"; +import { updatePeer as updateNewtPeer } from "../newt/peers"; // Define Zod schema for request validation const updateHolePunchSchema = z.object({ @@ -362,11 +364,142 @@ export async function updateAndGenerateEndpointDestinations( return destinations; } -function handleSiteEndpointChange(siteId: number, newEndpoint: string) { - // just alert all of the clients connected to this site that the endpoint has changed but only if they are NOT relayed +async function handleSiteEndpointChange(siteId: number, newEndpoint: string) { + // Alert all clients connected to this site that the endpoint has changed (only if NOT relayed) + try { + // Get site details + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + + if (!site || !site.publicKey) { + logger.warn(`Site ${siteId} not found or has no public key`); + return; + } + + // Get all non-relayed clients connected to this site + const connectedClients = await db + .select({ + clientId: clients.clientId, + olmId: olms.olmId, + isRelayed: clientSitesAssociationsCache.isRelayed + }) + .from(clientSitesAssociationsCache) + .innerJoin( + clients, + eq(clientSitesAssociationsCache.clientId, clients.clientId) + ) + .innerJoin(olms, eq(olms.clientId, clients.clientId)) + .where( + and( + eq(clientSitesAssociationsCache.siteId, siteId), + eq(clientSitesAssociationsCache.isRelayed, false) + ) + ); + + // Update each non-relayed client with the new site endpoint + for (const client of connectedClients) { + try { + await updateOlmPeer( + client.clientId, + { + siteId: siteId, + publicKey: site.publicKey, + endpoint: newEndpoint, + }, + client.olmId + ); + logger.debug( + `Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}` + ); + } catch (error) { + logger.error( + `Failed to update client ${client.clientId} with new site endpoint: ${error}` + ); + } + } + } catch (error) { + logger.error( + `Error handling site endpoint change for site ${siteId}: ${error}` + ); + } } -function handleClientEndpointChange(clientId: number, newEndpoint: string) { - // just alert all of the sites connected to this client that the endpoint has changed but only if they are NOT relayed +async function handleClientEndpointChange( + clientId: number, + newEndpoint: string +) { + // Alert all sites connected to this client that the endpoint has changed (only if NOT relayed) + try { + // Get client details + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, clientId)) + .limit(1); + if (!client || !client.pubKey) { + logger.warn(`Client ${clientId} not found or has no public key`); + return; + } + + // Get all non-relayed sites connected to this client + const connectedSites = await db + .select({ + siteId: sites.siteId, + newtId: newts.newtId, + isRelayed: clientSitesAssociationsCache.isRelayed, + subnet: clients.subnet + }) + .from(clientSitesAssociationsCache) + .innerJoin( + sites, + eq(clientSitesAssociationsCache.siteId, sites.siteId) + ) + .innerJoin(newts, eq(newts.siteId, sites.siteId)) + .innerJoin( + clients, + eq(clientSitesAssociationsCache.clientId, clients.clientId) + ) + .where( + and( + eq(clientSitesAssociationsCache.clientId, clientId), + eq(clientSitesAssociationsCache.isRelayed, false) + ) + ); + + // Update each non-relayed site with the new client endpoint + for (const siteData of connectedSites) { + try { + if (!siteData.subnet) { + logger.warn( + `Client ${clientId} has no subnet, skipping update for site ${siteData.siteId}` + ); + continue; + } + + await updateNewtPeer( + siteData.siteId, + client.pubKey, + { + endpoint: newEndpoint + }, + siteData.newtId + ); + logger.debug( + `Updated site ${siteData.siteId} with new client ${clientId} endpoint: ${newEndpoint}` + ); + } catch (error) { + logger.error( + `Failed to update site ${siteData.siteId} with new client endpoint: ${error}` + ); + } + } + } catch (error) { + logger.error( + `Error handling client endpoint change for client ${clientId}: ${error}` + ); + } } diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index 1daed53a..69ea2bc9 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -78,8 +78,8 @@ export async function updatePeer( siteId: number; publicKey: string; endpoint: string; - serverIP: string | null; - serverPort: number | null; + serverIP?: string | null; + serverPort?: number | null; remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that }, olmId?: string From 096da391e50a030cd2f293216105a6bb8d220935 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 30 Nov 2025 17:38:12 -0500 Subject: [PATCH 07/14] Add a utility subnet --- server/lib/createUserAccountOrg.ts | 4 ++++ server/lib/readConfigFile.ts | 6 ++++-- server/routers/org/createOrg.ts | 12 ++++++++---- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/server/lib/createUserAccountOrg.ts b/server/lib/createUserAccountOrg.ts index 1406b935..11f4e247 100644 --- a/server/lib/createUserAccountOrg.ts +++ b/server/lib/createUserAccountOrg.ts @@ -18,6 +18,7 @@ import { defaultRoleAllowedActions } from "@server/routers/role"; import { FeatureId, limitsService, sandboxLimitSet } from "@server/lib/billing"; import { createCustomer } from "#dynamic/lib/billing"; import { usageService } from "@server/lib/billing/usageService"; +import config from "@server/lib/config"; export async function createUserAccountOrg( userId: string, @@ -76,6 +77,8 @@ export async function createUserAccountOrg( .from(domains) .where(eq(domains.configManaged, true)); + const utilitySubnet = config.getRawConfig().orgs.utility_subnet_group; + const newOrg = await trx .insert(orgs) .values({ @@ -83,6 +86,7 @@ export async function createUserAccountOrg( name, // subnet subnet: "100.90.128.0/24", // TODO: this should not be hardcoded - or can it be the same in all orgs? + utilitySubnet: utilitySubnet, createdAt: new Date().toISOString() }) .returning(); diff --git a/server/lib/readConfigFile.ts b/server/lib/readConfigFile.ts index 2da8c0a7..dc5ec729 100644 --- a/server/lib/readConfigFile.ts +++ b/server/lib/readConfigFile.ts @@ -249,12 +249,14 @@ export const configSchema = z orgs: z .object({ block_size: z.number().positive().gt(0).optional().default(24), - subnet_group: z.string().optional().default("100.90.128.0/24") + subnet_group: z.string().optional().default("100.90.128.0/24"), + utility_subnet_group: z.string().optional().default("100.96.128.0/24") //just hardcode this for now as well }) .optional() .default({ block_size: 24, - subnet_group: "100.90.128.0/24" + subnet_group: "100.90.128.0/24", + utility_subnet_group: "100.96.128.0/24" }), rate_limits: z .object({ diff --git a/server/routers/org/createOrg.ts b/server/routers/org/createOrg.ts index e44bf021..8276da9a 100644 --- a/server/routers/org/createOrg.ts +++ b/server/routers/org/createOrg.ts @@ -28,10 +28,10 @@ import { FeatureId } from "@server/lib/billing"; import { build } from "@server/build"; const createOrgSchema = z.strictObject({ - orgId: z.string(), - name: z.string().min(1).max(255), - subnet: z.string() - }); + orgId: z.string(), + name: z.string().min(1).max(255), + subnet: z.string() +}); registry.registerPath({ method: "put", @@ -131,12 +131,16 @@ export async function createOrg( .from(domains) .where(eq(domains.configManaged, true)); + const utilitySubnet = + config.getRawConfig().orgs.utility_subnet_group; + const newOrg = await trx .insert(orgs) .values({ orgId, name, subnet, + utilitySubnet, createdAt: new Date().toISOString() }) .returning(); From 92125611e9a5bc8845f424950b7c4b620fbd3897 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 30 Nov 2025 17:49:55 -0500 Subject: [PATCH 08/14] Add validation and fix thrown error from updatePeerData --- .../siteResource/updateSiteResource.ts | 100 +++++++++++++----- 1 file changed, 73 insertions(+), 27 deletions(-) diff --git a/server/routers/siteResource/updateSiteResource.ts b/server/routers/siteResource/updateSiteResource.ts index 51a18af9..91b14bf3 100644 --- a/server/routers/siteResource/updateSiteResource.ts +++ b/server/routers/siteResource/updateSiteResource.ts @@ -49,7 +49,44 @@ const updateSiteResourceSchema = z roleIds: z.array(z.int()), clientIds: z.array(z.int()) }) - .strict(); + .strict() + .refine( + (data) => { + if (data.mode === "host" && data.destination) { + // Check if it's a valid IP address using zod (v4 or v6) + const isValidIP = z + .union([z.ipv4(), z.ipv6()]) + .safeParse(data.destination).success; + + // Check if it's a valid domain (hostname pattern, TLD not required) + const domainRegex = + /^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$/; + const isValidDomain = domainRegex.test(data.destination); + + return isValidIP || isValidDomain; + } + return true; + }, + { + message: + "Destination must be a valid IP address or domain name for host mode" + } + ) + .refine( + (data) => { + if (data.mode === "cidr" && data.destination) { + // Check if it's a valid CIDR (v4 or v6) + const isValidCIDR = z + .union([z.cidrv4(), z.cidrv6()]) + .safeParse(data.destination).success; + return isValidCIDR; + } + return true; + }, + { + message: "Destination must be a valid CIDR notation for cidr mode" + } + ); export type UpdateSiteResourceBody = z.infer; export type UpdateSiteResourceResponse = SiteResource; @@ -224,10 +261,11 @@ export async function updateSiteResource( ); } - const { mergedAllClients } = await rebuildClientAssociationsFromSiteResource( - existingSiteResource, // we want to rebuild based on the existing resource then we will apply the change to the destination below - trx - ); + const { mergedAllClients } = + await rebuildClientAssociationsFromSiteResource( + existingSiteResource, // we want to rebuild based on the existing resource then we will apply the change to the destination below + trx + ); // after everything is rebuilt above we still need to update the targets and remote subnets if the destination changed if ( @@ -263,28 +301,36 @@ export async function updateSiteResource( let olmJobs: Promise[] = []; for (const client of mergedAllClients) { // we also need to update the remote subnets on the olms for each client that has access to this site - olmJobs.push( - updatePeerData( - client.clientId, - updatedSiteResource.siteId, - { - oldRemoteSubnets: generateRemoteSubnets([ - existingSiteResource - ]), - newRemoteSubnets: generateRemoteSubnets([ - updatedSiteResource - ]) - }, - { - oldAliases: generateAliasConfig([ - existingSiteResource - ]), - newAliases: generateAliasConfig([ - updatedSiteResource - ]) - } - ) - ); + try { + olmJobs.push( + updatePeerData( + client.clientId, + updatedSiteResource.siteId, + { + oldRemoteSubnets: generateRemoteSubnets([ + existingSiteResource + ]), + newRemoteSubnets: generateRemoteSubnets([ + updatedSiteResource + ]) + }, + { + oldAliases: generateAliasConfig([ + existingSiteResource + ]), + newAliases: generateAliasConfig([ + updatedSiteResource + ]) + } + ) + ); + } catch (error) { + logger.warn( + // this is okay because sometimes the olm is not online to receive the update or associated with the client yet + `Error updating peer data for client ${client.clientId}:`, + error + ); + } } await Promise.all(olmJobs); From 8c62dfa70621c29d8cffcc1fc5d86dbd7862ea2c Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Mon, 1 Dec 2025 12:36:02 -0500 Subject: [PATCH 09/14] respond with relative code expiration time --- server/routers/auth/startDeviceWebAuth.ts | 7 +++++-- src/app/auth/login/device/page.tsx | 2 -- src/components/DeviceLoginForm.tsx | 9 +++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/server/routers/auth/startDeviceWebAuth.ts b/server/routers/auth/startDeviceWebAuth.ts index 8897e73f..925df67f 100644 --- a/server/routers/auth/startDeviceWebAuth.ts +++ b/server/routers/auth/startDeviceWebAuth.ts @@ -22,7 +22,7 @@ export type StartDeviceWebAuthBody = z.infer; export type StartDeviceWebAuthResponse = { code: string; - expiresAt: number; + expiresInSeconds: number; }; // Helper function to generate device code in format A1AJ-N5JD @@ -131,10 +131,13 @@ export async function startDeviceWebAuth( createdAt: Date.now() }); + // calculate relative expiration in seconds + const expiresInSeconds = Math.floor((expiresAt - Date.now()) / 1000); + return response(res, { data: { code, - expiresAt + expiresInSeconds }, success: true, error: false, diff --git a/src/app/auth/login/device/page.tsx b/src/app/auth/login/device/page.tsx index a19174d0..07c804fb 100644 --- a/src/app/auth/login/device/page.tsx +++ b/src/app/auth/login/device/page.tsx @@ -15,8 +15,6 @@ export default async function DeviceLoginPage({ searchParams }: Props) { const params = await searchParams; const code = params.code || ""; - console.log("user", user); - if (!user) { const redirectDestination = code ? `/auth/login/device?code=${encodeURIComponent(code)}` diff --git a/src/components/DeviceLoginForm.tsx b/src/components/DeviceLoginForm.tsx index 8b6d460c..1eeeb5ae 100644 --- a/src/components/DeviceLoginForm.tsx +++ b/src/components/DeviceLoginForm.tsx @@ -84,6 +84,9 @@ export default function DeviceLoginForm({ if (!data.code.includes("-") && data.code.length === 8) { data.code = data.code.slice(0, 4) + "-" + data.code.slice(4); } + + await new Promise((resolve) => setTimeout(resolve, 300)); + // First check - get metadata const res = await api.post( "/device-web-auth/verify?forceLogin=true", @@ -93,8 +96,6 @@ export default function DeviceLoginForm({ } ); - await new Promise((resolve) => setTimeout(resolve, 500)); // artificial delay for better UX - if (res.data.success && res.data.data.metadata) { setMetadata(res.data.data.metadata); setCode(data.code.toUpperCase()); @@ -116,14 +117,14 @@ export default function DeviceLoginForm({ setLoading(true); try { + await new Promise((resolve) => setTimeout(resolve, 300)); + // Final verify await api.post("/device-web-auth/verify", { code: code, verify: true }); - await new Promise((resolve) => setTimeout(resolve, 500)); // artificial delay for better UX - // Redirect to success page router.push("/auth/login/device/success"); } catch (e: any) { From a623604e9630aa28c30a5951c6249c224ae9a631 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 1 Dec 2025 11:51:20 -0500 Subject: [PATCH 10/14] Improve holepunching --- server/auth/sessions/app.ts | 16 +- server/db/pg/schema/schema.ts | 5 +- server/db/sqlite/schema/schema.ts | 6 +- server/lib/readConfigFile.ts | 5 + server/lib/rebuildClientAssociations.ts | 68 +------ server/private/routers/hybrid.ts | 5 +- server/routers/gerbil/updateHolePunch.ts | 54 +++-- server/routers/newt/handleGetConfigMessage.ts | 12 +- server/routers/olm/getOlmToken.ts | 89 ++++++++- server/routers/olm/handleOlmPingMessage.ts | 2 +- .../routers/olm/handleOlmRegisterMessage.ts | 79 +++----- server/routers/olm/handleOlmRelayMessage.ts | 18 +- .../olm/handleOlmServerPeerAddMessage.ts | 185 ++++++++++++++++++ server/routers/olm/index.ts | 1 + server/routers/olm/peers.ts | 38 ++++ server/routers/ws/messageHandlers.ts | 14 +- 16 files changed, 427 insertions(+), 170 deletions(-) create mode 100644 server/routers/olm/handleOlmServerPeerAddMessage.ts diff --git a/server/auth/sessions/app.ts b/server/auth/sessions/app.ts index 0e3da100..73b220fa 100644 --- a/server/auth/sessions/app.ts +++ b/server/auth/sessions/app.ts @@ -36,13 +36,15 @@ export async function createSession( const sessionId = encodeHexLowerCase( sha256(new TextEncoder().encode(token)) ); - const session: Session = { - sessionId: sessionId, - userId, - expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(), - issuedAt: new Date().getTime() - }; - await db.insert(sessions).values(session); + const [session] = await db + .insert(sessions) + .values({ + sessionId: sessionId, + userId, + expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(), + issuedAt: new Date().getTime() + }) + .returning(); return session; } diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index 120a7aa3..32b1252f 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -288,7 +288,7 @@ export const sessions = pgTable("session", { .references(() => users.userId, { onDelete: "cascade" }), expiresAt: bigint("expiresAt", { mode: "number" }).notNull(), issuedAt: bigint("issuedAt", { mode: "number" }), - deviceAuthUsed: boolean("deviceAuthUsed") + deviceAuthUsed: boolean("deviceAuthUsed").notNull().default(false) }); export const newtSessions = pgTable("newtSession", { @@ -665,7 +665,8 @@ export const clientSitesAssociationsCache = pgTable( .notNull(), siteId: integer("siteId").notNull(), isRelayed: boolean("isRelayed").notNull().default(false), - endpoint: varchar("endpoint") + endpoint: varchar("endpoint"), + publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes } ); diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index a5f3d0f6..8b42a461 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -1,6 +1,7 @@ import { randomUUID } from "crypto"; import { InferSelectModel } from "drizzle-orm"; import { sqliteTable, text, integer, index } from "drizzle-orm/sqlite-core"; +import { no } from "zod/v4/locales"; export const domains = sqliteTable("domains", { domainId: text("domainId").primaryKey(), @@ -372,7 +373,8 @@ export const clientSitesAssociationsCache = sqliteTable( isRelayed: integer("isRelayed", { mode: "boolean" }) .notNull() .default(false), - endpoint: text("endpoint") + endpoint: text("endpoint"), + publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes } ); @@ -417,6 +419,8 @@ export const sessions = sqliteTable("session", { expiresAt: integer("expiresAt").notNull(), issuedAt: integer("issuedAt"), deviceAuthUsed: integer("deviceAuthUsed", { mode: "boolean" }) + .notNull() + .default(false) }); export const newtSessions = sqliteTable("newtSession", { diff --git a/server/lib/readConfigFile.ts b/server/lib/readConfigFile.ts index dc5ec729..fe0dd593 100644 --- a/server/lib/readConfigFile.ts +++ b/server/lib/readConfigFile.ts @@ -229,6 +229,11 @@ export const configSchema = z .default(51820) .transform(stoi) .pipe(portSchema), + clients_start_port: portSchema + .optional() + .default(21820) + .transform(stoi) + .pipe(portSchema), base_endpoint: z .string() .optional() diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index f1cbea0c..4be2c92a 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -25,6 +25,7 @@ import { deletePeer as newtDeletePeer } from "@server/routers/newt/peers"; import { + initPeerAddHandshake as holepunchSiteAdd, addPeer as olmAddPeer, deletePeer as olmDeletePeer } from "@server/routers/olm/peers"; @@ -464,65 +465,16 @@ async function handleMessagesForSiteClients( } if (isAdd) { - // TODO: WE NEED TO HANDLE THIS BETTER. WE ARE DEFAULTING TO RELAYING FOR NEW SITES - // BUT REALLY WE NEED TO TRACK THE USERS PREFERENCE THAT THEY CHOSE IN THE CLIENTS - // AND TRIGGER A HOLEPUNCH OR SOMETHING TO GET THE ENDPOINT AND HP TO THE NEW SITES - const isRelayed = true; - - newtJobs.push( - newtAddPeer( + await holepunchSiteAdd( // this will kick off the add peer process for the client + client.clientId, + { siteId, - { - publicKey: client.pubKey, - allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client - // endpoint: isRelayed ? "" : clientSite.endpoint - endpoint: isRelayed ? "" : "" // we are not HPing yet so no endpoint - }, - newt.newtId - ) - ); - - // TODO: should we have this here? - const allSiteResources = await db // only get the site resources that this client has access to - .select() - .from(siteResources) - .innerJoin( - clientSiteResourcesAssociationsCache, - eq( - siteResources.siteResourceId, - clientSiteResourcesAssociationsCache.siteResourceId - ) - ) - .where( - and( - eq(siteResources.siteId, site.siteId), - eq( - clientSiteResourcesAssociationsCache.clientId, - client.clientId - ) - ) - ); - - olmJobs.push( - olmAddPeer( - client.clientId, - { - siteId: site.siteId, - endpoint: - isRelayed || !site.endpoint - ? `${exitNode.endpoint}:21820` - : site.endpoint, - publicKey: site.publicKey, - serverIP: site.address, - serverPort: site.listenPort, - remoteSubnets: generateRemoteSubnets( - allSiteResources.map( - ({ siteResources }) => siteResources - ) - ) - }, - olm.olmId - ) + exitNode: { + publicKey: exitNode.publicKey, + endpoint: exitNode.endpoint + } + }, + olm.olmId ); } diff --git a/server/private/routers/hybrid.ts b/server/private/routers/hybrid.ts index f78fb592..29e766b9 100644 --- a/server/private/routers/hybrid.ts +++ b/server/private/routers/hybrid.ts @@ -1369,7 +1369,7 @@ const updateHolePunchSchema = z.object({ port: z.number(), timestamp: z.number(), reachableAt: z.string().optional(), - publicKey: z.string().optional() + publicKey: z.string() // this is the client public key }); hybridRouter.post( "/gerbil/update-hole-punch", @@ -1408,7 +1408,7 @@ hybridRouter.post( ); } - const { olmId, newtId, ip, port, timestamp, token, reachableAt } = + const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } = parsedParams.data; const destinations = await updateAndGenerateEndpointDestinations( @@ -1418,6 +1418,7 @@ hybridRouter.post( port, timestamp, token, + publicKey, exitNode, true ); diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index f049a5c1..c5f0d2b8 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -30,8 +30,9 @@ const updateHolePunchSchema = z.object({ ip: z.string(), port: z.number(), timestamp: z.number(), + publicKey: z.string(), reachableAt: z.string().optional(), - publicKey: z.string().optional() + exitNodePublicKey: z.string().optional() }); // New response type with multi-peer destination support @@ -65,23 +66,26 @@ export async function updateHolePunch( timestamp, token, reachableAt, - publicKey + publicKey, // this is the client's current public key for this session + exitNodePublicKey } = parsedParams.data; let exitNode: ExitNode | undefined; - if (publicKey) { + if (exitNodePublicKey) { // Get the exit node by public key [exitNode] = await db .select() .from(exitNodes) - .where(eq(exitNodes.publicKey, publicKey)); + .where(eq(exitNodes.publicKey, exitNodePublicKey)); } else { // FOR BACKWARDS COMPATIBILITY IF GERBIL IS STILL =<1.1.0 [exitNode] = await db.select().from(exitNodes).limit(1); } if (!exitNode) { - logger.warn(`Exit node not found for publicKey: ${publicKey}`); + logger.warn( + `Exit node not found for publicKey: ${exitNodePublicKey}` + ); return next( createHttpError(HttpCode.NOT_FOUND, "Exit node not found") ); @@ -94,12 +98,13 @@ export async function updateHolePunch( port, timestamp, token, + publicKey, exitNode ); - logger.debug( - `Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}` - ); + // logger.debug( + // `Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}` + // ); // Return the new multi-peer structure return res.status(HttpCode.OK).send({ @@ -123,6 +128,7 @@ export async function updateAndGenerateEndpointDestinations( port: number, timestamp: number, token: string, + publicKey: string, exitNode: ExitNode, checkOrg = false ) { @@ -130,9 +136,9 @@ export async function updateAndGenerateEndpointDestinations( const destinations: PeerDestination[] = []; if (olmId) { - logger.debug( - `Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}` - ); + // logger.debug( + // `Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}` + // ); const { session, olm: olmSession } = await validateOlmSessionToken(token); @@ -180,6 +186,7 @@ export async function updateAndGenerateEndpointDestinations( siteId: sites.siteId, subnet: sites.subnet, listenPort: sites.listenPort, + publicKey: sites.publicKey, endpoint: clientSitesAssociationsCache.endpoint }) .from(sites) @@ -200,10 +207,19 @@ export async function updateAndGenerateEndpointDestinations( `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}` ); + // if the public key or endpoint has changed, update it otherwise continue + if ( + site.endpoint === `${ip}:${port}` && + site.publicKey === publicKey + ) { + continue; + } + const [updatedClientSitesAssociationsCache] = await db .update(clientSitesAssociationsCache) .set({ - endpoint: `${ip}:${port}` + endpoint: `${ip}:${port}`, + publicKey: publicKey }) .where( and( @@ -227,9 +243,9 @@ export async function updateAndGenerateEndpointDestinations( } } - logger.debug( - `Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}` - ); + // logger.debug( + // `Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}` + // ); if (!updatedClient) { logger.warn(`Client not found for olm: ${olmId}`); throw new Error("Client not found"); @@ -245,9 +261,9 @@ export async function updateAndGenerateEndpointDestinations( } } } else if (newtId) { - logger.debug( - `Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}` - ); + // logger.debug( + // `Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}` + // ); const { session, newt: newtSession } = await validateNewtSessionToken(token); @@ -407,7 +423,7 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) { { siteId: siteId, publicKey: site.publicKey, - endpoint: newEndpoint, + endpoint: newEndpoint }, client.olmId ); diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts index fbbcb4fb..36105d7e 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -79,12 +79,12 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { // TODO: somehow we should make sure a recent hole punch has happened if this occurs (hole punch could be from the last restart if done quickly) } - // if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 6) { - // logger.warn( - // `Site ${existingSite.siteId} last hole punch is too old, skipping` - // ); - // return; - // } + if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 5) { + logger.warn( + `Site ${existingSite.siteId} last hole punch is too old, skipping` + ); + return; + } // update the endpoint and the public key const [site] = await db diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index c26f5936..33f5fa2d 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -1,9 +1,9 @@ import { generateSessionToken } from "@server/auth/sessions/app"; -import { db } from "@server/db"; +import { clients, db, ExitNode, exitNodes, sites, clientSitesAssociationsCache } from "@server/db"; import { olms } from "@server/db"; import HttpCode from "@server/types/HttpCode"; import response from "@server/lib/response"; -import { eq } from "drizzle-orm"; +import { eq, inArray } from "drizzle-orm"; import { NextFunction, Request, Response } from "express"; import createHttpError from "http-errors"; import { z } from "zod"; @@ -15,11 +15,13 @@ import { import { verifyPassword } from "@server/auth/password"; import logger from "@server/logger"; import config from "@server/lib/config"; +import { listExitNodes } from "#dynamic/lib/exitNodes"; export const olmGetTokenBodySchema = z.object({ olmId: z.string(), secret: z.string(), - token: z.string().optional() + token: z.string().optional(), + orgId: z.string().optional() }); export type OlmGetTokenBody = z.infer; @@ -40,7 +42,7 @@ export async function getOlmToken( ); } - const { olmId, secret, token } = parsedBody.data; + const { olmId, secret, token, orgId } = parsedBody.data; try { if (token) { @@ -61,11 +63,12 @@ export async function getOlmToken( } } - const existingOlmRes = await db + const [existingOlm] = await db .select() .from(olms) .where(eq(olms.olmId, olmId)); - if (!existingOlmRes || !existingOlmRes.length) { + + if (!existingOlm) { return next( createHttpError( HttpCode.BAD_REQUEST, @@ -74,12 +77,11 @@ export async function getOlmToken( ); } - const existingOlm = existingOlmRes[0]; - const validSecret = await verifyPassword( secret, existingOlm.secretHash ); + if (!validSecret) { if (config.getRawConfig().app.log_failed_attempts) { logger.info( @@ -96,11 +98,78 @@ export async function getOlmToken( const resToken = generateSessionToken(); await createOlmSession(resToken, existingOlm.olmId); + let orgIdToUse = orgId; + if (!orgIdToUse) { + if (!existingOlm.clientId) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Olm is not associated with a client, orgId is required" + ) + ); + } + + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, existingOlm.clientId)) + .limit(1); + + if (!client) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Olm's associated client not found, orgId is required" + ) + ); + } + + orgIdToUse = client.orgId; + } + + // Get all exit nodes from sites where the client has peers + const clientSites = await db + .select() + .from(clientSitesAssociationsCache) + .innerJoin( + sites, + eq(sites.siteId, clientSitesAssociationsCache.siteId) + ) + .where(eq(clientSitesAssociationsCache.clientId, existingOlm.clientId!)); + + // Extract unique exit node IDs + const exitNodeIds = Array.from( + new Set( + clientSites + .map(({ sites: site }) => site.exitNodeId) + .filter((id): id is number => id !== null) + ) + ); + + let allExitNodes: ExitNode[] = []; + if (exitNodeIds.length > 0) { + allExitNodes = await db + .select() + .from(exitNodes) + .where(inArray(exitNodes.exitNodeId, exitNodeIds)); + } + + const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => { + return { + publicKey: exitNode.publicKey, + endpoint: exitNode.endpoint + }; + }); + logger.debug("Token created successfully"); - return response<{ token: string }>(res, { + return response<{ + token: string; + exitNodes: { publicKey: string; endpoint: string }[]; + }>(res, { data: { - token: resToken + token: resToken, + exitNodes: exitNodesHpData }, success: true, error: false, diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index ee9443f5..4bcbbb8b 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -59,7 +59,7 @@ export const startOlmOfflineChecker = (): void => { // Send a disconnect message to the client if connected try { - await sendTerminateClient(offlineClient.clientId); // terminate first + await sendTerminateClient(offlineClient.clientId, offlineClient.olmId); // terminate first // wait a moment to ensure the message is sent await new Promise(resolve => setTimeout(resolve, 1000)); await disconnectClient(offlineClient.olmId); diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 53cd9815..ce40e832 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -34,21 +34,28 @@ import { generateRemoteSubnets } from "@server/lib/ip"; import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; import { validateSessionToken } from "@server/auth/sessions/app"; +import config from "@server/lib/config"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.info("Handling register olm message!"); const { message, client: c, sendToClient } = context; const olm = c as Olm; - const now = new Date().getTime() / 1000; + const now = Math.floor(Date.now() / 1000); if (!olm) { logger.warn("Olm not found"); return; } - const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient, token: userToken } = - message.data; + const { + publicKey, + relay, + olmVersion, + orgId, + doNotCreateNewClient, + token: userToken + } = message.data; let client: Client | undefined; let org: Org | undefined; @@ -63,7 +70,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { olm.name || "User Device", // doNotCreateNewClient ? true : false true // for now never create a new client automatically because we create the users clients when they are added to the org - // this means that the rebuildClientAssociationsFromClient call below issue is not a problem + // this means that the rebuildClientAssociationsFromClient call below issue is not a problem ); client = clientRes; @@ -113,12 +120,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { `Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}` ); - await db - .update(olms) - .set({ - clientId: client.clientId - }) - .where(eq(olms.olmId, olm.olmId)); + if (olm.clientId !== client.clientId) { // we only need to do this if the client is changing + await db + .update(olms) + .set({ + clientId: client.clientId + }) + .where(eq(olms.olmId, olm.olmId)); + } } else { if (!olm.clientId) { logger.warn("Olm has no client ID!"); @@ -159,41 +168,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } - if (client.exitNodeId) { - // TODO: FOR NOW WE ARE JUST HOLEPUNCHING ALL EXIT NODES BUT IN THE FUTURE WE SHOULD HANDLE THIS BETTER - - // Get the exit node - const allExitNodes = await listExitNodes(client.orgId, true); // FILTER THE ONLINE ONES - - const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => { - return { - publicKey: exitNode.publicKey, - endpoint: exitNode.endpoint - }; - }); - - // Send holepunch message - await sendToClient(olm.olmId, { - type: "olm/wg/holepunch/all", - data: { - exitNodes: exitNodesHpData - } - }); - - if (!olmVersion) { - // THIS IS FOR BACKWARDS COMPATIBILITY - // THE OLDER CLIENTS DID NOT SEND THE VERSION - await sendToClient(olm.olmId, { - type: "olm/wg/holepunch", - data: { - serverPubKey: allExitNodes[0].publicKey, - endpoint: allExitNodes[0].endpoint - } - }); - } - } - - if (olmVersion) { + if (olmVersion && olm.version !== olmVersion) { await db .update(olms) .set({ @@ -202,10 +177,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .where(eq(olms.olmId, olm.olmId)); } - // if (now - (client.lastHolePunch || 0) > 6) { - // logger.warn("Client last hole punch is too old, skipping all sites"); - // return; - // } + // this prevents us from accepting a register from an olm that has not hole punched yet. + // the olm will pump the register so we can keep checking + if (now - (client.lastHolePunch || 0) > 5) { + logger.warn( + "Client last hole punch is too old; skipping this register" + ); + return; + } if (client.pubKey !== publicKey) { logger.info( @@ -319,7 +298,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.warn(`Exit node not found for site ${site.siteId}`); continue; } - endpoint = `${exitNode.endpoint}:21820`; + endpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`; } const allSiteResources = await db // only get the site resources that this client has access to diff --git a/server/routers/olm/handleOlmRelayMessage.ts b/server/routers/olm/handleOlmRelayMessage.ts index 153c4e7c..5479ccbb 100644 --- a/server/routers/olm/handleOlmRelayMessage.ts +++ b/server/routers/olm/handleOlmRelayMessage.ts @@ -2,7 +2,7 @@ import { db, exitNodes, sites } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; import { clients, clientSitesAssociationsCache, Olm } from "@server/db"; import { and, eq } from "drizzle-orm"; -import { updatePeer } from "../newt/peers"; +import { updatePeer as newtUpdatePeer } from "../newt/peers"; import logger from "@server/logger"; export const handleOlmRelayMessage: MessageHandler = async (context) => { @@ -79,18 +79,20 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => { ); // update the peer on the exit node - await updatePeer(siteId, client.pubKey, { - endpoint: "" // this removes the endpoint + await newtUpdatePeer(siteId, client.pubKey, { + endpoint: "" // this removes the endpoint so the exit node knows to relay }); - sendToClient(olm.olmId, { - type: "olm/wg/peer/relay", + return { + message: { + type: "olm/wg/peer/relay", data: { siteId: siteId, endpoint: exitNode.endpoint, publicKey: exitNode.publicKey } - }); - - return; + }, + broadcast: false, + excludeSender: false + }; }; diff --git a/server/routers/olm/handleOlmServerPeerAddMessage.ts b/server/routers/olm/handleOlmServerPeerAddMessage.ts new file mode 100644 index 00000000..3d0d61b2 --- /dev/null +++ b/server/routers/olm/handleOlmServerPeerAddMessage.ts @@ -0,0 +1,185 @@ +import { + Client, + clientSiteResourcesAssociationsCache, + db, + ExitNode, + Org, + orgs, + roleClients, + roles, + siteResources, + Transaction, + userClients, + userOrgs, + users +} from "@server/db"; +import { MessageHandler } from "@server/routers/ws"; +import { + clients, + clientSitesAssociationsCache, + exitNodes, + Olm, + olms, + sites +} from "@server/db"; +import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm"; +import { addPeer, deletePeer } from "../newt/peers"; +import logger from "@server/logger"; +import { listExitNodes } from "#dynamic/lib/exitNodes"; +import { + generateAliasConfig, + getNextAvailableClientSubnet +} from "@server/lib/ip"; +import { generateRemoteSubnets } from "@server/lib/ip"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; +import { validateSessionToken } from "@server/auth/sessions/app"; +import config from "@server/lib/config"; +import { + addPeer as newtAddPeer, + deletePeer as newtDeletePeer +} from "@server/routers/newt/peers"; + +export const handleOlmServerPeerAddMessage: MessageHandler = async ( + context +) => { + logger.info("Handling register olm message!"); + const { message, client: c, sendToClient } = context; + const olm = c as Olm; + + const now = Math.floor(Date.now() / 1000); + + if (!olm) { + logger.warn("Olm not found"); + return; + } + + const { siteId } = message.data; + + // get the site + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + + if (!site) { + logger.error( + `handleOlmServerPeerAddMessage: Site with ID ${siteId} not found` + ); + return; + } + + if (!site.endpoint) { + logger.error( + `handleOlmServerPeerAddMessage: Site with ID ${siteId} has no endpoint` + ); + return; + } + + // get the client + + if (!olm.clientId) { + logger.error( + `handleOlmServerPeerAddMessage: Olm with ID ${olm.olmId} has no clientId` + ); + return; + } + + const [client] = await db + .select() + .from(clients) + .where(and(eq(clients.clientId, olm.clientId))) + .limit(1); + + if (!client) { + logger.error( + `handleOlmServerPeerAddMessage: Client with ID ${olm.clientId} not found` + ); + return; + } + + if (!client.pubKey) { + logger.error( + `handleOlmServerPeerAddMessage: Client with ID ${client.clientId} has no public key` + ); + return; + } + + let endpoint: string | null = null; + + + const currentSessionSiteAssociationCaches = await db + .select() + .from(clientSitesAssociationsCache) + .where( + and( + eq(clientSitesAssociationsCache.clientId, client.clientId), + isNotNull(clientSitesAssociationsCache.endpoint), + eq(clientSitesAssociationsCache.publicKey, client.pubKey) // limit it to the current session its connected with otherwise the endpoint could be stale + ) + ); + + // pick an endpoint + for (const assoc of currentSessionSiteAssociationCaches) { + if (assoc.endpoint) { + endpoint = assoc.endpoint; + break; + } + } + + if (!endpoint) { + logger.error( + `handleOlmServerPeerAddMessage: No endpoint found for client ${client.clientId}` + ); + return; + } + + await newtAddPeer(siteId, { + publicKey: client.pubKey, + allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client + endpoint: endpoint // this is the client's endpoint with reference to the site's exit node + }); + + const allSiteResources = await db // only get the site resources that this client has access to + .select() + .from(siteResources) + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + siteResources.siteResourceId, + clientSiteResourcesAssociationsCache.siteResourceId + ) + ) + .where( + and( + eq(siteResources.siteId, site.siteId), + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ) + ) + ); + + // Return connect message with all site configurations + return { + message: { + type: "olm/wg/peer/add", + data: { + siteId: site.siteId, + endpoint: site.endpoint, + publicKey: site.publicKey, + serverIP: site.address, + serverPort: site.listenPort, + remoteSubnets: generateRemoteSubnets( + allSiteResources.map(({ siteResources }) => siteResources) + ), + aliases: generateAliasConfig( + allSiteResources.map(({ siteResources }) => siteResources) + ) + } + }, + broadcast: false, + excludeSender: false + }; +}; diff --git a/server/routers/olm/index.ts b/server/routers/olm/index.ts index 7adbf859..0fc65d92 100644 --- a/server/routers/olm/index.ts +++ b/server/routers/olm/index.ts @@ -7,3 +7,4 @@ export * from "./deleteUserOlm"; export * from "./listUserOlms"; export * from "./deleteUserOlm"; export * from "./getUserOlm"; +export * from "./handleOlmServerPeerAddMessage"; \ No newline at end of file diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index 69ea2bc9..7651f0a9 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -3,6 +3,7 @@ import { clients, olms, newts, sites } from "@server/db"; import { eq } from "drizzle-orm"; import { sendToClient } from "#dynamic/routers/ws"; import logger from "@server/logger"; +import { exit } from "process"; export async function addPeer( clientId: number, @@ -110,3 +111,40 @@ export async function updatePeer( logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`); } + +export async function initPeerAddHandshake( + clientId: number, + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }, + olmId?: string +) { + if (!olmId) { + const [olm] = await db + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + olmId = olm.olmId; + } + + await sendToClient(olmId, { + type: "olm/wg/peer/holepunch/site/add", + data: { + siteId: peer.siteId, + exitNode: { + publicKey: peer.exitNode.publicKey, + endpoint: peer.exitNode.endpoint + } + } + }); + + logger.info(`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`); +} diff --git a/server/routers/ws/messageHandlers.ts b/server/routers/ws/messageHandlers.ts index cbb023b3..b92e7530 100644 --- a/server/routers/ws/messageHandlers.ts +++ b/server/routers/ws/messageHandlers.ts @@ -11,23 +11,25 @@ import { handleOlmRegisterMessage, handleOlmRelayMessage, handleOlmPingMessage, - startOlmOfflineChecker + startOlmOfflineChecker, + handleOlmServerPeerAddMessage } from "../olm"; import { handleHealthcheckStatusMessage } from "../target"; import { MessageHandler } from "./types"; export const messageHandlers: Record = { - "newt/wg/register": handleNewtRegisterMessage, + "olm/wg/server/peer/add": handleOlmServerPeerAddMessage, "olm/wg/register": handleOlmRegisterMessage, - "newt/wg/get-config": handleGetConfigMessage, - "newt/receive-bandwidth": handleReceiveBandwidthMessage, "olm/wg/relay": handleOlmRelayMessage, "olm/ping": handleOlmPingMessage, + "newt/wg/register": handleNewtRegisterMessage, + "newt/wg/get-config": handleGetConfigMessage, + "newt/receive-bandwidth": handleReceiveBandwidthMessage, "newt/socket/status": handleDockerStatusMessage, "newt/socket/containers": handleDockerContainersMessage, "newt/ping/request": handleNewtPingRequestMessage, "newt/blueprint/apply": handleApplyBlueprintMessage, - "newt/healthcheck/status": handleHealthcheckStatusMessage, + "newt/healthcheck/status": handleHealthcheckStatusMessage }; -startOlmOfflineChecker(); // this is to handle the offline check for olms \ No newline at end of file +startOlmOfflineChecker(); // this is to handle the offline check for olms From b5e94d44ae1bb6fbb3746769ed530409009dff9b Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 1 Dec 2025 15:44:25 -0500 Subject: [PATCH 11/14] Fix switching orgs having connections from other orgs --- server/lib/rebuildClientAssociations.ts | 24 +- server/routers/olm/getOlmToken.ts | 48 +++- .../routers/olm/handleOlmRegisterMessage.ts | 235 +++--------------- 3 files changed, 95 insertions(+), 212 deletions(-) diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index 4be2c92a..810acdef 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -465,7 +465,8 @@ async function handleMessagesForSiteClients( } if (isAdd) { - await holepunchSiteAdd( // this will kick off the add peer process for the client + await holepunchSiteAdd( + // this will kick off the add peer process for the client client.clientId, { siteId, @@ -728,7 +729,19 @@ export async function rebuildClientAssociationsFromClient( const userSiteResourceIds = await trx .select({ siteResourceId: userSiteResources.siteResourceId }) .from(userSiteResources) - .where(eq(userSiteResources.userId, client.userId)); + .innerJoin( + siteResources, + eq( + siteResources.siteResourceId, + userSiteResources.siteResourceId + ) + ) + .where( + and( + eq(userSiteResources.userId, client.userId), + eq(siteResources.orgId, client.orgId) + ) + ); // this needs to be locked onto this org or else cross-org access could happen newSiteResourceIds.push( ...userSiteResourceIds.map((r) => r.siteResourceId) @@ -738,7 +751,12 @@ export async function rebuildClientAssociationsFromClient( const roleIds = await trx .select({ roleId: userOrgs.roleId }) .from(userOrgs) - .where(eq(userOrgs.userId, client.userId)) + .where( + and( + eq(userOrgs.userId, client.userId), + eq(userOrgs.orgId, client.orgId) + ) + ) // this needs to be locked onto this org or else cross-org access could happen .then((rows) => rows.map((row) => row.roleId)); if (roleIds.length > 0) { diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index 33f5fa2d..cea8386c 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -1,5 +1,12 @@ import { generateSessionToken } from "@server/auth/sessions/app"; -import { clients, db, ExitNode, exitNodes, sites, clientSitesAssociationsCache } from "@server/db"; +import { + clients, + db, + ExitNode, + exitNodes, + sites, + clientSitesAssociationsCache +} from "@server/db"; import { olms } from "@server/db"; import HttpCode from "@server/types/HttpCode"; import response from "@server/lib/response"; @@ -99,6 +106,7 @@ export async function getOlmToken( await createOlmSession(resToken, existingOlm.olmId); let orgIdToUse = orgId; + let clientIdToUse; if (!orgIdToUse) { if (!existingOlm.clientId) { return next( @@ -114,7 +122,7 @@ export async function getOlmToken( .from(clients) .where(eq(clients.clientId, existingOlm.clientId)) .limit(1); - + if (!client) { return next( createHttpError( @@ -125,6 +133,40 @@ export async function getOlmToken( } orgIdToUse = client.orgId; + clientIdToUse = client.clientId; + } else { + // we did provide the org + const [client] = await db + .select() + .from(clients) + .where(eq(clients.orgId, orgIdToUse)) + .limit(1); + + if (!client) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "No client found for provided orgId" + ) + ); + } + + if (existingOlm.clientId !== client.clientId) { + // we only need to do this if the client is changing + + logger.debug( + `Switching olm client ${existingOlm.olmId} to org ${orgId} for user ${existingOlm.userId}` + ); + + await db + .update(olms) + .set({ + clientId: client.clientId + }) + .where(eq(olms.olmId, existingOlm.olmId)); + } + + clientIdToUse = client.clientId; } // Get all exit nodes from sites where the client has peers @@ -135,7 +177,7 @@ export async function getOlmToken( sites, eq(sites.siteId, clientSitesAssociationsCache.siteId) ) - .where(eq(clientSitesAssociationsCache.clientId, existingOlm.clientId!)); + .where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!)); // Extract unique exit node IDs const exitNodeIds = Array.from( diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index ce40e832..6f3e59c1 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -48,45 +48,36 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } - const { - publicKey, - relay, - olmVersion, - orgId, - doNotCreateNewClient, - token: userToken - } = message.data; + const { publicKey, relay, olmVersion, orgId, userToken } = message.data; - let client: Client | undefined; - let org: Org | undefined; + if (!olm.clientId) { + logger.warn("Olm client ID not found"); + return; + } + + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, olm.clientId)) + .limit(1); + + if (!client) { + logger.warn("Client ID not found"); + return; + } + + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, client.orgId)) + .limit(1); + + if (!org) { + logger.warn("Org not found"); + return; + } if (orgId) { - try { - const { client: clientRes, org: orgRes } = - await getOrCreateOrgClient( - orgId, - olm.userId, - olm.olmId, - olm.name || "User Device", - // doNotCreateNewClient ? true : false - true // for now never create a new client automatically because we create the users clients when they are added to the org - // this means that the rebuildClientAssociationsFromClient call below issue is not a problem - ); - - client = clientRes; - org = orgRes; - } catch (err) { - logger.error( - `Error switching olm client ${olm.olmId} to org ${orgId}: ${err}` - ); - return; - } - - if (!client) { - logger.warn("Client not found"); - return; - } - if (!olm.userId) { logger.warn("Olm has no user ID"); return; @@ -95,11 +86,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { const { session: userSession, user } = await validateSessionToken(userToken); if (!userSession || !user) { - logger.warn("Invalid user session for olm ping"); + logger.warn("Invalid user session for olm register"); return; // by returning here we just ignore the ping and the setInterval will force it to disconnect } if (user.userId !== olm.userId) { - logger.warn("User ID mismatch for olm ping"); + logger.warn("User ID mismatch for olm register"); return; } @@ -115,48 +106,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { ); return; } - - logger.debug( - `Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}` - ); - - if (olm.clientId !== client.clientId) { // we only need to do this if the client is changing - await db - .update(olms) - .set({ - clientId: client.clientId - }) - .where(eq(olms.olmId, olm.olmId)); - } - } else { - if (!olm.clientId) { - logger.warn("Olm has no client ID!"); - return; - } - - logger.debug(`Using last connected org for client ${olm.clientId}`); - - [client] = await db - .select() - .from(clients) - .where(eq(clients.clientId, olm.clientId)) - .limit(1); - - [org] = await db - .select() - .from(orgs) - .where(eq(orgs.orgId, client.orgId)) - .limit(1); - } - - if (!client) { - logger.warn("Client ID not found"); - return; - } - - if (!org) { - logger.warn("Org not found"); - return; } logger.debug( @@ -357,129 +306,3 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { excludeSender: false }; }; - -async function getOrCreateOrgClient( - orgId: string, - userId: string | null, - olmId: string, - name: string, - doNotCreateNewClient: boolean, - trx: Transaction | typeof db = db -): Promise<{ - client: Client; - org: Org; -}> { - // get the org - const [org] = await trx - .select() - .from(orgs) - .where(eq(orgs.orgId, orgId)) - .limit(1); - - if (!org) { - throw new Error("Org not found"); - } - - if (!org.subnet) { - throw new Error("Org has no subnet defined"); - } - - // check if the user has a client in the org and if not then create a client for them - const [existingClient] = await trx - .select() - .from(clients) - .where( - and( - eq(clients.orgId, orgId), - userId ? eq(clients.userId, userId) : isNull(clients.userId), // we dont check the user id if it is null because the olm is not tied to a user? - eq(clients.olmId, olmId) - ) - ) // checking the olmid here because we want to create a new client PER OLM PER ORG - .limit(1); - - let client = existingClient; - if (!client && !doNotCreateNewClient) { - logger.debug( - `Client does not exist in org ${orgId}, creating new client for user ${userId}` - ); - - if (!userId) { - throw new Error("User ID is required to create client in org"); - } - - // Verify that the user belongs to the org - const [userOrg] = await trx - .select() - .from(userOrgs) - .where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId))) - .limit(1); - - if (!userOrg) { - throw new Error("User does not belong to org"); - } - - // TODO: more intelligent way to pick the exit node - const exitNodesList = await listExitNodes(orgId); - const randomExitNode = - exitNodesList[Math.floor(Math.random() * exitNodesList.length)]; - - const [adminRole] = await trx - .select() - .from(roles) - .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) - .limit(1); - - if (!adminRole) { - throw new Error("Admin role not found"); - } - - const newSubnet = await getNextAvailableClientSubnet(orgId); - if (!newSubnet) { - throw new Error("No available subnet found"); - } - - const subnet = newSubnet.split("/")[0]; - const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org - - const [newClient] = await trx - .insert(clients) - .values({ - exitNodeId: randomExitNode.exitNodeId, - orgId, - name, - subnet: updatedSubnet, - type: "olm", - userId: userId, - olmId: olmId // to lock this client to the olm even as the olm moves between clients in different orgs - }) - .returning(); - - await trx.insert(roleClients).values({ - roleId: adminRole.roleId, - clientId: newClient.clientId - }); - - await trx.insert(userClients).values({ - // we also want to make sure that the user can see their own client if they are not an admin - userId, - clientId: newClient.clientId - }); - - if (userOrg.roleId != adminRole.roleId) { - // make sure the user can access the client - trx.insert(userClients).values({ - userId, - clientId: newClient.clientId - }); - } - - await rebuildClientAssociationsFromClient(newClient, trx); // TODO: this will try to messages to the olm which has not connected yet - is that a problem? - - client = newClient; - } - - return { - client: client, - org: org - }; -} From beea28daf3932d51ddaa99eefb36978440739893 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 1 Dec 2025 16:20:10 -0500 Subject: [PATCH 12/14] Handle hp oddities --- server/routers/gerbil/updateHolePunch.ts | 10 +++++++-- .../routers/olm/handleOlmRegisterMessage.ts | 21 ++++++++++--------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index c5f0d2b8..208f3a98 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -230,7 +230,9 @@ export async function updateAndGenerateEndpointDestinations( .returning(); if ( - updatedClientSitesAssociationsCache.endpoint !== site.endpoint // this is the endpoint from the join table not the site + updatedClientSitesAssociationsCache.endpoint !== + site.endpoint && // this is the endpoint from the join table not the site + updatedClient.pubKey === publicKey // only trigger if the client's public key matches the current public key which means it has registered so we dont prematurely send the update ) { logger.info( `ClientSitesAssociationsCache for client ${olm.clientId} and site ${site.siteId} endpoint changed from ${site.endpoint} to ${updatedClientSitesAssociationsCache.endpoint}` @@ -318,7 +320,11 @@ export async function updateAndGenerateEndpointDestinations( .where(eq(sites.siteId, newt.siteId)) .returning(); - if (updatedSite.endpoint != site.endpoint) { + if ( + updatedSite.endpoint != site.endpoint && + updatedSite.publicKey == publicKey + ) { + // only trigger if the site's public key matches the current public key which means it has registered so we dont prematurely send the update logger.info( `Site ${newt.siteId} endpoint changed from ${site.endpoint} to ${updatedSite.endpoint}` ); diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 6f3e59c1..7cde2c76 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -126,15 +126,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .where(eq(olms.olmId, olm.olmId)); } - // this prevents us from accepting a register from an olm that has not hole punched yet. - // the olm will pump the register so we can keep checking - if (now - (client.lastHolePunch || 0) > 5) { - logger.warn( - "Client last hole punch is too old; skipping this register" - ); - return; - } - if (client.pubKey !== publicKey) { logger.info( "Public key mismatch. Updating public key and clearing session info..." @@ -172,8 +163,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { `Found ${sitesData.length} sites for client ${client.clientId}` ); + // this prevents us from accepting a register from an olm that has not hole punched yet. + // the olm will pump the register so we can keep checking + // TODO: I still think there is a better way to do this rather than locking it out here but ??? + if (now - (client.lastHolePunch || 0) > 5 && sitesData.length > 0) { + logger.warn( + "Client last hole punch is too old and we have sites to send; skipping this register" + ); + return; + } + // Process each site - for (const { sites: site } of sitesData) { + for (const { sites: site, clientSitesAssociationsCache: association } of sitesData) { if (!site.exitNodeId) { logger.warn( `Site ${site.siteId} does not have exit node, skipping` From a7e32d401353658144b301d052631d70e9b66e96 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 1 Dec 2025 19:57:23 -0500 Subject: [PATCH 13/14] Fix bugs with updating a resource --- server/routers/gerbil/updateHolePunch.ts | 6 +- .../siteResource/updateSiteResource.ts | 93 ++++++++++--------- 2 files changed, 52 insertions(+), 47 deletions(-) diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index 208f3a98..e1fa7c4c 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -203,9 +203,9 @@ export async function updateAndGenerateEndpointDestinations( // Update clientSites for each site on this exit node for (const site of sitesOnExitNode) { - logger.debug( - `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}` - ); + // logger.debug( + // `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}` + // ); // if the public key or endpoint has changed, update it otherwise continue if ( diff --git a/server/routers/siteResource/updateSiteResource.ts b/server/routers/siteResource/updateSiteResource.ts index 91b14bf3..5321f971 100644 --- a/server/routers/siteResource/updateSiteResource.ts +++ b/server/routers/siteResource/updateSiteResource.ts @@ -268,10 +268,13 @@ export async function updateSiteResource( ); // after everything is rebuilt above we still need to update the targets and remote subnets if the destination changed - if ( + const destinationChanged = existingSiteResource.destination !== - updatedSiteResource.destination - ) { + updatedSiteResource.destination; + const aliasChanged = + existingSiteResource.alias !== updatedSiteResource.alias; + + if (destinationChanged || aliasChanged) { const [newt] = await trx .select() .from(newts) @@ -284,53 +287,55 @@ export async function updateSiteResource( ); } - const oldTargets = generateSubnetProxyTargets( - existingSiteResource, - mergedAllClients - ); - const newTargets = generateSubnetProxyTargets( - updatedSiteResource, - mergedAllClients - ); + // Only update targets on newt if destination changed + if (destinationChanged) { + const oldTargets = generateSubnetProxyTargets( + existingSiteResource, + mergedAllClients + ); + const newTargets = generateSubnetProxyTargets( + updatedSiteResource, + mergedAllClients + ); - await updateTargets(newt.newtId, { - oldTargets: oldTargets, - newTargets: newTargets - }); + await updateTargets(newt.newtId, { + oldTargets: oldTargets, + newTargets: newTargets + }); + } + // Update olms for both destination and alias changes let olmJobs: Promise[] = []; for (const client of mergedAllClients) { // we also need to update the remote subnets on the olms for each client that has access to this site - try { - olmJobs.push( - updatePeerData( - client.clientId, - updatedSiteResource.siteId, - { - oldRemoteSubnets: generateRemoteSubnets([ - existingSiteResource - ]), - newRemoteSubnets: generateRemoteSubnets([ - updatedSiteResource - ]) - }, - { - oldAliases: generateAliasConfig([ - existingSiteResource - ]), - newAliases: generateAliasConfig([ - updatedSiteResource - ]) - } - ) - ); - } catch (error) { - logger.warn( + olmJobs.push( + updatePeerData( + client.clientId, + updatedSiteResource.siteId, + { + oldRemoteSubnets: generateRemoteSubnets([ + existingSiteResource + ]), + newRemoteSubnets: generateRemoteSubnets([ + updatedSiteResource + ]) + }, + { + oldAliases: generateAliasConfig([ + existingSiteResource + ]), + newAliases: generateAliasConfig([ + updatedSiteResource + ]) + } + ).catch((error) => { // this is okay because sometimes the olm is not online to receive the update or associated with the client yet - `Error updating peer data for client ${client.clientId}:`, - error - ); - } + logger.warn( + `Error updating peer data for client ${client.clientId}:`, + error + ); + }) + ); } await Promise.all(olmJobs); From 152fb47ca4cd4cdbf4dd9e7c0297e8804ea80663 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 2 Dec 2025 11:17:08 -0500 Subject: [PATCH 14/14] Handle unrelay and relaying better --- server/lib/rebuildClientAssociations.ts | 80 +++------------- .../routers/olm/handleOlmRegisterMessage.ts | 7 +- server/routers/olm/handleOlmRelayMessage.ts | 11 +-- .../olm/handleOlmServerPeerAddMessage.ts | 2 + server/routers/olm/handleOlmUnRelayMessage.ts | 96 +++++++++++++++++++ server/routers/olm/index.ts | 3 +- server/routers/olm/peers.ts | 1 + server/routers/org/deleteOrg.ts | 2 +- server/routers/ws/messageHandlers.ts | 4 +- 9 files changed, 125 insertions(+), 81 deletions(-) create mode 100644 server/routers/olm/handleOlmUnRelayMessage.ts diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index 810acdef..00156c01 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -1011,76 +1011,18 @@ async function handleMessagesForClientSites( continue; } - // Add peer to newt - const isRelayed = true; // Default to relaying for new connections - newtJobs.push( - newtAddPeer( - site.siteId, - { - publicKey: client.pubKey, - allowedIps: [`${client.subnet.split("/")[0]}/32`], - endpoint: isRelayed ? "" : "" - }, - newt.newtId - ) + await holepunchSiteAdd( + // this will kick off the add peer process for the client + client.clientId, + { + siteId: site.siteId, + exitNode: { + publicKey: exitNode.publicKey, + endpoint: exitNode.endpoint + } + }, + olmId ); - - // Get all site resources for this site that the client has access to - const accessibleResources = await trx - .select() - .from(siteResources) - .innerJoin( - clientSiteResourcesAssociationsCache, - eq( - siteResources.siteResourceId, - clientSiteResourcesAssociationsCache.siteResourceId - ) - ) - .where( - and( - eq(siteResources.siteId, site.siteId), - eq( - clientSiteResourcesAssociationsCache.clientId, - client.clientId - ) - ) - ); - try { - // Add peer to olm - olmJobs.push( - olmAddPeer( - client.clientId, - { - siteId: site.siteId, - endpoint: - isRelayed || !site.endpoint - ? `${exitNode.endpoint}:21820` - : site.endpoint, - publicKey: site.publicKey, - serverIP: site.address || "", - serverPort: site.listenPort || 0, - remoteSubnets: generateRemoteSubnets( - accessibleResources.map( - ({ siteResources }) => siteResources - ) - ) - }, - olmId - ) - ); - } catch (error) { - // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send - if ( - error instanceof Error && - error.message.includes("not found") - ) { - logger.debug( - `Olm data not found for client ${client.clientId}, skipping removal` - ); - } else { - throw error; - } - } } // Update exit node destinations diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 7cde2c76..696da748 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -237,7 +237,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { ); } - let endpoint = site.endpoint; + let relayEndpoint: string | undefined = undefined; if (relay) { const [exitNode] = await db .select() @@ -248,7 +248,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.warn(`Exit node not found for site ${site.siteId}`); continue; } - endpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`; + relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`; } const allSiteResources = await db // only get the site resources that this client has access to @@ -274,7 +274,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { // Add site configuration to the array siteConfigurations.push({ siteId: site.siteId, - endpoint: endpoint, + relayEndpoint: relayEndpoint, // this can be undefined now if not relayed + endpoint: site.endpoint, publicKey: site.publicKey, serverIP: site.address, serverPort: site.listenPort, diff --git a/server/routers/olm/handleOlmRelayMessage.ts b/server/routers/olm/handleOlmRelayMessage.ts index 5479ccbb..595b35ba 100644 --- a/server/routers/olm/handleOlmRelayMessage.ts +++ b/server/routers/olm/handleOlmRelayMessage.ts @@ -85,12 +85,11 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => { return { message: { - type: "olm/wg/peer/relay", - data: { - siteId: siteId, - endpoint: exitNode.endpoint, - publicKey: exitNode.publicKey - } + type: "olm/wg/peer/relay", + data: { + siteId: siteId, + relayEndpoint: exitNode.endpoint + } }, broadcast: false, excludeSender: false diff --git a/server/routers/olm/handleOlmServerPeerAddMessage.ts b/server/routers/olm/handleOlmServerPeerAddMessage.ts index 3d0d61b2..2e5009eb 100644 --- a/server/routers/olm/handleOlmServerPeerAddMessage.ts +++ b/server/routers/olm/handleOlmServerPeerAddMessage.ts @@ -135,6 +135,8 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async ( return; } + // NOTE: here we are always starting direct to the peer and will relay later + await newtAddPeer(siteId, { publicKey: client.pubKey, allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client diff --git a/server/routers/olm/handleOlmUnRelayMessage.ts b/server/routers/olm/handleOlmUnRelayMessage.ts new file mode 100644 index 00000000..5f47a095 --- /dev/null +++ b/server/routers/olm/handleOlmUnRelayMessage.ts @@ -0,0 +1,96 @@ +import { db, exitNodes, sites } from "@server/db"; +import { MessageHandler } from "@server/routers/ws"; +import { clients, clientSitesAssociationsCache, Olm } from "@server/db"; +import { and, eq } from "drizzle-orm"; +import { updatePeer as newtUpdatePeer } from "../newt/peers"; +import logger from "@server/logger"; + +export const handleOlmUnRelayMessage: MessageHandler = async (context) => { + const { message, client: c, sendToClient } = context; + const olm = c as Olm; + + logger.info("Handling unrelay olm message!"); + + if (!olm) { + logger.warn("Olm not found"); + return; + } + + if (!olm.clientId) { + logger.warn("Olm has no site!"); // TODO: Maybe we create the site here? + return; + } + + const clientId = olm.clientId; + + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, clientId)) + .limit(1); + + if (!client) { + logger.warn("Client not found"); + return; + } + + // make sure we hand endpoints for both the site and the client and the lastHolePunch is not too old + if (!client.pubKey) { + logger.warn("Client has no endpoint or listen port"); + return; + } + + const { siteId } = message.data; + + // Get the site + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + + if (!site) { + logger.warn("Site not found or has no exit node"); + return; + } + + const [clientSiteAssociation] = await db + .update(clientSitesAssociationsCache) + .set({ + isRelayed: false + }) + .where( + and( + eq(clientSitesAssociationsCache.clientId, olm.clientId), + eq(clientSitesAssociationsCache.siteId, siteId) + ) + ) + .returning(); + + if (!clientSiteAssociation) { + logger.warn("Client-Site association not found"); + return; + } + + if (!clientSiteAssociation.endpoint) { + logger.warn("Client-Site association has no endpoint, cannot unrelay"); + return; + } + + // update the peer on the exit node + await newtUpdatePeer(siteId, client.pubKey, { + endpoint: clientSiteAssociation.endpoint // this is the endpoint of the client to connect directly to the exit node + }); + + return { + message: { + type: "olm/wg/peer/unrelay", + data: { + siteId: siteId, + endpoint: site.endpoint + } + }, + broadcast: false, + excludeSender: false + }; +}; diff --git a/server/routers/olm/index.ts b/server/routers/olm/index.ts index 0fc65d92..e671dd42 100644 --- a/server/routers/olm/index.ts +++ b/server/routers/olm/index.ts @@ -7,4 +7,5 @@ export * from "./deleteUserOlm"; export * from "./listUserOlms"; export * from "./deleteUserOlm"; export * from "./getUserOlm"; -export * from "./handleOlmServerPeerAddMessage"; \ No newline at end of file +export * from "./handleOlmServerPeerAddMessage"; +export * from "./handleOlmUnRelayMessage"; \ No newline at end of file diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index 7651f0a9..87c634cc 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -103,6 +103,7 @@ export async function updatePeer( siteId: peer.siteId, publicKey: peer.publicKey, endpoint: peer.endpoint, + relayEndpoint: peer.serverIP, serverIP: peer.serverIP, serverPort: peer.serverPort, remoteSubnets: peer.remoteSubnets diff --git a/server/routers/org/deleteOrg.ts b/server/routers/org/deleteOrg.ts index 098c5c41..35dc7503 100644 --- a/server/routers/org/deleteOrg.ts +++ b/server/routers/org/deleteOrg.ts @@ -193,7 +193,7 @@ export async function deleteOrg( // Send termination messages outside of transaction to prevent blocking for (const newtId of deletedNewtIds) { const payload = { - type: `newt/terminate`, + type: `newt/wg/terminate`, data: {} }; // Don't await this to prevent blocking the response diff --git a/server/routers/ws/messageHandlers.ts b/server/routers/ws/messageHandlers.ts index b92e7530..acd1aef0 100644 --- a/server/routers/ws/messageHandlers.ts +++ b/server/routers/ws/messageHandlers.ts @@ -12,7 +12,8 @@ import { handleOlmRelayMessage, handleOlmPingMessage, startOlmOfflineChecker, - handleOlmServerPeerAddMessage + handleOlmServerPeerAddMessage, + handleOlmUnRelayMessage } from "../olm"; import { handleHealthcheckStatusMessage } from "../target"; import { MessageHandler } from "./types"; @@ -21,6 +22,7 @@ export const messageHandlers: Record = { "olm/wg/server/peer/add": handleOlmServerPeerAddMessage, "olm/wg/register": handleOlmRegisterMessage, "olm/wg/relay": handleOlmRelayMessage, + "olm/wg/unrelay": handleOlmUnRelayMessage, "olm/ping": handleOlmPingMessage, "newt/wg/register": handleNewtRegisterMessage, "newt/wg/get-config": handleGetConfigMessage,