Handle delete org and checking org policy

This commit is contained in:
Owen
2025-11-26 15:35:33 -05:00
parent ceae787cf5
commit de83cf9d8c
4 changed files with 155 additions and 34 deletions

View File

@@ -5,6 +5,8 @@ import { clients, Olm } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm";
import logger from "@server/logger";
import { validateSessionToken } from "@server/auth/sessions/app";
import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy";
import { sendTerminateClient } from "../client/terminate";
// Track if the offline checker interval is running
let offlineCheckerInterval: NodeJS.Timeout | null = null;
@@ -57,6 +59,9 @@ export const startOlmOfflineChecker = (): void => {
// Send a disconnect message to the client if connected
try {
await sendTerminateClient(offlineClient.clientId); // terminate first
// wait a moment to ensure the message is sent
await new Promise(resolve => setTimeout(resolve, 1000));
await disconnectClient(offlineClient.olmId);
} catch (error) {
logger.error(
@@ -110,6 +115,36 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
logger.warn("User ID mismatch for olm ping");
return;
}
// get the client
const [client] = await db
.select()
.from(clients)
.where(
and(
eq(clients.olmId, olm.olmId),
eq(clients.userId, olm.userId)
)
)
.limit(1);
if (!client) {
logger.warn("Client not found for olm ping");
return;
}
const policyCheck = await checkOrgAccessPolicy({
orgId: client.orgId,
userId: olm.userId,
session: userToken // this is the user token passed in the message
});
if (!policyCheck.allowed) {
logger.warn(
`Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}`
);
return;
}
}
if (!olm.clientId) {

View File

@@ -32,6 +32,8 @@ import {
} from "@server/lib/ip";
import { generateRemoteSubnets } from "@server/lib/ip";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!");
@@ -45,7 +47,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return;
}
const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient } =
const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient, token: userToken } =
message.data;
let client: Client | undefined;
@@ -78,6 +80,35 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return;
}
if (!olm.userId) {
logger.warn("Olm has no user ID");
return;
}
const { session: userSession, user } =
await validateSessionToken(userToken);
if (!userSession || !user) {
logger.warn("Invalid user session for olm ping");
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
}
if (user.userId !== olm.userId) {
logger.warn("User ID mismatch for olm ping");
return;
}
const policyCheck = await checkOrgAccessPolicy({
orgId: orgId,
userId: olm.userId,
session: userToken // this is the user token passed in the message
});
if (!policyCheck.allowed) {
logger.warn(
`Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`
);
return;
}
logger.debug(
`Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}`
);

View File

@@ -1,6 +1,15 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, domains, orgDomains, resources } from "@server/db";
import {
clients,
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
domains,
olms,
orgDomains,
resources
} from "@server/db";
import { newts, newtSessions, orgs, sites, userActions } from "@server/db";
import { eq, and, inArray, sql } from "drizzle-orm";
import response from "@server/lib/response";
@@ -14,8 +23,8 @@ import { deletePeer } from "../gerbil/peers";
import { OpenAPITags, registry } from "@server/openApi";
const deleteOrgSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export type DeleteOrgResponse = {};
@@ -69,41 +78,75 @@ export async function deleteOrg(
.where(eq(sites.orgId, orgId))
.limit(1);
const orgClients = await db
.select()
.from(clients)
.where(eq(clients.orgId, orgId));
const deletedNewtIds: string[] = [];
const olmsToTerminate: string[] = [];
await db.transaction(async (trx) => {
if (sites) {
for (const site of orgSites) {
if (site.pubKey) {
if (site.type == "wireguard") {
await deletePeer(site.exitNodeId!, site.pubKey);
} else if (site.type == "newt") {
// get the newt on the site by querying the newt table for siteId
const [deletedNewt] = await trx
.delete(newts)
.where(eq(newts.siteId, site.siteId))
.returning();
if (deletedNewt) {
deletedNewtIds.push(deletedNewt.newtId);
for (const site of orgSites) {
if (site.pubKey) {
if (site.type == "wireguard") {
await deletePeer(site.exitNodeId!, site.pubKey);
} else if (site.type == "newt") {
// get the newt on the site by querying the newt table for siteId
const [deletedNewt] = await trx
.delete(newts)
.where(eq(newts.siteId, site.siteId))
.returning();
if (deletedNewt) {
deletedNewtIds.push(deletedNewt.newtId);
// delete all of the sessions for the newt
await trx
.delete(newtSessions)
.where(
eq(
newtSessions.newtId,
deletedNewt.newtId
)
);
}
// delete all of the sessions for the newt
await trx
.delete(newtSessions)
.where(
eq(newtSessions.newtId, deletedNewt.newtId)
);
}
}
logger.info(`Deleting site ${site.siteId}`);
await trx
.delete(sites)
.where(eq(sites.siteId, site.siteId));
}
logger.info(`Deleting site ${site.siteId}`);
await trx.delete(sites).where(eq(sites.siteId, site.siteId));
}
for (const client of orgClients) {
const [olm] = await trx
.select()
.from(olms)
.where(eq(olms.clientId, client.clientId))
.limit(1);
if (olm) {
olmsToTerminate.push(olm.olmId);
}
logger.info(`Deleting client ${client.clientId}`);
await trx
.delete(clients)
.where(eq(clients.clientId, client.clientId));
// also delete the associations
await trx
.delete(clientSiteResourcesAssociationsCache)
.where(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
);
await trx
.delete(clientSitesAssociationsCache)
.where(
eq(
clientSitesAssociationsCache.clientId,
client.clientId
)
);
}
const allOrgDomains = await trx
@@ -162,6 +205,18 @@ export async function deleteOrg(
});
}
for (const olmId of olmsToTerminate) {
sendToClient(olmId, {
type: "olm/terminate",
data: {}
}).catch((error) => {
logger.error(
"Failed to send termination message to olm:",
error
);
});
}
return response(res, {
data: null,
success: true,

View File

@@ -125,7 +125,7 @@ export async function addUserRole(
.returning();
// get the client associated with this user in this org
const [orgClient] = await trx
const orgClients = await trx
.select()
.from(clients)
.where(
@@ -136,7 +136,7 @@ export async function addUserRole(
)
.limit(1);
if (orgClient) {
for (const orgClient of orgClients) {
// we just changed the user's role, so we need to rebuild client associations and what they have access to
await rebuildClientAssociationsFromClient(orgClient, trx);
}