Fix switching orgs having connections from other orgs

This commit is contained in:
Owen
2025-12-01 15:44:25 -05:00
parent a623604e96
commit b5e94d44ae
3 changed files with 95 additions and 212 deletions

View File

@@ -465,7 +465,8 @@ async function handleMessagesForSiteClients(
}
if (isAdd) {
await holepunchSiteAdd( // this will kick off the add peer process for the client
await holepunchSiteAdd(
// this will kick off the add peer process for the client
client.clientId,
{
siteId,
@@ -728,7 +729,19 @@ export async function rebuildClientAssociationsFromClient(
const userSiteResourceIds = await trx
.select({ siteResourceId: userSiteResources.siteResourceId })
.from(userSiteResources)
.where(eq(userSiteResources.userId, client.userId));
.innerJoin(
siteResources,
eq(
siteResources.siteResourceId,
userSiteResources.siteResourceId
)
)
.where(
and(
eq(userSiteResources.userId, client.userId),
eq(siteResources.orgId, client.orgId)
)
); // this needs to be locked onto this org or else cross-org access could happen
newSiteResourceIds.push(
...userSiteResourceIds.map((r) => r.siteResourceId)
@@ -738,7 +751,12 @@ export async function rebuildClientAssociationsFromClient(
const roleIds = await trx
.select({ roleId: userOrgs.roleId })
.from(userOrgs)
.where(eq(userOrgs.userId, client.userId))
.where(
and(
eq(userOrgs.userId, client.userId),
eq(userOrgs.orgId, client.orgId)
)
) // this needs to be locked onto this org or else cross-org access could happen
.then((rows) => rows.map((row) => row.roleId));
if (roleIds.length > 0) {

View File

@@ -1,5 +1,12 @@
import { generateSessionToken } from "@server/auth/sessions/app";
import { clients, db, ExitNode, exitNodes, sites, clientSitesAssociationsCache } from "@server/db";
import {
clients,
db,
ExitNode,
exitNodes,
sites,
clientSitesAssociationsCache
} from "@server/db";
import { olms } from "@server/db";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
@@ -99,6 +106,7 @@ export async function getOlmToken(
await createOlmSession(resToken, existingOlm.olmId);
let orgIdToUse = orgId;
let clientIdToUse;
if (!orgIdToUse) {
if (!existingOlm.clientId) {
return next(
@@ -114,7 +122,7 @@ export async function getOlmToken(
.from(clients)
.where(eq(clients.clientId, existingOlm.clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
@@ -125,6 +133,40 @@ export async function getOlmToken(
}
orgIdToUse = client.orgId;
clientIdToUse = client.clientId;
} else {
// we did provide the org
const [client] = await db
.select()
.from(clients)
.where(eq(clients.orgId, orgIdToUse))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No client found for provided orgId"
)
);
}
if (existingOlm.clientId !== client.clientId) {
// we only need to do this if the client is changing
logger.debug(
`Switching olm client ${existingOlm.olmId} to org ${orgId} for user ${existingOlm.userId}`
);
await db
.update(olms)
.set({
clientId: client.clientId
})
.where(eq(olms.olmId, existingOlm.olmId));
}
clientIdToUse = client.clientId;
}
// Get all exit nodes from sites where the client has peers
@@ -135,7 +177,7 @@ export async function getOlmToken(
sites,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, existingOlm.clientId!));
.where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!));
// Extract unique exit node IDs
const exitNodeIds = Array.from(

View File

@@ -48,45 +48,36 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return;
}
const {
publicKey,
relay,
olmVersion,
orgId,
doNotCreateNewClient,
token: userToken
} = message.data;
const { publicKey, relay, olmVersion, orgId, userToken } = message.data;
let client: Client | undefined;
let org: Org | undefined;
if (!olm.clientId) {
logger.warn("Olm client ID not found");
return;
}
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, olm.clientId))
.limit(1);
if (!client) {
logger.warn("Client ID not found");
return;
}
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, client.orgId))
.limit(1);
if (!org) {
logger.warn("Org not found");
return;
}
if (orgId) {
try {
const { client: clientRes, org: orgRes } =
await getOrCreateOrgClient(
orgId,
olm.userId,
olm.olmId,
olm.name || "User Device",
// doNotCreateNewClient ? true : false
true // for now never create a new client automatically because we create the users clients when they are added to the org
// this means that the rebuildClientAssociationsFromClient call below issue is not a problem
);
client = clientRes;
org = orgRes;
} catch (err) {
logger.error(
`Error switching olm client ${olm.olmId} to org ${orgId}: ${err}`
);
return;
}
if (!client) {
logger.warn("Client not found");
return;
}
if (!olm.userId) {
logger.warn("Olm has no user ID");
return;
@@ -95,11 +86,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
const { session: userSession, user } =
await validateSessionToken(userToken);
if (!userSession || !user) {
logger.warn("Invalid user session for olm ping");
logger.warn("Invalid user session for olm register");
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
}
if (user.userId !== olm.userId) {
logger.warn("User ID mismatch for olm ping");
logger.warn("User ID mismatch for olm register");
return;
}
@@ -115,48 +106,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
);
return;
}
logger.debug(
`Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}`
);
if (olm.clientId !== client.clientId) { // we only need to do this if the client is changing
await db
.update(olms)
.set({
clientId: client.clientId
})
.where(eq(olms.olmId, olm.olmId));
}
} else {
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
logger.debug(`Using last connected org for client ${olm.clientId}`);
[client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, olm.clientId))
.limit(1);
[org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, client.orgId))
.limit(1);
}
if (!client) {
logger.warn("Client ID not found");
return;
}
if (!org) {
logger.warn("Org not found");
return;
}
logger.debug(
@@ -357,129 +306,3 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
excludeSender: false
};
};
async function getOrCreateOrgClient(
orgId: string,
userId: string | null,
olmId: string,
name: string,
doNotCreateNewClient: boolean,
trx: Transaction | typeof db = db
): Promise<{
client: Client;
org: Org;
}> {
// get the org
const [org] = await trx
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
throw new Error("Org not found");
}
if (!org.subnet) {
throw new Error("Org has no subnet defined");
}
// check if the user has a client in the org and if not then create a client for them
const [existingClient] = await trx
.select()
.from(clients)
.where(
and(
eq(clients.orgId, orgId),
userId ? eq(clients.userId, userId) : isNull(clients.userId), // we dont check the user id if it is null because the olm is not tied to a user?
eq(clients.olmId, olmId)
)
) // checking the olmid here because we want to create a new client PER OLM PER ORG
.limit(1);
let client = existingClient;
if (!client && !doNotCreateNewClient) {
logger.debug(
`Client does not exist in org ${orgId}, creating new client for user ${userId}`
);
if (!userId) {
throw new Error("User ID is required to create client in org");
}
// Verify that the user belongs to the org
const [userOrg] = await trx
.select()
.from(userOrgs)
.where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId)))
.limit(1);
if (!userOrg) {
throw new Error("User does not belong to org");
}
// TODO: more intelligent way to pick the exit node
const exitNodesList = await listExitNodes(orgId);
const randomExitNode =
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
const [adminRole] = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (!adminRole) {
throw new Error("Admin role not found");
}
const newSubnet = await getNextAvailableClientSubnet(orgId);
if (!newSubnet) {
throw new Error("No available subnet found");
}
const subnet = newSubnet.split("/")[0];
const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
const [newClient] = await trx
.insert(clients)
.values({
exitNodeId: randomExitNode.exitNodeId,
orgId,
name,
subnet: updatedSubnet,
type: "olm",
userId: userId,
olmId: olmId // to lock this client to the olm even as the olm moves between clients in different orgs
})
.returning();
await trx.insert(roleClients).values({
roleId: adminRole.roleId,
clientId: newClient.clientId
});
await trx.insert(userClients).values({
// we also want to make sure that the user can see their own client if they are not an admin
userId,
clientId: newClient.clientId
});
if (userOrg.roleId != adminRole.roleId) {
// make sure the user can access the client
trx.insert(userClients).values({
userId,
clientId: newClient.clientId
});
}
await rebuildClientAssociationsFromClient(newClient, trx); // TODO: this will try to messages to the olm which has not connected yet - is that a problem?
client = newClient;
}
return {
client: client,
org: org
};
}