diff --git a/server/lib/calculateUserClientsForOrgs.ts b/server/lib/calculateUserClientsForOrgs.ts new file mode 100644 index 00000000..f8578f0f --- /dev/null +++ b/server/lib/calculateUserClientsForOrgs.ts @@ -0,0 +1,258 @@ +import { clients, clientSites, db, olms, orgs, roleClients, roles, userClients, userOrgs, Transaction } from "@server/db"; +import { eq, and, notInArray } from "drizzle-orm"; +import { listExitNodes } from "#dynamic/lib/exitNodes"; +import { getNextAvailableClientSubnet } from "@server/lib/ip"; +import logger from "@server/logger"; + +export async function calculateUserClientsForOrgs( + userId: string, + trx?: Transaction +): Promise { + const execute = async (transaction: Transaction) => { + // Get all OLMs for this user + const userOlms = await transaction + .select() + .from(olms) + .where(eq(olms.userId, userId)); + + if (userOlms.length === 0) { + // No OLMs for this user, but we should still clean up any orphaned clients + await cleanupOrphanedClients(userId, transaction); + return; + } + + // Get all user orgs + const allUserOrgs = await transaction + .select() + .from(userOrgs) + .where(eq(userOrgs.userId, userId)); + + const userOrgIds = allUserOrgs.map((uo) => uo.orgId); + + // For each OLM, ensure there's a client in each org the user is in + for (const olm of userOlms) { + for (const userOrg of allUserOrgs) { + const orgId = userOrg.orgId; + + const [org] = await transaction + .select() + .from(orgs) + .where(eq(orgs.orgId, orgId)); + + if (!org) { + logger.warn( + `Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): org not found` + ); + continue; + } + + if (!org.subnet) { + logger.warn( + `Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): org has no subnet configured` + ); + continue; + } + + // Get admin role for this org (needed for access grants) + const [adminRole] = await transaction + .select() + .from(roles) + .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) + .limit(1); + + if (!adminRole) { + logger.warn( + `Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): no admin role found` + ); + continue; + } + + // Check if a client already exists for this OLM+user+org combination + const [existingClient] = await transaction + .select() + .from(clients) + .where( + and( + eq(clients.userId, userId), + eq(clients.orgId, orgId), + eq(clients.olmId, olm.olmId) + ) + ) + .limit(1); + + if (existingClient) { + // Ensure admin role has access to the client + const [existingRoleClient] = await transaction + .select() + .from(roleClients) + .where( + and( + eq(roleClients.roleId, adminRole.roleId), + eq(roleClients.clientId, existingClient.clientId) + ) + ) + .limit(1); + + if (!existingRoleClient) { + await transaction.insert(roleClients).values({ + roleId: adminRole.roleId, + clientId: existingClient.clientId + }); + logger.debug( + `Granted admin role access to existing client ${existingClient.clientId} for OLM ${olm.olmId} in org ${orgId} (user ${userId})` + ); + } + + // Ensure user has access to the client + const [existingUserClient] = await transaction + .select() + .from(userClients) + .where( + and( + eq(userClients.userId, userId), + eq(userClients.clientId, existingClient.clientId) + ) + ) + .limit(1); + + if (!existingUserClient) { + await transaction.insert(userClients).values({ + userId, + clientId: existingClient.clientId + }); + logger.debug( + `Granted user access to existing client ${existingClient.clientId} for OLM ${olm.olmId} in org ${orgId} (user ${userId})` + ); + } + + logger.debug( + `Client already exists for OLM ${olm.olmId} in org ${orgId} (user ${userId}), skipping creation` + ); + continue; + } + + // Get exit nodes for this org + const exitNodesList = await listExitNodes(orgId); + + if (exitNodesList.length === 0) { + logger.warn( + `Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): no exit nodes found` + ); + continue; + } + + const randomExitNode = + exitNodesList[ + Math.floor(Math.random() * exitNodesList.length) + ]; + + // Get next available subnet + const newSubnet = await getNextAvailableClientSubnet(orgId); + if (!newSubnet) { + logger.warn( + `Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): no available subnet found` + ); + continue; + } + + const subnet = newSubnet.split("/")[0]; + const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; + + // Create the client + const [newClient] = await transaction + .insert(clients) + .values({ + userId, + orgId: userOrg.orgId, + exitNodeId: randomExitNode.exitNodeId, + name: olm.name || "User Client", + subnet: updatedSubnet, + olmId: olm.olmId, + type: "olm" + }) + .returning(); + + // Grant admin role access to the client + await transaction.insert(roleClients).values({ + roleId: adminRole.roleId, + clientId: newClient.clientId + }); + + // Grant user access to the client + await transaction.insert(userClients).values({ + userId, + clientId: newClient.clientId + }); + + logger.debug( + `Created client for OLM ${olm.olmId} in org ${orgId} (user ${userId}) with access granted to admin role and user` + ); + } + } + + // Clean up clients in orgs the user is no longer in + await cleanupOrphanedClients(userId, transaction, userOrgIds); + }; + + if (trx) { + // Use provided transaction + await execute(trx); + } else { + // Create new transaction + await db.transaction(async (transaction) => { + await execute(transaction); + }); + } +} + +async function cleanupOrphanedClients( + userId: string, + trx: Transaction, + userOrgIds: string[] = [] +): Promise { + // Find all OLM clients for this user that should be deleted + // If userOrgIds is empty, delete all OLM clients (user has no orgs) + // If userOrgIds has values, delete clients in orgs they're not in + const clientsToDelete = await trx + .select({ clientId: clients.clientId }) + .from(clients) + .where( + userOrgIds.length > 0 + ? and( + eq(clients.userId, userId), + notInArray(clients.orgId, userOrgIds) + ) + : and(eq(clients.userId, userId)) + ); + + // Delete client-site associations first, then delete the clients + for (const client of clientsToDelete) { + await trx + .delete(clientSites) + .where(eq(clientSites.clientId, client.clientId)); + } + + if (clientsToDelete.length > 0) { + await trx + .delete(clients) + .where( + userOrgIds.length > 0 + ? and( + eq(clients.userId, userId), + notInArray(clients.orgId, userOrgIds) + ) + : and(eq(clients.userId, userId)) + ); + + if (userOrgIds.length === 0) { + logger.debug( + `Deleted all ${clientsToDelete.length} OLM client(s) for user ${userId} (user has no orgs)` + ); + } else { + logger.debug( + `Deleted ${clientsToDelete.length} orphaned OLM client(s) for user ${userId} in orgs they're no longer in` + ); + } + } +} + diff --git a/server/routers/idp/validateOidcCallback.ts b/server/routers/idp/validateOidcCallback.ts index 7d1da1c5..114f3422 100644 --- a/server/routers/idp/validateOidcCallback.ts +++ b/server/routers/idp/validateOidcCallback.ts @@ -33,6 +33,7 @@ import { UserType } from "@server/types/UserTypes"; import { FeatureId } from "@server/lib/billing"; import { usageService } from "@server/lib/billing/usageService"; import { build } from "@server/build"; +import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; const ensureTrailingSlash = (url: string): string => { return url; @@ -364,10 +365,18 @@ export async function validateOidcCallback( ); if (!existingUserOrgs.length) { - // delete the user - // await db - // .delete(users) - // .where(eq(users.userId, existingUser.userId)); + // delete all auto -provisioned user orgs + await db + .delete(userOrgs) + .where( + and( + eq(userOrgs.userId, existingUser.userId), + eq(userOrgs.autoProvisioned, true) + ) + ); + + await calculateUserClientsForOrgs(existingUser.userId); + return next( createHttpError( HttpCode.UNAUTHORIZED, @@ -513,6 +522,8 @@ export async function validateOidcCallback( userCount: userCount.length }); } + + await calculateUserClientsForOrgs(userId!, trx); }); for (const orgCount of orgUserCounts) { @@ -553,6 +564,24 @@ export async function validateOidcCallback( ); } + // check for existing user orgs + const existingUserOrgs = await db + .select() + .from(userOrgs) + .where(and(eq(userOrgs.userId, existingUser.userId))); + + if (!existingUserOrgs.length) { + logger.debug( + "No existing user orgs found for non-auto-provisioned IdP" + ); + return next( + createHttpError( + HttpCode.UNAUTHORIZED, + `User with username ${userIdentifier} is unprovisioned. This user must be added to an organization before logging in.` + ) + ); + } + const token = generateSessionToken(); const sess = await createSession(token, existingUser.userId); const isSecure = req.protocol === "https"; diff --git a/server/routers/olm/createUserOlm.ts b/server/routers/olm/createUserOlm.ts index 9b64c1a9..e9bcfa8a 100644 --- a/server/routers/olm/createUserOlm.ts +++ b/server/routers/olm/createUserOlm.ts @@ -1,8 +1,7 @@ import { NextFunction, Request, Response } from "express"; -import { db } from "@server/db"; +import { db, olms } from "@server/db"; import HttpCode from "@server/types/HttpCode"; import { z } from "zod"; -import { olms } from "@server/db"; import createHttpError from "http-errors"; import response from "@server/lib/response"; import moment from "moment"; @@ -10,6 +9,7 @@ import { generateId } from "@server/auth/sessions/app"; import { fromError } from "zod-validation-error"; import { hashPassword } from "@server/auth/password"; import { OpenAPITags, registry } from "@server/openApi"; +import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; const bodySchema = z .object({ @@ -81,12 +81,16 @@ export async function createUserOlm( const secretHash = await hashPassword(secret); - await db.insert(olms).values({ - olmId: olmId, - userId, - name, - secretHash, - dateCreated: moment().toISOString() + await db.transaction(async (trx) => { + await trx.insert(olms).values({ + olmId: olmId, + userId, + name, + secretHash, + dateCreated: moment().toISOString() + }); + + await calculateUserClientsForOrgs(userId, trx); }); return response(res, { diff --git a/server/routers/user/acceptInvite.ts b/server/routers/user/acceptInvite.ts index 5e4264f9..7e64770f 100644 --- a/server/routers/user/acceptInvite.ts +++ b/server/routers/user/acceptInvite.ts @@ -12,6 +12,7 @@ import { checkValidInvite } from "@server/auth/checkValidInvite"; import { verifySession } from "@server/auth/sessions/verifySession"; import { usageService } from "@server/lib/billing/usageService"; import { FeatureId } from "@server/lib/billing"; +import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; const acceptInviteBodySchema = z .object({ @@ -131,6 +132,8 @@ export async function acceptInvite( .select() .from(userOrgs) .where(eq(userOrgs.orgId, existingInvite.orgId)); + + await calculateUserClientsForOrgs(existingUser[0].userId, trx); }); if (totalUsers) { diff --git a/server/routers/user/createOrgUser.ts b/server/routers/user/createOrgUser.ts index 29f94641..1e88add5 100644 --- a/server/routers/user/createOrgUser.ts +++ b/server/routers/user/createOrgUser.ts @@ -15,6 +15,7 @@ import { FeatureId } from "@server/lib/billing"; import { build } from "@server/build"; import { getOrgTierData } from "#dynamic/lib/billing"; import { TierId } from "@server/lib/billing/tiers"; +import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; const paramsSchema = z .object({ @@ -89,14 +90,7 @@ export async function createOrgUser( } const { orgId } = parsedParams.data; - const { - username, - email, - name, - type, - idpId, - roleId - } = parsedBody.data; + const { username, email, name, type, idpId, roleId } = parsedBody.data; if (build == "saas") { const usage = await usageService.getUsage(orgId, FeatureId.USERS); @@ -202,7 +196,9 @@ export async function createOrgUser( ) ); + let userId: string | undefined; if (existingUser) { + userId = existingUser.userId; const [existingOrgUser] = await trx .select() .from(userOrgs) @@ -232,7 +228,7 @@ export async function createOrgUser( }) .returning(); } else { - const userId = generateId(15); + userId = generateId(15); const [newUser] = await trx .insert(users) @@ -244,7 +240,7 @@ export async function createOrgUser( type: "oidc", idpId, dateCreated: new Date().toISOString(), - emailVerified: true, + emailVerified: true }) .returning(); @@ -264,6 +260,8 @@ export async function createOrgUser( .select() .from(userOrgs) .where(eq(userOrgs.orgId, orgId)); + + await calculateUserClientsForOrgs(userId, trx); }); if (orgUsers) { diff --git a/server/routers/user/removeUserOrg.ts b/server/routers/user/removeUserOrg.ts index babccdd0..8bad16d9 100644 --- a/server/routers/user/removeUserOrg.ts +++ b/server/routers/user/removeUserOrg.ts @@ -13,6 +13,7 @@ import { usageService } from "@server/lib/billing/usageService"; import { FeatureId } from "@server/lib/billing"; import { build } from "@server/build"; import { UserType } from "@server/types/UserTypes"; +import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; const removeUserSchema = z .object({ @@ -120,22 +121,24 @@ export async function removeUserOrg( .from(userOrgs) .where(eq(userOrgs.orgId, orgId)); - if (build === "saas") { - const [rootUser] = await trx - .select() - .from(users) - .where(eq(users.userId, userId)); + // if (build === "saas") { + // const [rootUser] = await trx + // .select() + // .from(users) + // .where(eq(users.userId, userId)); + // + // const [leftInOrgs] = await trx + // .select({ count: count() }) + // .from(userOrgs) + // .where(eq(userOrgs.userId, userId)); + // + // // if the user is not an internal user and does not belong to any org, delete the entire user + // if (rootUser?.type !== UserType.Internal && !leftInOrgs.count) { + // await trx.delete(users).where(eq(users.userId, userId)); + // } + // } - const [leftInOrgs] = await trx - .select({ count: count() }) - .from(userOrgs) - .where(eq(userOrgs.userId, userId)); - - // if the user is not an internal user and does not belong to any org, delete the entire user - if (rootUser?.type !== UserType.Internal && !leftInOrgs.count) { - await trx.delete(users).where(eq(users.userId, userId)); - } - } + await calculateUserClientsForOrgs(userId, trx); }); if (userCount) {