From ec1f94791ade6cf2c0db27844976b0340fcea529 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 6 Nov 2025 17:59:34 -0800 Subject: [PATCH] Remove siteIds and build associations from user role chnages --- server/lib/rebuildSiteClientAssociations.ts | 389 ++++++++++++++++++ server/routers/client/updateClient.ts | 286 +------------ server/routers/newt/peers.ts | 136 +++--- .../routers/olm/handleOlmRegisterMessage.ts | 5 + server/routers/olm/peers.ts | 76 ++-- .../siteResource/createSiteResource.ts | 112 ++--- .../siteResource/deleteSiteResource.ts | 95 +++-- .../siteResource/setSiteResourceRoles.ts | 8 +- .../siteResource/setSiteResourceUsers.ts | 21 +- 9 files changed, 669 insertions(+), 459 deletions(-) create mode 100644 server/lib/rebuildSiteClientAssociations.ts diff --git a/server/lib/rebuildSiteClientAssociations.ts b/server/lib/rebuildSiteClientAssociations.ts new file mode 100644 index 00000000..5003b8e1 --- /dev/null +++ b/server/lib/rebuildSiteClientAssociations.ts @@ -0,0 +1,389 @@ +import { + Client, + clients, + clientSites, + db, + exitNodes, + newts, + olms, + roleSiteResources, + Site, + SiteResource, + sites, + Transaction, + userOrgs, + users, + userSiteResources +} from "@server/db"; +import { and, eq, inArray } from "drizzle-orm"; + +import { + addPeer as newtAddPeer, + deletePeer as newtDeletePeer +} from "@server/routers/newt/peers"; +import { + addPeer as olmAddPeer, + deletePeer as olmDeletePeer +} from "@server/routers/olm/peers"; +import { sendToExitNode } from "#dynamic/lib/exitNodes"; +import logger from "@server/logger"; + +export async function rebuildSiteClientAssociations( + siteResource: SiteResource, + trx: Transaction | typeof db = db +): Promise { + const siteId = siteResource.siteId; + + // get the site + const [site] = await trx + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + + if (!site) { + throw new Error(`Site with ID ${siteId} not found`); + } + + const roleIds = await trx + .select() + .from(roleSiteResources) + .where( + eq(roleSiteResources.siteResourceId, siteResource.siteResourceId) + ) + .then((rows) => rows.map((row) => row.roleId)); + + const directUserIds = await trx + .select() + .from(userSiteResources) + .where( + eq(userSiteResources.siteResourceId, siteResource.siteResourceId) + ) + .then((rows) => rows.map((row) => row.userId)); + + // get all of the users in these roles + const userIdsFromRoles = await trx + .select({ + userId: users.userId + }) + .from(userOrgs) + .where(inArray(userOrgs.roleId, roleIds)) + .then((rows) => rows.map((row) => row.userId)); + + const allUserIds = Array.from( + new Set([...directUserIds, ...userIdsFromRoles]) + ); + + const allClients = await trx + .select({ + clientId: clients.clientId, + pubKey: clients.pubKey, + subnet: clients.subnet + }) + .from(clients) + .where(inArray(clients.userId, allUserIds)); + + const allClientIds = allClients.map((client) => client.clientId); + + const existingClientSiteIds = await trx + .select({ + clientId: clientSites.clientId + }) + .from(clientSites) + .where(eq(clientSites.siteId, siteId)) + .then((rows) => rows.map((row) => row.clientId)); + + const clientSitesToAdd = allClientIds.filter( + (clientId) => !existingClientSiteIds.includes(clientId) + ); + + const clientSitesToInsert = allClientIds + .filter((clientId) => !existingClientSiteIds.includes(clientId)) + .map((clientId) => ({ + clientId, + siteId + })); + + if (clientSitesToInsert.length > 0) { + await trx.insert(clientSites).values(clientSitesToInsert); + } + + // Now remove any client-site associations that should no longer exist + const clientSitesToRemove = existingClientSiteIds.filter( + (clientId) => !allClientIds.includes(clientId) + ); + + if (clientSitesToRemove.length > 0) { + await trx + .delete(clientSites) + .where( + and( + eq(clientSites.siteId, siteId), + inArray(clientSites.clientId, clientSitesToRemove) + ) + ); + } + + // Now handle the messages to add/remove peers on both the newt and olm sides + await handleMessagesForSiteClients( + site, + siteId, + allClients, + clientSitesToAdd, + clientSitesToRemove, + trx + ); +} + +async function handleMessagesForSiteClients( + site: Site, + siteId: number, + allClients: { + clientId: number; + pubKey: string | null; + subnet: string | null; + }[], + clientSitesToAdd: number[], + clientSitesToRemove: number[], + trx: Transaction | typeof db = db +): Promise { + if (!site.exitNodeId) { + logger.warn( + `Exit node ID not on site ${site.siteId} so there is no reason to update clients because it must be offline` + ); + return; + } + + // get the exit node for the site + const [exitNode] = await trx + .select() + .from(exitNodes) + .where(eq(exitNodes.exitNodeId, site.exitNodeId)) + .limit(1); + + if (!exitNode) { + logger.warn( + `Exit node not found for site ${site.siteId} so there is no reason to update clients because it must be offline` + ); + return; + } + + if (!site.publicKey) { + logger.warn( + `Site publicKey not set for site ${site.siteId} so cannot add peers to clients` + ); + return; + } + + const [newt] = await trx + .select({ + newtId: newts.newtId + }) + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); + if (!newt) { + logger.warn( + `Newt not found for site ${siteId} so cannot add peers to clients` + ); + return; + } + + let newtJobs: Promise[] = []; + let olmJobs: Promise[] = []; + let exitNodeJobs: Promise[] = []; + for (const client of allClients) { + // UPDATE THE NEWT + if (!client.subnet || !client.pubKey) { + logger.debug("Client subnet, pubKey or endpoint is not set"); + continue; + } + + // is this an add or a delete? + const isAdd = clientSitesToAdd.includes(client.clientId); + const isDelete = clientSitesToRemove.includes(client.clientId); + + if (!isAdd && !isDelete) { + // nothing to do for this client + continue; + } + + const [olm] = await trx + .select({ + olmId: olms.olmId + }) + .from(olms) + .where(eq(olms.clientId, client.clientId)) + .limit(1); + if (!olm) { + logger.warn( + `Olm not found for client ${client.clientId} so cannot add/delete peers` + ); + continue; + } + + if (isDelete) { + newtJobs.push(newtDeletePeer(siteId, client.pubKey, newt.newtId)); + olmJobs.push( + olmDeletePeer( + client.clientId, + siteId, + site.publicKey, + olm.olmId + ) + ); + } + + if (isAdd) { + // TODO: WE NEED TO HANDLE THIS BETTER. WE ARE DEFAULTING TO RELAYING FOR NEW SITES + // BUT REALLY WE NEED TO TRACK THE USERS PREFERENCE THAT THEY CHOSE IN THE CLIENTS + // AND TRIGGER A HOLEPUNCH OR SOMETHING TO GET THE ENDPOINT AND HP TO THE NEW SITES + const isRelayed = true; + + newtJobs.push( + newtAddPeer( + siteId, + { + publicKey: client.pubKey, + allowedIps: [`${client.subnet.split("/")[-1]}/32`], // we want to only allow from that client + // endpoint: isRelayed ? "" : clientSite.endpoint + endpoint: isRelayed ? "" : "" // we are not HPing yet so no endpoint + }, + newt.newtId + ) + ); + + olmJobs.push( + olmAddPeer( + client.clientId, + { + siteId: site.siteId, + endpoint: + isRelayed || !site.endpoint + ? `${exitNode.endpoint}:21820` + : site.endpoint, + publicKey: site.publicKey, + serverIP: site.address, + serverPort: site.listenPort, + remoteSubnets: site.remoteSubnets + }, + olm.olmId + ) + ); + } + + exitNodeJobs.push(updateClientSiteDestinations(client, trx)); + } + + await Promise.all(exitNodeJobs); + await Promise.all(newtJobs); // do the servers first to make sure they are ready? + await Promise.all(olmJobs); +} + +interface PeerDestination { + destinationIP: string; + destinationPort: number; +} + +// this updates the relay destinations for a client to point to all of the new sites +export async function updateClientSiteDestinations( + client: { + clientId: number; + pubKey: string | null; + subnet: string | null; + }, + trx: Transaction | typeof db = db +): Promise { + let exitNodeDestinations: { + reachableAt: string; + exitNodeId: number; + type: string; + name: string; + sourceIp: string; + sourcePort: number; + destinations: PeerDestination[]; + }[] = []; + + const sitesData = await trx + .select() + .from(sites) + .innerJoin(clientSites, eq(sites.siteId, clientSites.siteId)) + .leftJoin(exitNodes, eq(sites.exitNodeId, exitNodes.exitNodeId)) + .where(eq(clientSites.clientId, client.clientId)); + + for (const site of sitesData) { + if (!site.sites.subnet) { + logger.warn(`Site ${site.sites.siteId} has no subnet, skipping`); + continue; + } + + if (!site.clientSites.endpoint) { + logger.warn(`Site ${site.sites.siteId} has no endpoint, skipping`); + continue; + } + + // find the destinations in the array + let destinations = exitNodeDestinations.find( + (d) => d.reachableAt === site.exitNodes?.reachableAt + ); + + if (!destinations) { + destinations = { + reachableAt: site.exitNodes?.reachableAt || "", + exitNodeId: site.exitNodes?.exitNodeId || 0, + type: site.exitNodes?.type || "", + name: site.exitNodes?.name || "", + sourceIp: site.clientSites.endpoint.split(":")[0] || "", + sourcePort: + parseInt(site.clientSites.endpoint.split(":")[1]) || 0, + destinations: [ + { + destinationIP: site.sites.subnet.split("/")[0], + destinationPort: site.sites.listenPort || 0 + } + ] + }; + } else { + // add to the existing destinations + destinations.destinations.push({ + destinationIP: site.sites.subnet.split("/")[0], + destinationPort: site.sites.listenPort || 0 + }); + } + + // update it in the array + exitNodeDestinations = exitNodeDestinations.filter( + (d) => d.reachableAt !== site.exitNodes?.reachableAt + ); + exitNodeDestinations.push(destinations); + } + + for (const destination of exitNodeDestinations) { + logger.info( + `Updating destinations for exit node at ${destination.reachableAt}` + ); + const payload = { + sourceIp: destination.sourceIp, + sourcePort: destination.sourcePort, + destinations: destination.destinations + }; + logger.info( + `Payload for update-destinations: ${JSON.stringify(payload, null, 2)}` + ); + + // Create an ExitNode-like object for sendToExitNode + const exitNodeForComm = { + exitNodeId: destination.exitNodeId, + type: destination.type, + reachableAt: destination.reachableAt, + name: destination.name + } as any; // Using 'as any' since we know sendToExitNode will handle this correctly + + await sendToExitNode(exitNodeForComm, { + remoteType: "remoteExitNode/update-destinations", + localPath: "/update-destinations", + method: "POST", + data: payload + }); + } +} diff --git a/server/routers/client/updateClient.ts b/server/routers/client/updateClient.ts index 884a9864..f0eef459 100644 --- a/server/routers/client/updateClient.ts +++ b/server/routers/client/updateClient.ts @@ -9,15 +9,6 @@ import logger from "@server/logger"; import { eq, and } from "drizzle-orm"; import { fromError } from "zod-validation-error"; import { OpenAPITags, registry } from "@server/openApi"; -import { - addPeer as newtAddPeer, - deletePeer as newtDeletePeer -} from "../newt/peers"; -import { - addPeer as olmAddPeer, - deletePeer as olmDeletePeer -} from "../olm/peers"; -import { sendToExitNode } from "#dynamic/lib/exitNodes"; const updateClientParamsSchema = z .object({ @@ -27,10 +18,7 @@ const updateClientParamsSchema = z const updateClientSchema = z .object({ - name: z.string().min(1).max(255).optional(), - siteIds: z - .array(z.number().int().positive()) - .optional() + name: z.string().min(1).max(255).optional() }) .strict(); @@ -54,11 +42,6 @@ registry.registerPath({ responses: {} }); -interface PeerDestination { - destinationIP: string; - destinationPort: number; -} - export async function updateClient( req: Request, res: Response, @@ -75,7 +58,7 @@ export async function updateClient( ); } - const { name, siteIds } = parsedBody.data; + const { name } = parsedBody.data; const parsedParams = updateClientParamsSchema.safeParse(req.params); if (!parsedParams.success) { @@ -105,266 +88,11 @@ export async function updateClient( ); } - let sitesAdded = []; - let sitesRemoved = []; - - // Fetch existing site associations - const existingSites = await db - .select({ siteId: clientSites.siteId }) - .from(clientSites) - .where(eq(clientSites.clientId, clientId)); - - const existingSiteIds = existingSites.map((site) => site.siteId); - - const siteIdsToProcess = siteIds || []; - // Determine which sites were added and removed - sitesAdded = siteIdsToProcess.filter( - (siteId) => !existingSiteIds.includes(siteId) - ); - sitesRemoved = existingSiteIds.filter( - (siteId) => !siteIdsToProcess.includes(siteId) - ); - - let updatedClient: Client | undefined = undefined; - let sitesData: any; // TODO: define type somehow from the query below - await db.transaction(async (trx) => { - // Update client name if provided - if (name) { - await trx - .update(clients) - .set({ name }) - .where(eq(clients.clientId, clientId)); - } - - // Update site associations if provided - // Remove sites that are no longer associated - for (const siteId of sitesRemoved) { - await trx - .delete(clientSites) - .where( - and( - eq(clientSites.clientId, clientId), - eq(clientSites.siteId, siteId) - ) - ); - } - - // Add new site associations - for (const siteId of sitesAdded) { - await trx.insert(clientSites).values({ - clientId, - siteId - }); - } - - // Fetch the updated client - [updatedClient] = await trx - .select() - .from(clients) - .where(eq(clients.clientId, clientId)) - .limit(1); - - // get all sites for this client and join with exit nodes with site.exitNodeId - sitesData = await trx - .select() - .from(sites) - .innerJoin(clientSites, eq(sites.siteId, clientSites.siteId)) - .leftJoin(exitNodes, eq(sites.exitNodeId, exitNodes.exitNodeId)) - .where(eq(clientSites.clientId, client.clientId)); - }); - - logger.info( - `Adding ${sitesAdded.length} new sites to client ${client.clientId}` - ); - for (const siteId of sitesAdded) { - if (!client.subnet || !client.pubKey) { - logger.debug("Client subnet, pubKey or endpoint is not set"); - continue; - } - - // TODO: WE NEED TO HANDLE THIS BETTER. WE ARE DEFAULTING TO RELAYING FOR NEW SITES - // BUT REALLY WE NEED TO TRACK THE USERS PREFERENCE THAT THEY CHOSE IN THE CLIENTS - // AND TRIGGER A HOLEPUNCH OR SOMETHING TO GET THE ENDPOINT AND HP TO THE NEW SITES - const isRelayed = true; - - const site = await newtAddPeer(siteId, { - publicKey: client.pubKey, - allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client - // endpoint: isRelayed ? "" : clientSite.endpoint - endpoint: isRelayed ? "" : "" // we are not HPing yet so no endpoint - }); - - if (!site) { - logger.debug("Failed to add peer to newt - missing site"); - continue; - } - - if (!site.endpoint || !site.publicKey) { - logger.debug("Site endpoint or publicKey is not set"); - continue; - } - - let endpoint; - - if (isRelayed) { - if (!site.exitNodeId) { - logger.warn( - `Site ${site.siteId} has no exit node, skipping` - ); - return null; - } - - // get the exit node for the site - const [exitNode] = await db - .select() - .from(exitNodes) - .where(eq(exitNodes.exitNodeId, site.exitNodeId)) - .limit(1); - - if (!exitNode) { - logger.warn(`Exit node not found for site ${site.siteId}`); - return null; - } - - endpoint = `${exitNode.endpoint}:21820`; - } else { - if (!site.endpoint) { - logger.warn( - `Site ${site.siteId} has no endpoint, skipping` - ); - return null; - } - endpoint = site.endpoint; - } - - await olmAddPeer(client.clientId, { - siteId: site.siteId, - endpoint: endpoint, - publicKey: site.publicKey, - serverIP: site.address, - serverPort: site.listenPort, - remoteSubnets: site.remoteSubnets - }); - } - - logger.info( - `Removing ${sitesRemoved.length} sites from client ${client.clientId}` - ); - for (const siteId of sitesRemoved) { - if (!client.pubKey) { - logger.debug("Client pubKey is not set"); - continue; - } - const site = await newtDeletePeer(siteId, client.pubKey); - if (!site) { - logger.debug("Failed to delete peer from newt - missing site"); - continue; - } - if (!site.endpoint || !site.publicKey) { - logger.debug("Site endpoint or publicKey is not set"); - continue; - } - await olmDeletePeer(client.clientId, site.siteId, site.publicKey); - } - - if (!updatedClient || !sitesData) { - return next( - createHttpError( - HttpCode.INTERNAL_SERVER_ERROR, - `Failed to update client` - ) - ); - } - - let exitNodeDestinations: { - reachableAt: string; - exitNodeId: number; - type: string; - name: string; - sourceIp: string; - sourcePort: number; - destinations: PeerDestination[]; - }[] = []; - - for (const site of sitesData) { - if (!site.sites.subnet) { - logger.warn( - `Site ${site.sites.siteId} has no subnet, skipping` - ); - continue; - } - - if (!site.clientSites.endpoint) { - logger.warn( - `Site ${site.sites.siteId} has no endpoint, skipping` - ); - continue; - } - - // find the destinations in the array - let destinations = exitNodeDestinations.find( - (d) => d.reachableAt === site.exitNodes?.reachableAt - ); - - if (!destinations) { - destinations = { - reachableAt: site.exitNodes?.reachableAt || "", - exitNodeId: site.exitNodes?.exitNodeId || 0, - type: site.exitNodes?.type || "", - name: site.exitNodes?.name || "", - sourceIp: site.clientSites.endpoint.split(":")[0] || "", - sourcePort: - parseInt(site.clientSites.endpoint.split(":")[1]) || 0, - destinations: [ - { - destinationIP: site.sites.subnet.split("/")[0], - destinationPort: site.sites.listenPort || 0 - } - ] - }; - } else { - // add to the existing destinations - destinations.destinations.push({ - destinationIP: site.sites.subnet.split("/")[0], - destinationPort: site.sites.listenPort || 0 - }); - } - - // update it in the array - exitNodeDestinations = exitNodeDestinations.filter( - (d) => d.reachableAt !== site.exitNodes?.reachableAt - ); - exitNodeDestinations.push(destinations); - } - - for (const destination of exitNodeDestinations) { - logger.info( - `Updating destinations for exit node at ${destination.reachableAt}` - ); - const payload = { - sourceIp: destination.sourceIp, - sourcePort: destination.sourcePort, - destinations: destination.destinations - }; - logger.info( - `Payload for update-destinations: ${JSON.stringify(payload, null, 2)}` - ); - - // Create an ExitNode-like object for sendToExitNode - const exitNodeForComm = { - exitNodeId: destination.exitNodeId, - type: destination.type, - reachableAt: destination.reachableAt, - name: destination.name - } as any; // Using 'as any' since we know sendToExitNode will handle this correctly - - await sendToExitNode(exitNodeForComm, { - remoteType: "remoteExitNode/update-destinations", - localPath: "/update-destinations", - method: "POST", - data: payload - }); - } + const updatedClient = await db + .update(clients) + .set({ name }) + .where(eq(clients.clientId, clientId)) + .returning(); return response(res, { data: updatedClient, diff --git a/server/routers/newt/peers.ts b/server/routers/newt/peers.ts index 03dc3460..e0c1596b 100644 --- a/server/routers/newt/peers.ts +++ b/server/routers/newt/peers.ts @@ -1,4 +1,4 @@ -import { db } from "@server/db"; +import { db, Site } from "@server/db"; import { newts, sites } from "@server/db"; import { eq } from "drizzle-orm"; import { sendToClient } from "#dynamic/routers/ws"; @@ -10,65 +10,74 @@ export async function addPeer( publicKey: string; allowedIps: string[]; endpoint: string; - } + }, + newtId?: string ) { - const [site] = await db - .select() - .from(sites) - .where(eq(sites.siteId, siteId)) - .limit(1); - if (!site) { - throw new Error(`Exit node with ID ${siteId} not found`); + let site: Site | null = null; + if (!newtId) { + [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + if (!site) { + throw new Error(`Exit node with ID ${siteId} not found`); + } + + // get the newt on the site + const [newt] = await db + .select() + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); + if (!newt) { + throw new Error(`Site found for site ${siteId}`); + } + newtId = newt.newtId; } - // get the newt on the site - const [newt] = await db - .select() - .from(newts) - .where(eq(newts.siteId, siteId)) - .limit(1); - if (!newt) { - throw new Error(`Site found for site ${siteId}`); - } - - sendToClient(newt.newtId, { + await sendToClient(newtId, { type: "newt/wg/peer/add", data: peer }); - logger.info(`Added peer ${peer.publicKey} to newt ${newt.newtId}`); + logger.info(`Added peer ${peer.publicKey} to newt ${newtId}`); return site; } -export async function deletePeer(siteId: number, publicKey: string) { - const [site] = await db - .select() - .from(sites) - .where(eq(sites.siteId, siteId)) - .limit(1); - if (!site) { - throw new Error(`Site with ID ${siteId} not found`); +export async function deletePeer(siteId: number, publicKey: string, newtId?: string) { + let site: Site | null = null; + if (!newtId) { + [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + if (!site) { + throw new Error(`Site with ID ${siteId} not found`); + } + + // get the newt on the site + const [newt] = await db + .select() + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); + if (!newt) { + throw new Error(`Newt not found for site ${siteId}`); + } + newtId = newt.newtId; } - // get the newt on the site - const [newt] = await db - .select() - .from(newts) - .where(eq(newts.siteId, siteId)) - .limit(1); - if (!newt) { - throw new Error(`Newt not found for site ${siteId}`); - } - - sendToClient(newt.newtId, { + await sendToClient(newtId, { type: "newt/wg/peer/remove", data: { publicKey } }); - logger.info(`Deleted peer ${publicKey} from newt ${newt.newtId}`); + logger.info(`Deleted peer ${publicKey} from newt ${newtId}`); return site; } @@ -79,28 +88,33 @@ export async function updatePeer( peer: { allowedIps?: string[]; endpoint?: string; - } + }, + newtId?: string ) { - const [site] = await db - .select() - .from(sites) - .where(eq(sites.siteId, siteId)) - .limit(1); - if (!site) { - throw new Error(`Site with ID ${siteId} not found`); + let site: Site | null = null; + if (!newtId) { + [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + if (!site) { + throw new Error(`Site with ID ${siteId} not found`); + } + + // get the newt on the site + const [newt] = await db + .select() + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); + if (!newt) { + throw new Error(`Newt not found for site ${siteId}`); + } + newtId = newt.newtId; } - // get the newt on the site - const [newt] = await db - .select() - .from(newts) - .where(eq(newts.siteId, siteId)) - .limit(1); - if (!newt) { - throw new Error(`Newt not found for site ${siteId}`); - } - - sendToClient(newt.newtId, { + await sendToClient(newtId, { type: "newt/wg/peer/update", data: { publicKey, @@ -108,7 +122,7 @@ export async function updatePeer( } }); - logger.info(`Updated peer ${publicKey} on newt ${newt.newtId}`); + logger.info(`Updated peer ${publicKey} on newt ${newtId}`); return site; } diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index fdf3cbe1..643a1f44 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -388,6 +388,11 @@ async function getOrCreateOrgClient( clientId: newClient.clientId }); + await trx.insert(userClients).values({ // we also want to make sure that the user can see their own client if they are not an admin + userId, + clientId: newClient.clientId + }); + if (userOrg.roleId != adminRole.roleId) { // make sure the user can access the client trx.insert(userClients).values({ diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index 396866a1..c712ea65 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -13,18 +13,22 @@ export async function addPeer( serverIP: string | null; serverPort: number | null; remoteSubnets: string | null; // optional, comma-separated list of subnets that this site can access - } + }, + olmId?: string ) { - const [olm] = await db - .select() - .from(olms) - .where(eq(olms.clientId, clientId)) - .limit(1); - if (!olm) { - throw new Error(`Olm with ID ${clientId} not found`); + if (!olmId) { + const [olm] = await db + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + olmId = olm.olmId; } - await sendToClient(olm.olmId, { + await sendToClient(olmId, { type: "olm/wg/peer/add", data: { siteId: peer.siteId, @@ -36,20 +40,28 @@ export async function addPeer( } }); - logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`); + logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`); } -export async function deletePeer(clientId: number, siteId: number, publicKey: string) { - const [olm] = await db - .select() - .from(olms) - .where(eq(olms.clientId, clientId)) - .limit(1); - if (!olm) { - throw new Error(`Olm with ID ${clientId} not found`); +export async function deletePeer( + clientId: number, + siteId: number, + publicKey: string, + olmId?: string +) { + if (!olmId) { + const [olm] = await db + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + olmId = olm.olmId; } - await sendToClient(olm.olmId, { + await sendToClient(olmId, { type: "olm/wg/peer/remove", data: { publicKey, @@ -57,7 +69,7 @@ export async function deletePeer(clientId: number, siteId: number, publicKey: st } }); - logger.info(`Deleted peer ${publicKey} from olm ${olm.olmId}`); + logger.info(`Deleted peer ${publicKey} from olm ${olmId}`); } export async function updatePeer( @@ -69,18 +81,22 @@ export async function updatePeer( serverIP: string | null; serverPort: number | null; remoteSubnets?: string | null; // optional, comma-separated list of subnets that - } + }, + olmId?: string ) { - const [olm] = await db - .select() - .from(olms) - .where(eq(olms.clientId, clientId)) - .limit(1); - if (!olm) { - throw new Error(`Olm with ID ${clientId} not found`); + if (!olmId) { + const [olm] = await db + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + olmId = olm.olmId; } - await sendToClient(olm.olmId, { + await sendToClient(olmId, { type: "olm/wg/peer/update", data: { siteId: peer.siteId, @@ -92,5 +108,5 @@ export async function updatePeer( } }); - logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`); + logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`); } diff --git a/server/routers/siteResource/createSiteResource.ts b/server/routers/siteResource/createSiteResource.ts index e7f8bd75..1200c38b 100644 --- a/server/routers/siteResource/createSiteResource.ts +++ b/server/routers/siteResource/createSiteResource.ts @@ -11,6 +11,7 @@ import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; import { addTargets } from "../client/targets"; import { getUniqueSiteResourceName } from "@server/db/names"; +import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations"; const createSiteResourceParamsSchema = z .object({ @@ -29,7 +30,8 @@ const createSiteResourceSchema = z destination: z.string().min(1), enabled: z.boolean().default(true), alias: z.string().optional() - }).strict() + }) + .strict() .refine( (data) => { if (data.mode === "port") { @@ -145,61 +147,75 @@ export async function createSiteResource( const niceId = await getUniqueSiteResourceName(orgId); - // Create the site resource - const [newSiteResource] = await db - .insert(siteResources) - .values({ - siteId, - niceId, - orgId, - name, - mode, - protocol: mode === "port" ? protocol : null, - proxyPort: mode === "port" ? proxyPort : null, - destinationPort: mode === "port" ? destinationPort : null, - destination, - enabled, - alias: alias || null - }) - .returning(); + let newSiteResource: SiteResource | undefined; + await db.transaction(async (trx) => { + // Create the site resource + [newSiteResource] = await trx + .insert(siteResources) + .values({ + siteId, + niceId, + orgId, + name, + mode, + protocol: mode === "port" ? protocol : null, + proxyPort: mode === "port" ? proxyPort : null, + destinationPort: mode === "port" ? destinationPort : null, + destination, + enabled, + alias: alias || null + }) + .returning(); - const adminRole = await db - .select() - .from(roles) - .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) - .limit(1); - - if (adminRole.length === 0) { - return next( - createHttpError(HttpCode.NOT_FOUND, `Admin role not found`) - ); - } - - await db.insert(roleSiteResources).values({ - roleId: adminRole[0].roleId, - siteResourceId: newSiteResource.siteResourceId - }); - - // Only add targets for port mode - if (mode === "port" && protocol && proxyPort && destinationPort) { - const [newt] = await db + const [adminRole] = await trx .select() - .from(newts) - .where(eq(newts.siteId, site.siteId)) + .from(roles) + .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) .limit(1); - if (!newt) { + if (!adminRole) { return next( - createHttpError(HttpCode.NOT_FOUND, "Newt not found") + createHttpError(HttpCode.NOT_FOUND, `Admin role not found`) ); } - await addTargets( - newt.newtId, - destination, - destinationPort, - protocol, - proxyPort + await trx.insert(roleSiteResources).values({ + roleId: adminRole.roleId, + siteResourceId: newSiteResource.siteResourceId + }); + + // Only add targets for port mode + if (mode === "port" && protocol && proxyPort && destinationPort) { + const [newt] = await trx + .select() + .from(newts) + .where(eq(newts.siteId, site.siteId)) + .limit(1); + + if (!newt) { + return next( + createHttpError(HttpCode.NOT_FOUND, "Newt not found") + ); + } + + await addTargets( + newt.newtId, + destination, + destinationPort, + protocol, + proxyPort + ); + } + + await rebuildSiteClientAssociations(newSiteResource, trx); // we need to call this because we added to the admin role + }); + + if (!newSiteResource) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Site resource creation failed" + ) ); } diff --git a/server/routers/siteResource/deleteSiteResource.ts b/server/routers/siteResource/deleteSiteResource.ts index b43dcd27..bbd84233 100644 --- a/server/routers/siteResource/deleteSiteResource.ts +++ b/server/routers/siteResource/deleteSiteResource.ts @@ -10,10 +10,14 @@ import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; import { removeTargets } from "../client/targets"; +import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations"; const deleteSiteResourceParamsSchema = z .object({ - siteResourceId: z.string().transform(Number).pipe(z.number().int().positive()), + siteResourceId: z + .string() + .transform(Number) + .pipe(z.number().int().positive()), siteId: z.string().transform(Number).pipe(z.number().int().positive()), orgId: z.string() }) @@ -40,7 +44,9 @@ export async function deleteSiteResource( next: NextFunction ): Promise { try { - const parsedParams = deleteSiteResourceParamsSchema.safeParse(req.params); + const parsedParams = deleteSiteResourceParamsSchema.safeParse( + req.params + ); if (!parsedParams.success) { return next( createHttpError( @@ -66,53 +72,61 @@ export async function deleteSiteResource( const [existingSiteResource] = await db .select() .from(siteResources) - .where(and( - eq(siteResources.siteResourceId, siteResourceId), - eq(siteResources.siteId, siteId), - eq(siteResources.orgId, orgId) - )) + .where(and(eq(siteResources.siteResourceId, siteResourceId))) .limit(1); if (!existingSiteResource) { return next( - createHttpError( - HttpCode.NOT_FOUND, - "Site resource not found" - ) + createHttpError(HttpCode.NOT_FOUND, "Site resource not found") ); } - // Delete the site resource - await db - .delete(siteResources) - .where(and( - eq(siteResources.siteResourceId, siteResourceId), - eq(siteResources.siteId, siteId), - eq(siteResources.orgId, orgId) - )); + await db.transaction(async (trx) => { + // Delete the site resource + await trx + .delete(siteResources) + .where( + and( + eq(siteResources.siteResourceId, siteResourceId), + eq(siteResources.siteId, siteId), + eq(siteResources.orgId, orgId) + ) + ); - // Only remove targets for port mode - if (existingSiteResource.mode === "port" && existingSiteResource.protocol && existingSiteResource.proxyPort && existingSiteResource.destinationPort) { - const [newt] = await db - .select() - .from(newts) - .where(eq(newts.siteId, site.siteId)) - .limit(1); + // Only remove targets for port mode + if ( + existingSiteResource.mode === "port" && + existingSiteResource.protocol && + existingSiteResource.proxyPort && + existingSiteResource.destinationPort + ) { + const [newt] = await trx + .select() + .from(newts) + .where(eq(newts.siteId, site.siteId)) + .limit(1); - if (!newt) { - return next(createHttpError(HttpCode.NOT_FOUND, "Newt not found")); + if (!newt) { + return next( + createHttpError(HttpCode.NOT_FOUND, "Newt not found") + ); + } + + await removeTargets( + newt.newtId, + existingSiteResource.destination, + existingSiteResource.destinationPort, + existingSiteResource.protocol, + existingSiteResource.proxyPort + ); } - await removeTargets( - newt.newtId, - existingSiteResource.destination, - existingSiteResource.destinationPort, - existingSiteResource.protocol, - existingSiteResource.proxyPort - ); - } + await rebuildSiteClientAssociations(existingSiteResource, trx); + }); - logger.info(`Deleted site resource ${siteResourceId} for site ${siteId}`); + logger.info( + `Deleted site resource ${siteResourceId} for site ${siteId}` + ); return response(res, { data: { message: "Site resource deleted successfully" }, @@ -123,6 +137,11 @@ export async function deleteSiteResource( }); } catch (error) { logger.error("Error deleting site resource:", error); - return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to delete site resource")); + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Failed to delete site resource" + ) + ); } } diff --git a/server/routers/siteResource/setSiteResourceRoles.ts b/server/routers/siteResource/setSiteResourceRoles.ts index ba312134..3be0ee11 100644 --- a/server/routers/siteResource/setSiteResourceRoles.ts +++ b/server/routers/siteResource/setSiteResourceRoles.ts @@ -9,6 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and, ne } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations"; const setSiteResourceRolesBodySchema = z .object({ @@ -62,7 +63,9 @@ export async function setSiteResourceRoles( const { roleIds } = parsedBody.data; - const parsedParams = setSiteResourceRolesParamsSchema.safeParse(req.params); + const parsedParams = setSiteResourceRolesParamsSchema.safeParse( + req.params + ); if (!parsedParams.success) { return next( createHttpError( @@ -136,6 +139,8 @@ export async function setSiteResourceRoles( .returning() ) ); + + await rebuildSiteClientAssociations(siteResource, trx); }); return response(res, { @@ -152,4 +157,3 @@ export async function setSiteResourceRoles( ); } } - diff --git a/server/routers/siteResource/setSiteResourceUsers.ts b/server/routers/siteResource/setSiteResourceUsers.ts index f100f6d6..ea913732 100644 --- a/server/routers/siteResource/setSiteResourceUsers.ts +++ b/server/routers/siteResource/setSiteResourceUsers.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; +import { db, siteResources } from "@server/db"; import { userSiteResources } from "@server/db"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; @@ -9,6 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations"; const setSiteResourceUsersBodySchema = z .object({ @@ -74,6 +75,22 @@ export async function setSiteResourceUsers( const { siteResourceId } = parsedParams.data; + // get the site resource + const [siteResource] = await db + .select() + .from(siteResources) + .where(eq(siteResources.siteResourceId, siteResourceId)) + .limit(1); + + if (!siteResource) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Site resource not found" + ) + ); + } + await db.transaction(async (trx) => { await trx .delete(userSiteResources) @@ -87,6 +104,8 @@ export async function setSiteResourceUsers( .returning() ) ); + + await rebuildSiteClientAssociations(siteResource, trx); }); return response(res, {