Update schmea; create client when registering

This commit is contained in:
Owen
2025-11-03 15:42:22 -08:00
parent 43590896e9
commit d30743a428
5 changed files with 200 additions and 33 deletions

View File

@@ -607,6 +607,10 @@ export const clients = pgTable("clients", {
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null"
}),
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
}),
name: varchar("name").notNull(),
pubKey: varchar("pubKey"),
subnet: varchar("subnet").notNull(),
@@ -638,6 +642,11 @@ export const olms = pgTable("olms", {
dateCreated: varchar("dateCreated").notNull(),
version: text("version"),
clientId: integer("clientId").references(() => clients.clientId, {
// we will switch this depending on the current org it wants to connect to
onDelete: "set null"
}),
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
})
});

View File

@@ -25,11 +25,10 @@ export const dnsRecords = sqliteTable("dnsRecords", {
recordType: text("recordType").notNull(), // "NS" | "CNAME" | "A" | "TXT"
baseDomain: text("baseDomain"),
value: text("value").notNull(),
verified: integer("verified", { mode: "boolean" }).notNull().default(false),
value: text("value").notNull(),
verified: integer("verified", { mode: "boolean" }).notNull().default(false)
});
export const orgs = sqliteTable("orgs", {
orgId: text("orgId").primaryKey(),
name: text("name").notNull(),
@@ -142,9 +141,10 @@ export const resources = sqliteTable("resources", {
onDelete: "set null"
}),
headers: text("headers"), // comma-separated list of headers to add to the request
proxyProtocol: integer("proxyProtocol", { mode: "boolean" }).notNull().default(false),
proxyProtocol: integer("proxyProtocol", { mode: "boolean" })
.notNull()
.default(false),
proxyProtocolVersion: integer("proxyProtocolVersion").default(1)
});
export const targets = sqliteTable("targets", {
@@ -315,6 +315,10 @@ export const clients = sqliteTable("clients", {
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null"
}),
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
}),
name: text("name").notNull(),
pubKey: text("pubKey"),
subnet: text("subnet").notNull(),
@@ -347,6 +351,10 @@ export const olms = sqliteTable("olms", {
dateCreated: text("dateCreated").notNull(),
version: text("version"),
clientId: integer("clientId").references(() => clients.clientId, {
// we will switch this depending on the current org it wants to connect to
onDelete: "set null"
}),
userId: text("userId").references(() => users.userId, { // optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
})
});

View File

@@ -197,7 +197,7 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
// // set the item in the database if it is offline
// if (isActuallyOnline != node.online) {
// await db
// await trx
// .update(exitNodes)
// .set({ online: isActuallyOnline })
// .where(eq(exitNodes.exitNodeId, node.exitNodeId));

View File

@@ -182,14 +182,13 @@ export async function createClient(
const randomExitNode =
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
const adminRole = await trx
const [adminRole] = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (adminRole.length === 0) {
trx.rollback();
if (!adminRole) {
return next(
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
);
@@ -207,12 +206,12 @@ export async function createClient(
.returning();
await trx.insert(roleClients).values({
roleId: adminRole[0].roleId,
roleId: adminRole.roleId,
clientId: newClient.clientId
});
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the site
if (req.user && req.userOrgRoleId != adminRole.roleId) {
// make sure the user can access the client
trx.insert(userClients).values({
userId: req.user?.userId!,
clientId: newClient.clientId

View File

@@ -1,10 +1,22 @@
import { db, ExitNode } from "@server/db";
import {
Client,
db,
ExitNode,
orgs,
roleClients,
roles,
Transaction,
userClients,
userOrgs,
users
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!");
@@ -17,15 +29,62 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
const { publicKey, relay, olmVersion, orgId, deviceName } = message.data;
let client: Client;
if (orgId) {
if (!olm.userId) {
logger.warn("Olm has no user ID to verify org change!");
return;
}
try {
client = await getOrCreateOrgClient(orgId, olm.userId, deviceName);
} catch (err) {
logger.error(
`Error switching olm client ${olm.olmId} to org ${orgId}: ${err}`
);
return;
}
if (!client) {
logger.warn("Client not found");
return;
}
logger.debug(
`Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}`
);
await db
.update(olms)
.set({
clientId: client.clientId
})
.where(eq(olms.olmId, olm.olmId));
} else {
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
logger.debug(`Using last connected org for client ${olm.clientId}`);
[client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, olm.clientId))
.limit(1);
}
if (!client) {
logger.warn("Client ID not found");
return;
}
const clientId = olm.clientId;
const { publicKey, relay, olmVersion } = message.data;
logger.debug(
`Olm client ID: ${clientId}, Public Key: ${publicKey}, Relay: ${relay}`
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
);
if (!publicKey) {
@@ -33,18 +92,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return;
}
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
if (client.exitNodeId) {
// TODO: FOR NOW WE ARE JUST HOLEPUNCHING ALL EXIT NODES BUT IN THE FUTURE WE SHOULD HANDLE THIS BETTER
@@ -103,7 +150,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.set({
pubKey: publicKey
})
.where(eq(clients.clientId, olm.clientId));
.where(eq(clients.clientId, client.clientId));
// set isRelay to false for all of the client's sites to reset the connection metadata
await db
@@ -111,7 +158,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.set({
isRelayed: relay == true
})
.where(eq(clientSites.clientId, olm.clientId));
.where(eq(clientSites.clientId, client.clientId));
}
// Get all sites data
@@ -145,7 +192,9 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// Validate endpoint and hole punch status
if (!site.endpoint) {
logger.warn(`In olm register: site ${site.siteId} has no endpoint, skipping`);
logger.warn(
`In olm register: site ${site.siteId} has no endpoint, skipping`
);
continue;
}
@@ -240,3 +289,105 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
excludeSender: false
};
};
async function getOrCreateOrgClient(
orgId: string,
userId: string,
deviceName?: string,
trx: Transaction | typeof db = db
): Promise<Client> {
let client: Client;
// get the org
const [org] = await trx
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
throw new Error("Org not found");
}
if (!org.subnet) {
throw new Error("Org has no subnet defined");
}
// Verify that the user belongs to the org
const [userOrg] = await trx
.select()
.from(userOrgs)
.where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId)))
.limit(1);
if (!userOrg) {
throw new Error("User does not belong to org");
}
// check if the user has a client in the org and if not then create a client for them
const [existingClient] = await trx
.select()
.from(clients)
.where(and(eq(clients.orgId, orgId), eq(clients.userId, userId)))
.limit(1);
if (!existingClient) {
logger.debug(
`Client does not exist in org ${orgId}, creating new client for user ${userId}`
);
// TODO: more intelligent way to pick the exit node
const exitNodesList = await listExitNodes(orgId);
const randomExitNode =
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
const [adminRole] = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (!adminRole) {
throw new Error("Admin role not found");
}
const newSubnet = await getNextAvailableClientSubnet(orgId);
if (!newSubnet) {
throw new Error("No available subnet found");
}
const subnet = newSubnet.split("/")[0];
const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
const [newClient] = await trx
.insert(clients)
.values({
exitNodeId: randomExitNode.exitNodeId,
orgId,
name: deviceName || "User Device",
subnet: updatedSubnet,
type: "olm",
userId: userId
})
.returning();
await trx.insert(roleClients).values({
roleId: adminRole.roleId,
clientId: newClient.clientId
});
if (userOrg.roleId != adminRole.roleId) {
// make sure the user can access the client
trx.insert(userClients).values({
userId,
clientId: newClient.clientId
});
}
client = newClient;
} else {
client = existingClient;
}
return client;
}