From b5e94d44ae1bb6fbb3746769ed530409009dff9b Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 1 Dec 2025 15:44:25 -0500 Subject: [PATCH] Fix switching orgs having connections from other orgs --- server/lib/rebuildClientAssociations.ts | 24 +- server/routers/olm/getOlmToken.ts | 48 +++- .../routers/olm/handleOlmRegisterMessage.ts | 235 +++--------------- 3 files changed, 95 insertions(+), 212 deletions(-) diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index 4be2c92a..810acdef 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -465,7 +465,8 @@ async function handleMessagesForSiteClients( } if (isAdd) { - await holepunchSiteAdd( // this will kick off the add peer process for the client + await holepunchSiteAdd( + // this will kick off the add peer process for the client client.clientId, { siteId, @@ -728,7 +729,19 @@ export async function rebuildClientAssociationsFromClient( const userSiteResourceIds = await trx .select({ siteResourceId: userSiteResources.siteResourceId }) .from(userSiteResources) - .where(eq(userSiteResources.userId, client.userId)); + .innerJoin( + siteResources, + eq( + siteResources.siteResourceId, + userSiteResources.siteResourceId + ) + ) + .where( + and( + eq(userSiteResources.userId, client.userId), + eq(siteResources.orgId, client.orgId) + ) + ); // this needs to be locked onto this org or else cross-org access could happen newSiteResourceIds.push( ...userSiteResourceIds.map((r) => r.siteResourceId) @@ -738,7 +751,12 @@ export async function rebuildClientAssociationsFromClient( const roleIds = await trx .select({ roleId: userOrgs.roleId }) .from(userOrgs) - .where(eq(userOrgs.userId, client.userId)) + .where( + and( + eq(userOrgs.userId, client.userId), + eq(userOrgs.orgId, client.orgId) + ) + ) // this needs to be locked onto this org or else cross-org access could happen .then((rows) => rows.map((row) => row.roleId)); if (roleIds.length > 0) { diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index 33f5fa2d..cea8386c 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -1,5 +1,12 @@ import { generateSessionToken } from "@server/auth/sessions/app"; -import { clients, db, ExitNode, exitNodes, sites, clientSitesAssociationsCache } from "@server/db"; +import { + clients, + db, + ExitNode, + exitNodes, + sites, + clientSitesAssociationsCache +} from "@server/db"; import { olms } from "@server/db"; import HttpCode from "@server/types/HttpCode"; import response from "@server/lib/response"; @@ -99,6 +106,7 @@ export async function getOlmToken( await createOlmSession(resToken, existingOlm.olmId); let orgIdToUse = orgId; + let clientIdToUse; if (!orgIdToUse) { if (!existingOlm.clientId) { return next( @@ -114,7 +122,7 @@ export async function getOlmToken( .from(clients) .where(eq(clients.clientId, existingOlm.clientId)) .limit(1); - + if (!client) { return next( createHttpError( @@ -125,6 +133,40 @@ export async function getOlmToken( } orgIdToUse = client.orgId; + clientIdToUse = client.clientId; + } else { + // we did provide the org + const [client] = await db + .select() + .from(clients) + .where(eq(clients.orgId, orgIdToUse)) + .limit(1); + + if (!client) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "No client found for provided orgId" + ) + ); + } + + if (existingOlm.clientId !== client.clientId) { + // we only need to do this if the client is changing + + logger.debug( + `Switching olm client ${existingOlm.olmId} to org ${orgId} for user ${existingOlm.userId}` + ); + + await db + .update(olms) + .set({ + clientId: client.clientId + }) + .where(eq(olms.olmId, existingOlm.olmId)); + } + + clientIdToUse = client.clientId; } // Get all exit nodes from sites where the client has peers @@ -135,7 +177,7 @@ export async function getOlmToken( sites, eq(sites.siteId, clientSitesAssociationsCache.siteId) ) - .where(eq(clientSitesAssociationsCache.clientId, existingOlm.clientId!)); + .where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!)); // Extract unique exit node IDs const exitNodeIds = Array.from( diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index ce40e832..6f3e59c1 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -48,45 +48,36 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } - const { - publicKey, - relay, - olmVersion, - orgId, - doNotCreateNewClient, - token: userToken - } = message.data; + const { publicKey, relay, olmVersion, orgId, userToken } = message.data; - let client: Client | undefined; - let org: Org | undefined; + if (!olm.clientId) { + logger.warn("Olm client ID not found"); + return; + } + + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, olm.clientId)) + .limit(1); + + if (!client) { + logger.warn("Client ID not found"); + return; + } + + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, client.orgId)) + .limit(1); + + if (!org) { + logger.warn("Org not found"); + return; + } if (orgId) { - try { - const { client: clientRes, org: orgRes } = - await getOrCreateOrgClient( - orgId, - olm.userId, - olm.olmId, - olm.name || "User Device", - // doNotCreateNewClient ? true : false - true // for now never create a new client automatically because we create the users clients when they are added to the org - // this means that the rebuildClientAssociationsFromClient call below issue is not a problem - ); - - client = clientRes; - org = orgRes; - } catch (err) { - logger.error( - `Error switching olm client ${olm.olmId} to org ${orgId}: ${err}` - ); - return; - } - - if (!client) { - logger.warn("Client not found"); - return; - } - if (!olm.userId) { logger.warn("Olm has no user ID"); return; @@ -95,11 +86,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { const { session: userSession, user } = await validateSessionToken(userToken); if (!userSession || !user) { - logger.warn("Invalid user session for olm ping"); + logger.warn("Invalid user session for olm register"); return; // by returning here we just ignore the ping and the setInterval will force it to disconnect } if (user.userId !== olm.userId) { - logger.warn("User ID mismatch for olm ping"); + logger.warn("User ID mismatch for olm register"); return; } @@ -115,48 +106,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { ); return; } - - logger.debug( - `Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}` - ); - - if (olm.clientId !== client.clientId) { // we only need to do this if the client is changing - await db - .update(olms) - .set({ - clientId: client.clientId - }) - .where(eq(olms.olmId, olm.olmId)); - } - } else { - if (!olm.clientId) { - logger.warn("Olm has no client ID!"); - return; - } - - logger.debug(`Using last connected org for client ${olm.clientId}`); - - [client] = await db - .select() - .from(clients) - .where(eq(clients.clientId, olm.clientId)) - .limit(1); - - [org] = await db - .select() - .from(orgs) - .where(eq(orgs.orgId, client.orgId)) - .limit(1); - } - - if (!client) { - logger.warn("Client ID not found"); - return; - } - - if (!org) { - logger.warn("Org not found"); - return; } logger.debug( @@ -357,129 +306,3 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { excludeSender: false }; }; - -async function getOrCreateOrgClient( - orgId: string, - userId: string | null, - olmId: string, - name: string, - doNotCreateNewClient: boolean, - trx: Transaction | typeof db = db -): Promise<{ - client: Client; - org: Org; -}> { - // get the org - const [org] = await trx - .select() - .from(orgs) - .where(eq(orgs.orgId, orgId)) - .limit(1); - - if (!org) { - throw new Error("Org not found"); - } - - if (!org.subnet) { - throw new Error("Org has no subnet defined"); - } - - // check if the user has a client in the org and if not then create a client for them - const [existingClient] = await trx - .select() - .from(clients) - .where( - and( - eq(clients.orgId, orgId), - userId ? eq(clients.userId, userId) : isNull(clients.userId), // we dont check the user id if it is null because the olm is not tied to a user? - eq(clients.olmId, olmId) - ) - ) // checking the olmid here because we want to create a new client PER OLM PER ORG - .limit(1); - - let client = existingClient; - if (!client && !doNotCreateNewClient) { - logger.debug( - `Client does not exist in org ${orgId}, creating new client for user ${userId}` - ); - - if (!userId) { - throw new Error("User ID is required to create client in org"); - } - - // Verify that the user belongs to the org - const [userOrg] = await trx - .select() - .from(userOrgs) - .where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId))) - .limit(1); - - if (!userOrg) { - throw new Error("User does not belong to org"); - } - - // TODO: more intelligent way to pick the exit node - const exitNodesList = await listExitNodes(orgId); - const randomExitNode = - exitNodesList[Math.floor(Math.random() * exitNodesList.length)]; - - const [adminRole] = await trx - .select() - .from(roles) - .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) - .limit(1); - - if (!adminRole) { - throw new Error("Admin role not found"); - } - - const newSubnet = await getNextAvailableClientSubnet(orgId); - if (!newSubnet) { - throw new Error("No available subnet found"); - } - - const subnet = newSubnet.split("/")[0]; - const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org - - const [newClient] = await trx - .insert(clients) - .values({ - exitNodeId: randomExitNode.exitNodeId, - orgId, - name, - subnet: updatedSubnet, - type: "olm", - userId: userId, - olmId: olmId // to lock this client to the olm even as the olm moves between clients in different orgs - }) - .returning(); - - await trx.insert(roleClients).values({ - roleId: adminRole.roleId, - clientId: newClient.clientId - }); - - await trx.insert(userClients).values({ - // we also want to make sure that the user can see their own client if they are not an admin - userId, - clientId: newClient.clientId - }); - - if (userOrg.roleId != adminRole.roleId) { - // make sure the user can access the client - trx.insert(userClients).values({ - userId, - clientId: newClient.clientId - }); - } - - await rebuildClientAssociationsFromClient(newClient, trx); // TODO: this will try to messages to the olm which has not connected yet - is that a problem? - - client = newClient; - } - - return { - client: client, - org: org - }; -}