diff --git a/server/auth/sessions/app.ts b/server/auth/sessions/app.ts index 0e3da100..73b220fa 100644 --- a/server/auth/sessions/app.ts +++ b/server/auth/sessions/app.ts @@ -36,13 +36,15 @@ export async function createSession( const sessionId = encodeHexLowerCase( sha256(new TextEncoder().encode(token)) ); - const session: Session = { - sessionId: sessionId, - userId, - expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(), - issuedAt: new Date().getTime() - }; - await db.insert(sessions).values(session); + const [session] = await db + .insert(sessions) + .values({ + sessionId: sessionId, + userId, + expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(), + issuedAt: new Date().getTime() + }) + .returning(); return session; } diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index 120a7aa3..32b1252f 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -288,7 +288,7 @@ export const sessions = pgTable("session", { .references(() => users.userId, { onDelete: "cascade" }), expiresAt: bigint("expiresAt", { mode: "number" }).notNull(), issuedAt: bigint("issuedAt", { mode: "number" }), - deviceAuthUsed: boolean("deviceAuthUsed") + deviceAuthUsed: boolean("deviceAuthUsed").notNull().default(false) }); export const newtSessions = pgTable("newtSession", { @@ -665,7 +665,8 @@ export const clientSitesAssociationsCache = pgTable( .notNull(), siteId: integer("siteId").notNull(), isRelayed: boolean("isRelayed").notNull().default(false), - endpoint: varchar("endpoint") + endpoint: varchar("endpoint"), + publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes } ); diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index a5f3d0f6..8b42a461 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -1,6 +1,7 @@ import { randomUUID } from "crypto"; import { InferSelectModel } from "drizzle-orm"; import { sqliteTable, text, integer, index } from "drizzle-orm/sqlite-core"; +import { no } from "zod/v4/locales"; export const domains = sqliteTable("domains", { domainId: text("domainId").primaryKey(), @@ -372,7 +373,8 @@ export const clientSitesAssociationsCache = sqliteTable( isRelayed: integer("isRelayed", { mode: "boolean" }) .notNull() .default(false), - endpoint: text("endpoint") + endpoint: text("endpoint"), + publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes } ); @@ -417,6 +419,8 @@ export const sessions = sqliteTable("session", { expiresAt: integer("expiresAt").notNull(), issuedAt: integer("issuedAt"), deviceAuthUsed: integer("deviceAuthUsed", { mode: "boolean" }) + .notNull() + .default(false) }); export const newtSessions = sqliteTable("newtSession", { diff --git a/server/lib/readConfigFile.ts b/server/lib/readConfigFile.ts index dc5ec729..fe0dd593 100644 --- a/server/lib/readConfigFile.ts +++ b/server/lib/readConfigFile.ts @@ -229,6 +229,11 @@ export const configSchema = z .default(51820) .transform(stoi) .pipe(portSchema), + clients_start_port: portSchema + .optional() + .default(21820) + .transform(stoi) + .pipe(portSchema), base_endpoint: z .string() .optional() diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index f1cbea0c..4be2c92a 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -25,6 +25,7 @@ import { deletePeer as newtDeletePeer } from "@server/routers/newt/peers"; import { + initPeerAddHandshake as holepunchSiteAdd, addPeer as olmAddPeer, deletePeer as olmDeletePeer } from "@server/routers/olm/peers"; @@ -464,65 +465,16 @@ async function handleMessagesForSiteClients( } if (isAdd) { - // TODO: WE NEED TO HANDLE THIS BETTER. WE ARE DEFAULTING TO RELAYING FOR NEW SITES - // BUT REALLY WE NEED TO TRACK THE USERS PREFERENCE THAT THEY CHOSE IN THE CLIENTS - // AND TRIGGER A HOLEPUNCH OR SOMETHING TO GET THE ENDPOINT AND HP TO THE NEW SITES - const isRelayed = true; - - newtJobs.push( - newtAddPeer( + await holepunchSiteAdd( // this will kick off the add peer process for the client + client.clientId, + { siteId, - { - publicKey: client.pubKey, - allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client - // endpoint: isRelayed ? "" : clientSite.endpoint - endpoint: isRelayed ? "" : "" // we are not HPing yet so no endpoint - }, - newt.newtId - ) - ); - - // TODO: should we have this here? - const allSiteResources = await db // only get the site resources that this client has access to - .select() - .from(siteResources) - .innerJoin( - clientSiteResourcesAssociationsCache, - eq( - siteResources.siteResourceId, - clientSiteResourcesAssociationsCache.siteResourceId - ) - ) - .where( - and( - eq(siteResources.siteId, site.siteId), - eq( - clientSiteResourcesAssociationsCache.clientId, - client.clientId - ) - ) - ); - - olmJobs.push( - olmAddPeer( - client.clientId, - { - siteId: site.siteId, - endpoint: - isRelayed || !site.endpoint - ? `${exitNode.endpoint}:21820` - : site.endpoint, - publicKey: site.publicKey, - serverIP: site.address, - serverPort: site.listenPort, - remoteSubnets: generateRemoteSubnets( - allSiteResources.map( - ({ siteResources }) => siteResources - ) - ) - }, - olm.olmId - ) + exitNode: { + publicKey: exitNode.publicKey, + endpoint: exitNode.endpoint + } + }, + olm.olmId ); } diff --git a/server/private/routers/hybrid.ts b/server/private/routers/hybrid.ts index f78fb592..29e766b9 100644 --- a/server/private/routers/hybrid.ts +++ b/server/private/routers/hybrid.ts @@ -1369,7 +1369,7 @@ const updateHolePunchSchema = z.object({ port: z.number(), timestamp: z.number(), reachableAt: z.string().optional(), - publicKey: z.string().optional() + publicKey: z.string() // this is the client public key }); hybridRouter.post( "/gerbil/update-hole-punch", @@ -1408,7 +1408,7 @@ hybridRouter.post( ); } - const { olmId, newtId, ip, port, timestamp, token, reachableAt } = + const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } = parsedParams.data; const destinations = await updateAndGenerateEndpointDestinations( @@ -1418,6 +1418,7 @@ hybridRouter.post( port, timestamp, token, + publicKey, exitNode, true ); diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index f049a5c1..c5f0d2b8 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -30,8 +30,9 @@ const updateHolePunchSchema = z.object({ ip: z.string(), port: z.number(), timestamp: z.number(), + publicKey: z.string(), reachableAt: z.string().optional(), - publicKey: z.string().optional() + exitNodePublicKey: z.string().optional() }); // New response type with multi-peer destination support @@ -65,23 +66,26 @@ export async function updateHolePunch( timestamp, token, reachableAt, - publicKey + publicKey, // this is the client's current public key for this session + exitNodePublicKey } = parsedParams.data; let exitNode: ExitNode | undefined; - if (publicKey) { + if (exitNodePublicKey) { // Get the exit node by public key [exitNode] = await db .select() .from(exitNodes) - .where(eq(exitNodes.publicKey, publicKey)); + .where(eq(exitNodes.publicKey, exitNodePublicKey)); } else { // FOR BACKWARDS COMPATIBILITY IF GERBIL IS STILL =<1.1.0 [exitNode] = await db.select().from(exitNodes).limit(1); } if (!exitNode) { - logger.warn(`Exit node not found for publicKey: ${publicKey}`); + logger.warn( + `Exit node not found for publicKey: ${exitNodePublicKey}` + ); return next( createHttpError(HttpCode.NOT_FOUND, "Exit node not found") ); @@ -94,12 +98,13 @@ export async function updateHolePunch( port, timestamp, token, + publicKey, exitNode ); - logger.debug( - `Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}` - ); + // logger.debug( + // `Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}` + // ); // Return the new multi-peer structure return res.status(HttpCode.OK).send({ @@ -123,6 +128,7 @@ export async function updateAndGenerateEndpointDestinations( port: number, timestamp: number, token: string, + publicKey: string, exitNode: ExitNode, checkOrg = false ) { @@ -130,9 +136,9 @@ export async function updateAndGenerateEndpointDestinations( const destinations: PeerDestination[] = []; if (olmId) { - logger.debug( - `Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}` - ); + // logger.debug( + // `Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}` + // ); const { session, olm: olmSession } = await validateOlmSessionToken(token); @@ -180,6 +186,7 @@ export async function updateAndGenerateEndpointDestinations( siteId: sites.siteId, subnet: sites.subnet, listenPort: sites.listenPort, + publicKey: sites.publicKey, endpoint: clientSitesAssociationsCache.endpoint }) .from(sites) @@ -200,10 +207,19 @@ export async function updateAndGenerateEndpointDestinations( `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}` ); + // if the public key or endpoint has changed, update it otherwise continue + if ( + site.endpoint === `${ip}:${port}` && + site.publicKey === publicKey + ) { + continue; + } + const [updatedClientSitesAssociationsCache] = await db .update(clientSitesAssociationsCache) .set({ - endpoint: `${ip}:${port}` + endpoint: `${ip}:${port}`, + publicKey: publicKey }) .where( and( @@ -227,9 +243,9 @@ export async function updateAndGenerateEndpointDestinations( } } - logger.debug( - `Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}` - ); + // logger.debug( + // `Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}` + // ); if (!updatedClient) { logger.warn(`Client not found for olm: ${olmId}`); throw new Error("Client not found"); @@ -245,9 +261,9 @@ export async function updateAndGenerateEndpointDestinations( } } } else if (newtId) { - logger.debug( - `Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}` - ); + // logger.debug( + // `Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}` + // ); const { session, newt: newtSession } = await validateNewtSessionToken(token); @@ -407,7 +423,7 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) { { siteId: siteId, publicKey: site.publicKey, - endpoint: newEndpoint, + endpoint: newEndpoint }, client.olmId ); diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts index fbbcb4fb..36105d7e 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -79,12 +79,12 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { // TODO: somehow we should make sure a recent hole punch has happened if this occurs (hole punch could be from the last restart if done quickly) } - // if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 6) { - // logger.warn( - // `Site ${existingSite.siteId} last hole punch is too old, skipping` - // ); - // return; - // } + if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 5) { + logger.warn( + `Site ${existingSite.siteId} last hole punch is too old, skipping` + ); + return; + } // update the endpoint and the public key const [site] = await db diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index c26f5936..33f5fa2d 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -1,9 +1,9 @@ import { generateSessionToken } from "@server/auth/sessions/app"; -import { db } from "@server/db"; +import { clients, db, ExitNode, exitNodes, sites, clientSitesAssociationsCache } from "@server/db"; import { olms } from "@server/db"; import HttpCode from "@server/types/HttpCode"; import response from "@server/lib/response"; -import { eq } from "drizzle-orm"; +import { eq, inArray } from "drizzle-orm"; import { NextFunction, Request, Response } from "express"; import createHttpError from "http-errors"; import { z } from "zod"; @@ -15,11 +15,13 @@ import { import { verifyPassword } from "@server/auth/password"; import logger from "@server/logger"; import config from "@server/lib/config"; +import { listExitNodes } from "#dynamic/lib/exitNodes"; export const olmGetTokenBodySchema = z.object({ olmId: z.string(), secret: z.string(), - token: z.string().optional() + token: z.string().optional(), + orgId: z.string().optional() }); export type OlmGetTokenBody = z.infer; @@ -40,7 +42,7 @@ export async function getOlmToken( ); } - const { olmId, secret, token } = parsedBody.data; + const { olmId, secret, token, orgId } = parsedBody.data; try { if (token) { @@ -61,11 +63,12 @@ export async function getOlmToken( } } - const existingOlmRes = await db + const [existingOlm] = await db .select() .from(olms) .where(eq(olms.olmId, olmId)); - if (!existingOlmRes || !existingOlmRes.length) { + + if (!existingOlm) { return next( createHttpError( HttpCode.BAD_REQUEST, @@ -74,12 +77,11 @@ export async function getOlmToken( ); } - const existingOlm = existingOlmRes[0]; - const validSecret = await verifyPassword( secret, existingOlm.secretHash ); + if (!validSecret) { if (config.getRawConfig().app.log_failed_attempts) { logger.info( @@ -96,11 +98,78 @@ export async function getOlmToken( const resToken = generateSessionToken(); await createOlmSession(resToken, existingOlm.olmId); + let orgIdToUse = orgId; + if (!orgIdToUse) { + if (!existingOlm.clientId) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Olm is not associated with a client, orgId is required" + ) + ); + } + + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, existingOlm.clientId)) + .limit(1); + + if (!client) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Olm's associated client not found, orgId is required" + ) + ); + } + + orgIdToUse = client.orgId; + } + + // Get all exit nodes from sites where the client has peers + const clientSites = await db + .select() + .from(clientSitesAssociationsCache) + .innerJoin( + sites, + eq(sites.siteId, clientSitesAssociationsCache.siteId) + ) + .where(eq(clientSitesAssociationsCache.clientId, existingOlm.clientId!)); + + // Extract unique exit node IDs + const exitNodeIds = Array.from( + new Set( + clientSites + .map(({ sites: site }) => site.exitNodeId) + .filter((id): id is number => id !== null) + ) + ); + + let allExitNodes: ExitNode[] = []; + if (exitNodeIds.length > 0) { + allExitNodes = await db + .select() + .from(exitNodes) + .where(inArray(exitNodes.exitNodeId, exitNodeIds)); + } + + const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => { + return { + publicKey: exitNode.publicKey, + endpoint: exitNode.endpoint + }; + }); + logger.debug("Token created successfully"); - return response<{ token: string }>(res, { + return response<{ + token: string; + exitNodes: { publicKey: string; endpoint: string }[]; + }>(res, { data: { - token: resToken + token: resToken, + exitNodes: exitNodesHpData }, success: true, error: false, diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index ee9443f5..4bcbbb8b 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -59,7 +59,7 @@ export const startOlmOfflineChecker = (): void => { // Send a disconnect message to the client if connected try { - await sendTerminateClient(offlineClient.clientId); // terminate first + await sendTerminateClient(offlineClient.clientId, offlineClient.olmId); // terminate first // wait a moment to ensure the message is sent await new Promise(resolve => setTimeout(resolve, 1000)); await disconnectClient(offlineClient.olmId); diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 53cd9815..ce40e832 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -34,21 +34,28 @@ import { generateRemoteSubnets } from "@server/lib/ip"; import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; import { validateSessionToken } from "@server/auth/sessions/app"; +import config from "@server/lib/config"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.info("Handling register olm message!"); const { message, client: c, sendToClient } = context; const olm = c as Olm; - const now = new Date().getTime() / 1000; + const now = Math.floor(Date.now() / 1000); if (!olm) { logger.warn("Olm not found"); return; } - const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient, token: userToken } = - message.data; + const { + publicKey, + relay, + olmVersion, + orgId, + doNotCreateNewClient, + token: userToken + } = message.data; let client: Client | undefined; let org: Org | undefined; @@ -63,7 +70,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { olm.name || "User Device", // doNotCreateNewClient ? true : false true // for now never create a new client automatically because we create the users clients when they are added to the org - // this means that the rebuildClientAssociationsFromClient call below issue is not a problem + // this means that the rebuildClientAssociationsFromClient call below issue is not a problem ); client = clientRes; @@ -113,12 +120,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { `Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}` ); - await db - .update(olms) - .set({ - clientId: client.clientId - }) - .where(eq(olms.olmId, olm.olmId)); + if (olm.clientId !== client.clientId) { // we only need to do this if the client is changing + await db + .update(olms) + .set({ + clientId: client.clientId + }) + .where(eq(olms.olmId, olm.olmId)); + } } else { if (!olm.clientId) { logger.warn("Olm has no client ID!"); @@ -159,41 +168,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } - if (client.exitNodeId) { - // TODO: FOR NOW WE ARE JUST HOLEPUNCHING ALL EXIT NODES BUT IN THE FUTURE WE SHOULD HANDLE THIS BETTER - - // Get the exit node - const allExitNodes = await listExitNodes(client.orgId, true); // FILTER THE ONLINE ONES - - const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => { - return { - publicKey: exitNode.publicKey, - endpoint: exitNode.endpoint - }; - }); - - // Send holepunch message - await sendToClient(olm.olmId, { - type: "olm/wg/holepunch/all", - data: { - exitNodes: exitNodesHpData - } - }); - - if (!olmVersion) { - // THIS IS FOR BACKWARDS COMPATIBILITY - // THE OLDER CLIENTS DID NOT SEND THE VERSION - await sendToClient(olm.olmId, { - type: "olm/wg/holepunch", - data: { - serverPubKey: allExitNodes[0].publicKey, - endpoint: allExitNodes[0].endpoint - } - }); - } - } - - if (olmVersion) { + if (olmVersion && olm.version !== olmVersion) { await db .update(olms) .set({ @@ -202,10 +177,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .where(eq(olms.olmId, olm.olmId)); } - // if (now - (client.lastHolePunch || 0) > 6) { - // logger.warn("Client last hole punch is too old, skipping all sites"); - // return; - // } + // this prevents us from accepting a register from an olm that has not hole punched yet. + // the olm will pump the register so we can keep checking + if (now - (client.lastHolePunch || 0) > 5) { + logger.warn( + "Client last hole punch is too old; skipping this register" + ); + return; + } if (client.pubKey !== publicKey) { logger.info( @@ -319,7 +298,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.warn(`Exit node not found for site ${site.siteId}`); continue; } - endpoint = `${exitNode.endpoint}:21820`; + endpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`; } const allSiteResources = await db // only get the site resources that this client has access to diff --git a/server/routers/olm/handleOlmRelayMessage.ts b/server/routers/olm/handleOlmRelayMessage.ts index 153c4e7c..5479ccbb 100644 --- a/server/routers/olm/handleOlmRelayMessage.ts +++ b/server/routers/olm/handleOlmRelayMessage.ts @@ -2,7 +2,7 @@ import { db, exitNodes, sites } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; import { clients, clientSitesAssociationsCache, Olm } from "@server/db"; import { and, eq } from "drizzle-orm"; -import { updatePeer } from "../newt/peers"; +import { updatePeer as newtUpdatePeer } from "../newt/peers"; import logger from "@server/logger"; export const handleOlmRelayMessage: MessageHandler = async (context) => { @@ -79,18 +79,20 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => { ); // update the peer on the exit node - await updatePeer(siteId, client.pubKey, { - endpoint: "" // this removes the endpoint + await newtUpdatePeer(siteId, client.pubKey, { + endpoint: "" // this removes the endpoint so the exit node knows to relay }); - sendToClient(olm.olmId, { - type: "olm/wg/peer/relay", + return { + message: { + type: "olm/wg/peer/relay", data: { siteId: siteId, endpoint: exitNode.endpoint, publicKey: exitNode.publicKey } - }); - - return; + }, + broadcast: false, + excludeSender: false + }; }; diff --git a/server/routers/olm/handleOlmServerPeerAddMessage.ts b/server/routers/olm/handleOlmServerPeerAddMessage.ts new file mode 100644 index 00000000..3d0d61b2 --- /dev/null +++ b/server/routers/olm/handleOlmServerPeerAddMessage.ts @@ -0,0 +1,185 @@ +import { + Client, + clientSiteResourcesAssociationsCache, + db, + ExitNode, + Org, + orgs, + roleClients, + roles, + siteResources, + Transaction, + userClients, + userOrgs, + users +} from "@server/db"; +import { MessageHandler } from "@server/routers/ws"; +import { + clients, + clientSitesAssociationsCache, + exitNodes, + Olm, + olms, + sites +} from "@server/db"; +import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm"; +import { addPeer, deletePeer } from "../newt/peers"; +import logger from "@server/logger"; +import { listExitNodes } from "#dynamic/lib/exitNodes"; +import { + generateAliasConfig, + getNextAvailableClientSubnet +} from "@server/lib/ip"; +import { generateRemoteSubnets } from "@server/lib/ip"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; +import { validateSessionToken } from "@server/auth/sessions/app"; +import config from "@server/lib/config"; +import { + addPeer as newtAddPeer, + deletePeer as newtDeletePeer +} from "@server/routers/newt/peers"; + +export const handleOlmServerPeerAddMessage: MessageHandler = async ( + context +) => { + logger.info("Handling register olm message!"); + const { message, client: c, sendToClient } = context; + const olm = c as Olm; + + const now = Math.floor(Date.now() / 1000); + + if (!olm) { + logger.warn("Olm not found"); + return; + } + + const { siteId } = message.data; + + // get the site + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + + if (!site) { + logger.error( + `handleOlmServerPeerAddMessage: Site with ID ${siteId} not found` + ); + return; + } + + if (!site.endpoint) { + logger.error( + `handleOlmServerPeerAddMessage: Site with ID ${siteId} has no endpoint` + ); + return; + } + + // get the client + + if (!olm.clientId) { + logger.error( + `handleOlmServerPeerAddMessage: Olm with ID ${olm.olmId} has no clientId` + ); + return; + } + + const [client] = await db + .select() + .from(clients) + .where(and(eq(clients.clientId, olm.clientId))) + .limit(1); + + if (!client) { + logger.error( + `handleOlmServerPeerAddMessage: Client with ID ${olm.clientId} not found` + ); + return; + } + + if (!client.pubKey) { + logger.error( + `handleOlmServerPeerAddMessage: Client with ID ${client.clientId} has no public key` + ); + return; + } + + let endpoint: string | null = null; + + + const currentSessionSiteAssociationCaches = await db + .select() + .from(clientSitesAssociationsCache) + .where( + and( + eq(clientSitesAssociationsCache.clientId, client.clientId), + isNotNull(clientSitesAssociationsCache.endpoint), + eq(clientSitesAssociationsCache.publicKey, client.pubKey) // limit it to the current session its connected with otherwise the endpoint could be stale + ) + ); + + // pick an endpoint + for (const assoc of currentSessionSiteAssociationCaches) { + if (assoc.endpoint) { + endpoint = assoc.endpoint; + break; + } + } + + if (!endpoint) { + logger.error( + `handleOlmServerPeerAddMessage: No endpoint found for client ${client.clientId}` + ); + return; + } + + await newtAddPeer(siteId, { + publicKey: client.pubKey, + allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client + endpoint: endpoint // this is the client's endpoint with reference to the site's exit node + }); + + const allSiteResources = await db // only get the site resources that this client has access to + .select() + .from(siteResources) + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + siteResources.siteResourceId, + clientSiteResourcesAssociationsCache.siteResourceId + ) + ) + .where( + and( + eq(siteResources.siteId, site.siteId), + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ) + ) + ); + + // Return connect message with all site configurations + return { + message: { + type: "olm/wg/peer/add", + data: { + siteId: site.siteId, + endpoint: site.endpoint, + publicKey: site.publicKey, + serverIP: site.address, + serverPort: site.listenPort, + remoteSubnets: generateRemoteSubnets( + allSiteResources.map(({ siteResources }) => siteResources) + ), + aliases: generateAliasConfig( + allSiteResources.map(({ siteResources }) => siteResources) + ) + } + }, + broadcast: false, + excludeSender: false + }; +}; diff --git a/server/routers/olm/index.ts b/server/routers/olm/index.ts index 7adbf859..0fc65d92 100644 --- a/server/routers/olm/index.ts +++ b/server/routers/olm/index.ts @@ -7,3 +7,4 @@ export * from "./deleteUserOlm"; export * from "./listUserOlms"; export * from "./deleteUserOlm"; export * from "./getUserOlm"; +export * from "./handleOlmServerPeerAddMessage"; \ No newline at end of file diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index 69ea2bc9..7651f0a9 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -3,6 +3,7 @@ import { clients, olms, newts, sites } from "@server/db"; import { eq } from "drizzle-orm"; import { sendToClient } from "#dynamic/routers/ws"; import logger from "@server/logger"; +import { exit } from "process"; export async function addPeer( clientId: number, @@ -110,3 +111,40 @@ export async function updatePeer( logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`); } + +export async function initPeerAddHandshake( + clientId: number, + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }, + olmId?: string +) { + if (!olmId) { + const [olm] = await db + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + olmId = olm.olmId; + } + + await sendToClient(olmId, { + type: "olm/wg/peer/holepunch/site/add", + data: { + siteId: peer.siteId, + exitNode: { + publicKey: peer.exitNode.publicKey, + endpoint: peer.exitNode.endpoint + } + } + }); + + logger.info(`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`); +} diff --git a/server/routers/ws/messageHandlers.ts b/server/routers/ws/messageHandlers.ts index cbb023b3..b92e7530 100644 --- a/server/routers/ws/messageHandlers.ts +++ b/server/routers/ws/messageHandlers.ts @@ -11,23 +11,25 @@ import { handleOlmRegisterMessage, handleOlmRelayMessage, handleOlmPingMessage, - startOlmOfflineChecker + startOlmOfflineChecker, + handleOlmServerPeerAddMessage } from "../olm"; import { handleHealthcheckStatusMessage } from "../target"; import { MessageHandler } from "./types"; export const messageHandlers: Record = { - "newt/wg/register": handleNewtRegisterMessage, + "olm/wg/server/peer/add": handleOlmServerPeerAddMessage, "olm/wg/register": handleOlmRegisterMessage, - "newt/wg/get-config": handleGetConfigMessage, - "newt/receive-bandwidth": handleReceiveBandwidthMessage, "olm/wg/relay": handleOlmRelayMessage, "olm/ping": handleOlmPingMessage, + "newt/wg/register": handleNewtRegisterMessage, + "newt/wg/get-config": handleGetConfigMessage, + "newt/receive-bandwidth": handleReceiveBandwidthMessage, "newt/socket/status": handleDockerStatusMessage, "newt/socket/containers": handleDockerContainersMessage, "newt/ping/request": handleNewtPingRequestMessage, "newt/blueprint/apply": handleApplyBlueprintMessage, - "newt/healthcheck/status": handleHealthcheckStatusMessage, + "newt/healthcheck/status": handleHealthcheckStatusMessage }; -startOlmOfflineChecker(); // this is to handle the offline check for olms \ No newline at end of file +startOlmOfflineChecker(); // this is to handle the offline check for olms