From de83cf9d8ca11fb92a5d2f563616d53281c07603 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 26 Nov 2025 15:35:33 -0500 Subject: [PATCH] Handle delete org and checking org policy --- server/routers/olm/handleOlmPingMessage.ts | 35 ++++++ .../routers/olm/handleOlmRegisterMessage.ts | 33 ++++- server/routers/org/deleteOrg.ts | 117 +++++++++++++----- server/routers/user/addUserRole.ts | 4 +- 4 files changed, 155 insertions(+), 34 deletions(-) diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index ab503d4c..ee9443f5 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -5,6 +5,8 @@ import { clients, Olm } from "@server/db"; import { eq, lt, isNull, and, or } from "drizzle-orm"; import logger from "@server/logger"; import { validateSessionToken } from "@server/auth/sessions/app"; +import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; +import { sendTerminateClient } from "../client/terminate"; // Track if the offline checker interval is running let offlineCheckerInterval: NodeJS.Timeout | null = null; @@ -57,6 +59,9 @@ export const startOlmOfflineChecker = (): void => { // Send a disconnect message to the client if connected try { + await sendTerminateClient(offlineClient.clientId); // terminate first + // wait a moment to ensure the message is sent + await new Promise(resolve => setTimeout(resolve, 1000)); await disconnectClient(offlineClient.olmId); } catch (error) { logger.error( @@ -110,6 +115,36 @@ export const handleOlmPingMessage: MessageHandler = async (context) => { logger.warn("User ID mismatch for olm ping"); return; } + + // get the client + const [client] = await db + .select() + .from(clients) + .where( + and( + eq(clients.olmId, olm.olmId), + eq(clients.userId, olm.userId) + ) + ) + .limit(1); + + if (!client) { + logger.warn("Client not found for olm ping"); + return; + } + + const policyCheck = await checkOrgAccessPolicy({ + orgId: client.orgId, + userId: olm.userId, + session: userToken // this is the user token passed in the message + }); + + if (!policyCheck.allowed) { + logger.warn( + `Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}` + ); + return; + } } if (!olm.clientId) { diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 2ee5c120..53cd9815 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -32,6 +32,8 @@ import { } from "@server/lib/ip"; import { generateRemoteSubnets } from "@server/lib/ip"; import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; +import { validateSessionToken } from "@server/auth/sessions/app"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.info("Handling register olm message!"); @@ -45,7 +47,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } - const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient } = + const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient, token: userToken } = message.data; let client: Client | undefined; @@ -78,6 +80,35 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } + if (!olm.userId) { + logger.warn("Olm has no user ID"); + return; + } + + const { session: userSession, user } = + await validateSessionToken(userToken); + if (!userSession || !user) { + logger.warn("Invalid user session for olm ping"); + return; // by returning here we just ignore the ping and the setInterval will force it to disconnect + } + if (user.userId !== olm.userId) { + logger.warn("User ID mismatch for olm ping"); + return; + } + + const policyCheck = await checkOrgAccessPolicy({ + orgId: orgId, + userId: olm.userId, + session: userToken // this is the user token passed in the message + }); + + if (!policyCheck.allowed) { + logger.warn( + `Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}` + ); + return; + } + logger.debug( `Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}` ); diff --git a/server/routers/org/deleteOrg.ts b/server/routers/org/deleteOrg.ts index 0e21a8c0..098c5c41 100644 --- a/server/routers/org/deleteOrg.ts +++ b/server/routers/org/deleteOrg.ts @@ -1,6 +1,15 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db, domains, orgDomains, resources } from "@server/db"; +import { + clients, + clientSiteResourcesAssociationsCache, + clientSitesAssociationsCache, + db, + domains, + olms, + orgDomains, + resources +} from "@server/db"; import { newts, newtSessions, orgs, sites, userActions } from "@server/db"; import { eq, and, inArray, sql } from "drizzle-orm"; import response from "@server/lib/response"; @@ -14,8 +23,8 @@ import { deletePeer } from "../gerbil/peers"; import { OpenAPITags, registry } from "@server/openApi"; const deleteOrgSchema = z.strictObject({ - orgId: z.string() - }); + orgId: z.string() +}); export type DeleteOrgResponse = {}; @@ -69,41 +78,75 @@ export async function deleteOrg( .where(eq(sites.orgId, orgId)) .limit(1); + const orgClients = await db + .select() + .from(clients) + .where(eq(clients.orgId, orgId)); + const deletedNewtIds: string[] = []; + const olmsToTerminate: string[] = []; await db.transaction(async (trx) => { - if (sites) { - for (const site of orgSites) { - if (site.pubKey) { - if (site.type == "wireguard") { - await deletePeer(site.exitNodeId!, site.pubKey); - } else if (site.type == "newt") { - // get the newt on the site by querying the newt table for siteId - const [deletedNewt] = await trx - .delete(newts) - .where(eq(newts.siteId, site.siteId)) - .returning(); - if (deletedNewt) { - deletedNewtIds.push(deletedNewt.newtId); + for (const site of orgSites) { + if (site.pubKey) { + if (site.type == "wireguard") { + await deletePeer(site.exitNodeId!, site.pubKey); + } else if (site.type == "newt") { + // get the newt on the site by querying the newt table for siteId + const [deletedNewt] = await trx + .delete(newts) + .where(eq(newts.siteId, site.siteId)) + .returning(); + if (deletedNewt) { + deletedNewtIds.push(deletedNewt.newtId); - // delete all of the sessions for the newt - await trx - .delete(newtSessions) - .where( - eq( - newtSessions.newtId, - deletedNewt.newtId - ) - ); - } + // delete all of the sessions for the newt + await trx + .delete(newtSessions) + .where( + eq(newtSessions.newtId, deletedNewt.newtId) + ); } } - - logger.info(`Deleting site ${site.siteId}`); - await trx - .delete(sites) - .where(eq(sites.siteId, site.siteId)); } + + logger.info(`Deleting site ${site.siteId}`); + await trx.delete(sites).where(eq(sites.siteId, site.siteId)); + } + for (const client of orgClients) { + const [olm] = await trx + .select() + .from(olms) + .where(eq(olms.clientId, client.clientId)) + .limit(1); + + if (olm) { + olmsToTerminate.push(olm.olmId); + } + + logger.info(`Deleting client ${client.clientId}`); + await trx + .delete(clients) + .where(eq(clients.clientId, client.clientId)); + + // also delete the associations + await trx + .delete(clientSiteResourcesAssociationsCache) + .where( + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ) + ); + + await trx + .delete(clientSitesAssociationsCache) + .where( + eq( + clientSitesAssociationsCache.clientId, + client.clientId + ) + ); } const allOrgDomains = await trx @@ -162,6 +205,18 @@ export async function deleteOrg( }); } + for (const olmId of olmsToTerminate) { + sendToClient(olmId, { + type: "olm/terminate", + data: {} + }).catch((error) => { + logger.error( + "Failed to send termination message to olm:", + error + ); + }); + } + return response(res, { data: null, success: true, diff --git a/server/routers/user/addUserRole.ts b/server/routers/user/addUserRole.ts index 9404d94f..32eaa19d 100644 --- a/server/routers/user/addUserRole.ts +++ b/server/routers/user/addUserRole.ts @@ -125,7 +125,7 @@ export async function addUserRole( .returning(); // get the client associated with this user in this org - const [orgClient] = await trx + const orgClients = await trx .select() .from(clients) .where( @@ -136,7 +136,7 @@ export async function addUserRole( ) .limit(1); - if (orgClient) { + for (const orgClient of orgClients) { // we just changed the user's role, so we need to rebuild client associations and what they have access to await rebuildClientAssociationsFromClient(orgClient, trx); }