From d30743a428edb6bb57cf9eeaf5bef67657d5edfd Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 3 Nov 2025 15:42:22 -0800 Subject: [PATCH] Update schmea; create client when registering --- server/db/pg/schema/schema.ts | 9 + server/db/sqlite/schema/schema.ts | 18 +- server/private/lib/exitNodes/exitNodes.ts | 2 +- server/routers/client/createClient.ts | 11 +- .../routers/olm/handleOlmRegisterMessage.ts | 193 ++++++++++++++++-- 5 files changed, 200 insertions(+), 33 deletions(-) diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index ffbe820c..7db4af30 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -607,6 +607,10 @@ export const clients = pgTable("clients", { exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { onDelete: "set null" }), + userId: text("userId").references(() => users.userId, { + // optionally tied to a user and in this case delete when the user deletes + onDelete: "cascade" + }), name: varchar("name").notNull(), pubKey: varchar("pubKey"), subnet: varchar("subnet").notNull(), @@ -638,6 +642,11 @@ export const olms = pgTable("olms", { dateCreated: varchar("dateCreated").notNull(), version: text("version"), clientId: integer("clientId").references(() => clients.clientId, { + // we will switch this depending on the current org it wants to connect to + onDelete: "set null" + }), + userId: text("userId").references(() => users.userId, { + // optionally tied to a user and in this case delete when the user deletes onDelete: "cascade" }) }); diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index 13453d2e..63a50154 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -25,11 +25,10 @@ export const dnsRecords = sqliteTable("dnsRecords", { recordType: text("recordType").notNull(), // "NS" | "CNAME" | "A" | "TXT" baseDomain: text("baseDomain"), - value: text("value").notNull(), - verified: integer("verified", { mode: "boolean" }).notNull().default(false), + value: text("value").notNull(), + verified: integer("verified", { mode: "boolean" }).notNull().default(false) }); - export const orgs = sqliteTable("orgs", { orgId: text("orgId").primaryKey(), name: text("name").notNull(), @@ -142,9 +141,10 @@ export const resources = sqliteTable("resources", { onDelete: "set null" }), headers: text("headers"), // comma-separated list of headers to add to the request - proxyProtocol: integer("proxyProtocol", { mode: "boolean" }).notNull().default(false), + proxyProtocol: integer("proxyProtocol", { mode: "boolean" }) + .notNull() + .default(false), proxyProtocolVersion: integer("proxyProtocolVersion").default(1) - }); export const targets = sqliteTable("targets", { @@ -315,6 +315,10 @@ export const clients = sqliteTable("clients", { exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { onDelete: "set null" }), + userId: text("userId").references(() => users.userId, { + // optionally tied to a user and in this case delete when the user deletes + onDelete: "cascade" + }), name: text("name").notNull(), pubKey: text("pubKey"), subnet: text("subnet").notNull(), @@ -347,6 +351,10 @@ export const olms = sqliteTable("olms", { dateCreated: text("dateCreated").notNull(), version: text("version"), clientId: integer("clientId").references(() => clients.clientId, { + // we will switch this depending on the current org it wants to connect to + onDelete: "set null" + }), + userId: text("userId").references(() => users.userId, { // optionally tied to a user and in this case delete when the user deletes onDelete: "cascade" }) }); diff --git a/server/private/lib/exitNodes/exitNodes.ts b/server/private/lib/exitNodes/exitNodes.ts index 10418d5a..77149bb0 100644 --- a/server/private/lib/exitNodes/exitNodes.ts +++ b/server/private/lib/exitNodes/exitNodes.ts @@ -197,7 +197,7 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud // // set the item in the database if it is offline // if (isActuallyOnline != node.online) { - // await db + // await trx // .update(exitNodes) // .set({ online: isActuallyOnline }) // .where(eq(exitNodes.exitNodeId, node.exitNodeId)); diff --git a/server/routers/client/createClient.ts b/server/routers/client/createClient.ts index cb2bbd6e..90445925 100644 --- a/server/routers/client/createClient.ts +++ b/server/routers/client/createClient.ts @@ -182,14 +182,13 @@ export async function createClient( const randomExitNode = exitNodesList[Math.floor(Math.random() * exitNodesList.length)]; - const adminRole = await trx + const [adminRole] = await trx .select() .from(roles) .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) .limit(1); - if (adminRole.length === 0) { - trx.rollback(); + if (!adminRole) { return next( createHttpError(HttpCode.NOT_FOUND, `Admin role not found`) ); @@ -207,12 +206,12 @@ export async function createClient( .returning(); await trx.insert(roleClients).values({ - roleId: adminRole[0].roleId, + roleId: adminRole.roleId, clientId: newClient.clientId }); - if (req.user && req.userOrgRoleId != adminRole[0].roleId) { - // make sure the user can access the site + if (req.user && req.userOrgRoleId != adminRole.roleId) { + // make sure the user can access the client trx.insert(userClients).values({ userId: req.user?.userId!, clientId: newClient.clientId diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 66128f0e..e9d0ab6f 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -1,10 +1,22 @@ -import { db, ExitNode } from "@server/db"; +import { + Client, + db, + ExitNode, + orgs, + roleClients, + roles, + Transaction, + userClients, + userOrgs, + users +} from "@server/db"; import { MessageHandler } from "@server/routers/ws"; import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db"; import { and, eq, inArray } from "drizzle-orm"; import { addPeer, deletePeer } from "../newt/peers"; import logger from "@server/logger"; import { listExitNodes } from "#dynamic/lib/exitNodes"; +import { getNextAvailableClientSubnet } from "@server/lib/ip"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.info("Handling register olm message!"); @@ -17,15 +29,62 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.warn("Olm not found"); return; } - if (!olm.clientId) { - logger.warn("Olm has no client ID!"); + + const { publicKey, relay, olmVersion, orgId, deviceName } = message.data; + let client: Client; + + if (orgId) { + if (!olm.userId) { + logger.warn("Olm has no user ID to verify org change!"); + return; + } + + try { + client = await getOrCreateOrgClient(orgId, olm.userId, deviceName); + } catch (err) { + logger.error( + `Error switching olm client ${olm.olmId} to org ${orgId}: ${err}` + ); + return; + } + + if (!client) { + logger.warn("Client not found"); + return; + } + + logger.debug( + `Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}` + ); + + await db + .update(olms) + .set({ + clientId: client.clientId + }) + .where(eq(olms.olmId, olm.olmId)); + } else { + if (!olm.clientId) { + logger.warn("Olm has no client ID!"); + return; + } + + logger.debug(`Using last connected org for client ${olm.clientId}`); + + [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, olm.clientId)) + .limit(1); + } + + if (!client) { + logger.warn("Client ID not found"); return; } - const clientId = olm.clientId; - const { publicKey, relay, olmVersion } = message.data; logger.debug( - `Olm client ID: ${clientId}, Public Key: ${publicKey}, Relay: ${relay}` + `Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}` ); if (!publicKey) { @@ -33,18 +92,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } - // Get the client - const [client] = await db - .select() - .from(clients) - .where(eq(clients.clientId, clientId)) - .limit(1); - - if (!client) { - logger.warn("Client not found"); - return; - } - if (client.exitNodeId) { // TODO: FOR NOW WE ARE JUST HOLEPUNCHING ALL EXIT NODES BUT IN THE FUTURE WE SHOULD HANDLE THIS BETTER @@ -103,7 +150,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .set({ pubKey: publicKey }) - .where(eq(clients.clientId, olm.clientId)); + .where(eq(clients.clientId, client.clientId)); // set isRelay to false for all of the client's sites to reset the connection metadata await db @@ -111,7 +158,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .set({ isRelayed: relay == true }) - .where(eq(clientSites.clientId, olm.clientId)); + .where(eq(clientSites.clientId, client.clientId)); } // Get all sites data @@ -145,7 +192,9 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { // Validate endpoint and hole punch status if (!site.endpoint) { - logger.warn(`In olm register: site ${site.siteId} has no endpoint, skipping`); + logger.warn( + `In olm register: site ${site.siteId} has no endpoint, skipping` + ); continue; } @@ -240,3 +289,105 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { excludeSender: false }; }; + +async function getOrCreateOrgClient( + orgId: string, + userId: string, + deviceName?: string, + trx: Transaction | typeof db = db +): Promise { + let client: Client; + + // get the org + const [org] = await trx + .select() + .from(orgs) + .where(eq(orgs.orgId, orgId)) + .limit(1); + + if (!org) { + throw new Error("Org not found"); + } + + if (!org.subnet) { + throw new Error("Org has no subnet defined"); + } + + // Verify that the user belongs to the org + const [userOrg] = await trx + .select() + .from(userOrgs) + .where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId))) + .limit(1); + + if (!userOrg) { + throw new Error("User does not belong to org"); + } + + // check if the user has a client in the org and if not then create a client for them + const [existingClient] = await trx + .select() + .from(clients) + .where(and(eq(clients.orgId, orgId), eq(clients.userId, userId))) + .limit(1); + + if (!existingClient) { + logger.debug( + `Client does not exist in org ${orgId}, creating new client for user ${userId}` + ); + + // TODO: more intelligent way to pick the exit node + const exitNodesList = await listExitNodes(orgId); + const randomExitNode = + exitNodesList[Math.floor(Math.random() * exitNodesList.length)]; + + const [adminRole] = await trx + .select() + .from(roles) + .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) + .limit(1); + + if (!adminRole) { + throw new Error("Admin role not found"); + } + + const newSubnet = await getNextAvailableClientSubnet(orgId); + if (!newSubnet) { + throw new Error("No available subnet found"); + } + + const subnet = newSubnet.split("/")[0]; + const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org + + const [newClient] = await trx + .insert(clients) + .values({ + exitNodeId: randomExitNode.exitNodeId, + orgId, + name: deviceName || "User Device", + subnet: updatedSubnet, + type: "olm", + userId: userId + }) + .returning(); + + await trx.insert(roleClients).values({ + roleId: adminRole.roleId, + clientId: newClient.clientId + }); + + if (userOrg.roleId != adminRole.roleId) { + // make sure the user can access the client + trx.insert(userClients).values({ + userId, + clientId: newClient.clientId + }); + } + + client = newClient; + } else { + client = existingClient; + } + + return client; +}