Merge branch 'dev' into feat/resource-policies

This commit is contained in:
Fred KISSIE
2026-02-28 01:08:12 +01:00
214 changed files with 13059 additions and 7647 deletions

View File

@@ -11,11 +11,11 @@
* This file is not licensed under the AGPLv3.
*/
import { accessAuditLog, db, resources } from "@server/db";
import { accessAuditLog, logsDb, resources, db, primaryDb } from "@server/db";
import { registry } from "@server/openApi";
import { NextFunction } from "express";
import { Request, Response } from "express";
import { eq, gt, lt, and, count, desc } from "drizzle-orm";
import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm";
import { OpenAPITags } from "@server/openApi";
import { z } from "zod";
import createHttpError from "http-errors";
@@ -115,15 +115,13 @@ function getWhere(data: Q) {
}
export function queryAccess(data: Q) {
return db
return logsDb
.select({
orgId: accessAuditLog.orgId,
action: accessAuditLog.action,
actorType: accessAuditLog.actorType,
actorId: accessAuditLog.actorId,
resourceId: accessAuditLog.resourceId,
resourceName: resources.name,
resourceNiceId: resources.niceId,
ip: accessAuditLog.ip,
location: accessAuditLog.location,
userAgent: accessAuditLog.userAgent,
@@ -133,16 +131,46 @@ export function queryAccess(data: Q) {
actor: accessAuditLog.actor
})
.from(accessAuditLog)
.leftJoin(
resources,
eq(accessAuditLog.resourceId, resources.resourceId)
)
.where(getWhere(data))
.orderBy(desc(accessAuditLog.timestamp), desc(accessAuditLog.id));
}
async function enrichWithResourceDetails(logs: Awaited<ReturnType<typeof queryAccess>>) {
// If logs database is the same as main database, we can do a join
// Otherwise, we need to fetch resource details separately
const resourceIds = logs
.map(log => log.resourceId)
.filter((id): id is number => id !== null && id !== undefined);
if (resourceIds.length === 0) {
return logs.map(log => ({ ...log, resourceName: null, resourceNiceId: null }));
}
// Fetch resource details from main database
const resourceDetails = await primaryDb
.select({
resourceId: resources.resourceId,
name: resources.name,
niceId: resources.niceId
})
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
// Create a map for quick lookup
const resourceMap = new Map(
resourceDetails.map(r => [r.resourceId, { name: r.name, niceId: r.niceId }])
);
// Enrich logs with resource details
return logs.map(log => ({
...log,
resourceName: log.resourceId ? resourceMap.get(log.resourceId)?.name ?? null : null,
resourceNiceId: log.resourceId ? resourceMap.get(log.resourceId)?.niceId ?? null : null
}));
}
export function countAccessQuery(data: Q) {
const countQuery = db
const countQuery = logsDb
.select({ count: count() })
.from(accessAuditLog)
.where(getWhere(data));
@@ -161,7 +189,7 @@ async function queryUniqueFilterAttributes(
);
// Get unique actors
const uniqueActors = await db
const uniqueActors = await logsDb
.selectDistinct({
actor: accessAuditLog.actor
})
@@ -169,7 +197,7 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
// Get unique locations
const uniqueLocations = await db
const uniqueLocations = await logsDb
.selectDistinct({
locations: accessAuditLog.location
})
@@ -177,25 +205,40 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
// Get unique resources with names
const uniqueResources = await db
const uniqueResources = await logsDb
.selectDistinct({
id: accessAuditLog.resourceId,
name: resources.name
id: accessAuditLog.resourceId
})
.from(accessAuditLog)
.leftJoin(
resources,
eq(accessAuditLog.resourceId, resources.resourceId)
)
.where(baseConditions);
// Fetch resource names from main database for the unique resource IDs
const resourceIds = uniqueResources
.map(row => row.id)
.filter((id): id is number => id !== null);
let resourcesWithNames: Array<{ id: number; name: string | null }> = [];
if (resourceIds.length > 0) {
const resourceDetails = await primaryDb
.select({
resourceId: resources.resourceId,
name: resources.name
})
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
resourcesWithNames = resourceDetails.map(r => ({
id: r.resourceId,
name: r.name
}));
}
return {
actors: uniqueActors
.map((row) => row.actor)
.filter((actor): actor is string => actor !== null),
resources: uniqueResources.filter(
(row): row is { id: number; name: string | null } => row.id !== null
),
resources: resourcesWithNames,
locations: uniqueLocations
.map((row) => row.locations)
.filter((location): location is string => location !== null)
@@ -243,7 +286,10 @@ export async function queryAccessAuditLogs(
const baseQuery = queryAccess(data);
const log = await baseQuery.limit(data.limit).offset(data.offset);
const logsRaw = await baseQuery.limit(data.limit).offset(data.offset);
// Enrich with resource details (handles cross-database scenario)
const log = await enrichWithResourceDetails(logsRaw);
const totalCountResult = await countAccessQuery(data);
const totalCount = totalCountResult[0].count;

View File

@@ -11,7 +11,7 @@
* This file is not licensed under the AGPLv3.
*/
import { actionAuditLog, db } from "@server/db";
import { actionAuditLog, logsDb } from "@server/db";
import { registry } from "@server/openApi";
import { NextFunction } from "express";
import { Request, Response } from "express";
@@ -97,7 +97,7 @@ function getWhere(data: Q) {
}
export function queryAction(data: Q) {
return db
return logsDb
.select({
orgId: actionAuditLog.orgId,
action: actionAuditLog.action,
@@ -113,7 +113,7 @@ export function queryAction(data: Q) {
}
export function countActionQuery(data: Q) {
const countQuery = db
const countQuery = logsDb
.select({ count: count() })
.from(actionAuditLog)
.where(getWhere(data));
@@ -132,14 +132,14 @@ async function queryUniqueFilterAttributes(
);
// Get unique actors
const uniqueActors = await db
const uniqueActors = await logsDb
.selectDistinct({
actor: actionAuditLog.actor
})
.from(actionAuditLog)
.where(baseConditions);
const uniqueActions = await db
const uniqueActions = await logsDb
.selectDistinct({
action: actionAuditLog.action
})

View File

@@ -15,7 +15,19 @@ import { SubscriptionType } from "./hooks/getSubType";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { Tier } from "@server/types/Tiers";
import logger from "@server/logger";
import { db, idp, idpOrg, loginPage, loginPageBranding, loginPageBrandingOrg, loginPageOrg, orgs, resources, roles } from "@server/db";
import {
db,
idp,
idpOrg,
loginPage,
loginPageBranding,
loginPageBrandingOrg,
loginPageOrg,
orgs,
resources,
roles,
siteResources
} from "@server/db";
import { eq } from "drizzle-orm";
/**
@@ -59,10 +71,7 @@ async function capRetentionDays(
}
// Get current org settings
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId));
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
logger.warn(`Org ${orgId} not found when capping retention days`);
@@ -110,18 +119,13 @@ async function capRetentionDays(
// Apply updates if needed
if (needsUpdate) {
await db
.update(orgs)
.set(updates)
.where(eq(orgs.orgId, orgId));
await db.update(orgs).set(updates).where(eq(orgs.orgId, orgId));
logger.info(
`Successfully capped retention days for org ${orgId} to max ${maxRetentionDays} days`
);
} else {
logger.debug(
`No retention day capping needed for org ${orgId}`
);
logger.debug(`No retention day capping needed for org ${orgId}`);
}
}
@@ -134,6 +138,35 @@ export async function handleTierChange(
`Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}`
);
// Get all orgs that have this orgId as their billingOrgId
const associatedOrgs = await db
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, orgId));
logger.info(
`Found ${associatedOrgs.length} org(s) associated with billing org ${orgId}`
);
// Loop over all associated orgs and apply tier changes
for (const org of associatedOrgs) {
await handleTierChangeForOrg(org.orgId, newTier, previousTier);
}
logger.info(
`Completed tier change handling for all orgs associated with billing org ${orgId}`
);
}
async function handleTierChangeForOrg(
orgId: string,
newTier: SubscriptionType | null,
previousTier?: SubscriptionType | null
): Promise<void> {
logger.info(
`Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}`
);
// License subscriptions are handled separately and don't use the tier matrix
if (newTier === "license") {
logger.debug(
@@ -254,6 +287,10 @@ async function disableFeature(
await disableAutoProvisioning(orgId);
break;
case TierFeature.SshPam:
await disableSshPam(orgId);
break;
default:
logger.warn(
`Unknown feature ${feature} for org ${orgId}, skipping`
@@ -283,6 +320,12 @@ async function disableDeviceApprovals(orgId: string): Promise<void> {
logger.info(`Disabled device approvals on all roles for org ${orgId}`);
}
async function disableSshPam(orgId: string): Promise<void> {
logger.info(
`Disabled SSH PAM options on all roles and site resources for org ${orgId}`
);
}
async function disableLoginPageBranding(orgId: string): Promise<void> {
const [existingBranding] = await db
.select()
@@ -314,9 +357,7 @@ async function disableLoginPageDomain(orgId: string): Promise<void> {
);
if (existingLoginPage) {
await db
.delete(loginPageOrg)
.where(eq(loginPageOrg.orgId, orgId));
await db.delete(loginPageOrg).where(eq(loginPageOrg.orgId, orgId));
await db
.delete(loginPage)

View File

@@ -112,11 +112,13 @@ export async function getOrgSubscriptionsData(
throw new Error(`Not found`);
}
const billingOrgId = org[0].billingOrgId || org[0].orgId;
// Get customer for org
const customer = await db
.select()
.from(customers)
.where(eq(customers.orgId, orgId))
.where(eq(customers.orgId, billingOrgId))
.limit(1);
const subscriptionsWithItems: Array<{

View File

@@ -85,10 +85,14 @@ export async function getOrgUsage(
orgId,
FeatureId.REMOTE_EXIT_NODES
);
const egressData = await usageService.getUsage(
const organizations = await usageService.getUsage(
orgId,
FeatureId.EGRESS_DATA_MB
FeatureId.ORGINIZATIONS
);
// const egressData = await usageService.getUsage(
// orgId,
// FeatureId.EGRESS_DATA_MB
// );
if (sites) {
usageData.push(sites);
@@ -96,15 +100,18 @@ export async function getOrgUsage(
if (users) {
usageData.push(users);
}
if (egressData) {
usageData.push(egressData);
}
// if (egressData) {
// usageData.push(egressData);
// }
if (domains) {
usageData.push(domains);
}
if (remoteExitNodes) {
usageData.push(remoteExitNodes);
}
if (organizations) {
usageData.push(organizations);
}
const orgLimits = await db
.select()

View File

@@ -25,6 +25,7 @@ import * as logs from "#private/routers/auditLogs";
import * as misc from "#private/routers/misc";
import * as reKey from "#private/routers/re-key";
import * as approval from "#private/routers/approvals";
import * as ssh from "#private/routers/ssh";
import * as resource from "#private/routers/resource";
import * as policy from "#private/routers/policy";
@@ -34,6 +35,7 @@ import {
verifyUserIsServerAdmin,
verifySiteAccess,
verifyClientAccess,
verifyLimits,
verifyLimits
} from "@server/middlewares";
import { ActionsEnum } from "@server/auth/actions";
@@ -503,9 +505,9 @@ authenticated.get(
authenticated.post(
"/re-key/:clientId/regenerate-client-secret",
verifyClientAccess, // this is first to set the org id
verifyValidLicense,
verifyValidSubscription(tierMatrix.rotateCredentials),
verifyClientAccess, // this is first to set the org id
verifyLimits,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateClientSecret
@@ -513,9 +515,9 @@ authenticated.post(
authenticated.post(
"/re-key/:siteId/regenerate-site-secret",
verifySiteAccess, // this is first to set the org id
verifyValidLicense,
verifyValidSubscription(tierMatrix.rotateCredentials),
verifySiteAccess, // this is first to set the org id
verifyLimits,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateSiteSecret
@@ -530,3 +532,14 @@ authenticated.put(
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateExitNodeSecret
);
authenticated.post(
"/org/:orgId/ssh/sign-key",
verifyValidLicense,
verifyValidSubscription(tierMatrix.sshPam),
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.signSshKey),
logActionAudit(ActionsEnum.signSshKey),
ssh.signSshKey
);

View File

@@ -26,6 +26,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, InferInsertModel } from "drizzle-orm";
import { build } from "@server/build";
import { validateLocalPath } from "@app/lib/validateLocalPath";
import config from "#private/lib/config";
const paramsSchema = z.strictObject({
@@ -37,14 +38,36 @@ const bodySchema = z.strictObject({
.union([
z.literal(""),
z
.url("Must be a valid URL")
.superRefine(async (url, ctx) => {
.string()
.superRefine(async (urlOrPath, ctx) => {
const parseResult = z.url().safeParse(urlOrPath);
if (!parseResult.success) {
if (build !== "enterprise") {
ctx.addIssue({
code: "custom",
message: "Must be a valid URL"
});
return;
} else {
try {
validateLocalPath(urlOrPath);
} catch (error) {
ctx.addIssue({
code: "custom",
message: "Must be either a valid image URL or a valid pathname starting with `/` and not containing query parameters, `..` or `*`"
});
} finally {
return;
}
}
}
try {
const response = await fetch(url, {
const response = await fetch(urlOrPath, {
method: "HEAD"
}).catch(() => {
// If HEAD fails (CORS or method not allowed), try GET
return fetch(url, { method: "GET" });
return fetch(urlOrPath, { method: "GET" });
});
if (response.status !== 200) {

View File

@@ -28,6 +28,7 @@ import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types";
import { isSubscribed } from "#private/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import privateConfig from "#private/lib/config";
import { build } from "@server/build";
const paramsSchema = z.strictObject({ orgId: z.string().nonempty() });
@@ -122,12 +123,14 @@ export async function createOrgOidcIdp(
let { autoProvision } = parsedBody.data;
const subscribed = await isSubscribed(
orgId,
tierMatrix.deviceApprovals
);
if (!subscribed) {
autoProvision = false;
if (build == "saas") { // this is not paywalled with a ee license because this whole endpoint is restricted
const subscribed = await isSubscribed(
orgId,
tierMatrix.deviceApprovals
);
if (!subscribed) {
autoProvision = false;
}
}
const key = config.getRawConfig().server.secret!;

View File

@@ -27,6 +27,7 @@ import config from "@server/lib/config";
import { isSubscribed } from "#private/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import privateConfig from "#private/lib/config";
import { build } from "@server/build";
const paramsSchema = z
.object({
@@ -127,12 +128,15 @@ export async function updateOrgOidcIdp(
let { autoProvision } = parsedBody.data;
const subscribed = await isSubscribed(
orgId,
tierMatrix.deviceApprovals
);
if (!subscribed) {
autoProvision = false;
if (build == "saas") {
// this is not paywalled with a ee license because this whole endpoint is restricted
const subscribed = await isSubscribed(
orgId,
tierMatrix.deviceApprovals
);
if (!subscribed) {
autoProvision = false;
}
}
// Check if IDP exists and is of type OIDC

View File

@@ -12,7 +12,14 @@
*/
import { NextFunction, Request, Response } from "express";
import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg } from "@server/db";
import {
db,
exitNodes,
exitNodeOrgs,
ExitNode,
ExitNodeOrg,
orgs
} from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { remoteExitNodes } from "@server/db";
@@ -25,7 +32,7 @@ import { createRemoteExitNodeSession } from "#private/auth/sessions/remoteExitNo
import { fromError } from "zod-validation-error";
import { hashPassword, verifyPassword } from "@server/auth/password";
import logger from "@server/logger";
import { and, eq } from "drizzle-orm";
import { and, eq, inArray, ne } from "drizzle-orm";
import { getNextAvailableSubnet } from "@server/lib/exitNodes";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
@@ -169,7 +176,17 @@ export async function createRemoteExitNode(
);
}
let numExitNodeOrgs: ExitNodeOrg[] | undefined;
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
);
}
await db.transaction(async (trx) => {
if (!existingExitNode) {
@@ -217,19 +234,43 @@ export async function createRemoteExitNode(
});
}
numExitNodeOrgs = await trx
.select()
.from(exitNodeOrgs)
.where(eq(exitNodeOrgs.orgId, orgId));
});
// calculate if the node is in any other of the orgs before we count it as an add to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(
and(
eq(orgs.billingOrgId, org.billingOrgId),
ne(orgs.orgId, orgId)
)
);
if (numExitNodeOrgs) {
await usageService.updateCount(
orgId,
FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length
);
}
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheNodeIsStillIn = await trx
.select()
.from(exitNodeOrgs)
.where(
and(
eq(
exitNodeOrgs.exitNodeId,
existingExitNode.exitNodeId
),
inArray(exitNodeOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) {
await usageService.add(
orgId,
FeatureId.REMOTE_EXIT_NODES,
1,
trx
);
}
}
});
const token = generateSessionToken();
await createRemoteExitNodeSession(token, remoteExitNodeId);

View File

@@ -13,9 +13,9 @@
import { NextFunction, Request, Response } from "express";
import { z } from "zod";
import { db, ExitNodeOrg, exitNodeOrgs, exitNodes } from "@server/db";
import { db, ExitNodeOrg, exitNodeOrgs, exitNodes, orgs } from "@server/db";
import { remoteExitNodes } from "@server/db";
import { and, count, eq } from "drizzle-orm";
import { and, count, eq, inArray } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -50,7 +50,8 @@ export async function deleteRemoteExitNode(
const [remoteExitNode] = await db
.select()
.from(remoteExitNodes)
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId));
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId))
.limit(1);
if (!remoteExitNode) {
return next(
@@ -70,7 +71,17 @@ export async function deleteRemoteExitNode(
);
}
let numExitNodeOrgs: ExitNodeOrg[] | undefined;
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Org with ID ${orgId} not found`
)
);
}
await db.transaction(async (trx) => {
await trx
.delete(exitNodeOrgs)
@@ -81,38 +92,39 @@ export async function deleteRemoteExitNode(
)
);
const [remainingExitNodeOrgs] = await trx
.select({ count: count() })
.from(exitNodeOrgs)
.where(eq(exitNodeOrgs.exitNodeId, remoteExitNode.exitNodeId!));
// calculate if the user is in any other of the orgs before we count it as an remove to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, org.billingOrgId));
if (remainingExitNodeOrgs.count === 0) {
await trx
.delete(remoteExitNodes)
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheNodeIsStillIn = await trx
.select()
.from(exitNodeOrgs)
.where(
eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId)
and(
eq(
exitNodeOrgs.exitNodeId,
remoteExitNode.exitNodeId!
),
inArray(exitNodeOrgs.orgId, billingOrgIds)
)
);
await trx
.delete(exitNodes)
.where(
eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId!)
if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) {
await usageService.add(
orgId,
FeatureId.REMOTE_EXIT_NODES,
-1,
trx
);
}
}
numExitNodeOrgs = await trx
.select()
.from(exitNodeOrgs)
.where(eq(exitNodeOrgs.orgId, orgId));
});
if (numExitNodeOrgs) {
await usageService.updateCount(
orgId,
FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length
);
}
return response(res, {
data: null,
success: true,

View File

@@ -0,0 +1,14 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
export * from "./signSshKey";

View File

@@ -0,0 +1,476 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import {
db,
newts,
roles,
roundTripMessageTracker,
siteResources,
sites,
userOrgs
} from "@server/db";
import { isLicensedOrSubscribed } from "#private/lib/isLicencedOrSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { eq, or, and } from "drizzle-orm";
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
import { signPublicKey, getOrgCAKeys } from "@server/lib/sshCA";
import config from "@server/lib/config";
import { sendToClient } from "#private/routers/ws";
const paramsSchema = z.strictObject({
orgId: z.string().nonempty()
});
const bodySchema = z
.strictObject({
publicKey: z.string().nonempty(),
resourceId: z.number().int().positive().optional(),
resource: z.string().nonempty().optional() // this is either the nice id or the alias
})
.refine(
(data) => {
const fields = [data.resourceId, data.resource];
const definedFields = fields.filter((field) => field !== undefined);
return definedFields.length === 1;
},
{
message:
"Exactly one of resourceId, niceId, or alias must be provided"
}
);
export type SignSshKeyResponse = {
certificate: string;
messageId: number;
sshUsername: string;
sshHost: string;
resourceId: number;
keyId: string;
validPrincipals: string[];
validAfter: string;
validBefore: string;
expiresIn: number;
};
// registry.registerPath({
// method: "post",
// path: "/org/{orgId}/ssh/sign-key",
// description: "Sign an SSH public key for access to a resource.",
// tags: [OpenAPITags.Org, OpenAPITags.Ssh],
// request: {
// params: paramsSchema,
// body: {
// content: {
// "application/json": {
// schema: bodySchema
// }
// }
// }
// },
// responses: {}
// });
export async function signSshKey(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const {
publicKey,
resourceId,
resource: resourceQueryString
} = parsedBody.data;
const userId = req.user?.userId;
const roleId = req.userOrgRoleId!;
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
);
}
const [userOrg] = await db
.select()
.from(userOrgs)
.where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId)))
.limit(1);
if (!userOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not belong to the specified organization"
)
);
}
const isLicensed = await isLicensedOrSubscribed(
orgId,
tierMatrix.sshPam
);
if (!isLicensed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"SSH key signing requires a paid plan"
)
);
}
let usernameToUse;
if (!userOrg.pamUsername) {
if (req.user?.email) {
// Extract username from email (first part before @)
usernameToUse = req.user?.email
.split("@")[0]
.replace(/[^a-zA-Z0-9_-]/g, "");
if (!usernameToUse) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Unable to extract username from email"
)
);
}
} else if (req.user?.username) {
usernameToUse = req.user.username;
// We need to clean out any spaces or special characters from the username to ensure it's valid for SSH certificates
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "-");
if (!usernameToUse) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Username is not valid for SSH certificate"
)
);
}
} else {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"User does not have a valid email or username for SSH certificate"
)
);
}
// prefix with p-
usernameToUse = `p-${usernameToUse}`;
// check if we have a existing user in this org with the same
const [existingUserWithSameName] = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.pamUsername, usernameToUse)
)
)
.limit(1);
if (existingUserWithSameName) {
let foundUniqueUsername = false;
for (let attempt = 0; attempt < 20; attempt++) {
const randomNum = Math.floor(Math.random() * 101); // 0 to 100
const candidateUsername = `${usernameToUse}${randomNum}`;
const [existingUser] = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.pamUsername, candidateUsername)
)
)
.limit(1);
if (!existingUser) {
usernameToUse = candidateUsername;
foundUniqueUsername = true;
break;
}
}
if (!foundUniqueUsername) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Unable to generate a unique username for SSH certificate"
)
);
}
}
await db
.update(userOrgs)
.set({ pamUsername: usernameToUse })
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.userId, userId)
)
);
} else {
usernameToUse = userOrg.pamUsername;
}
// Get and decrypt the org's CA keys
const caKeys = await getOrgCAKeys(
orgId,
config.getRawConfig().server.secret!
);
if (!caKeys) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"SSH CA not configured for this organization"
)
);
}
// Verify the resource exists and belongs to the org
// Build the where clause dynamically based on which field is provided
let whereClause;
if (resourceId !== undefined) {
whereClause = eq(siteResources.siteResourceId, resourceId);
} else if (resourceQueryString !== undefined) {
whereClause = or(
eq(siteResources.niceId, resourceQueryString),
eq(siteResources.alias, resourceQueryString)
);
} else {
// This should never happen due to the schema validation, but TypeScript doesn't know that
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"One of resourceId, niceId, or alias must be provided"
)
);
}
const resources = await db
.select()
.from(siteResources)
.where(and(whereClause, eq(siteResources.orgId, orgId)));
if (!resources || resources.length === 0) {
return next(
createHttpError(HttpCode.NOT_FOUND, `Resource not found`)
);
}
if (resources.length > 1) {
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Multiple resources found matching the criteria`
)
);
}
const resource = resources[0];
if (resource.orgId !== orgId) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Resource does not belong to the specified organization"
)
);
}
if (resource.mode == "cidr") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"SSHing is not supported for CIDR resources"
)
);
}
// Check if the user has access to the resource
const hasAccess = await canUserAccessSiteResource({
userId: userId,
resourceId: resource.siteResourceId,
roleId: roleId
});
if (!hasAccess) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this resource"
)
);
}
const [roleRow] = await db
.select()
.from(roles)
.where(eq(roles.roleId, roleId))
.limit(1);
let parsedSudoCommands: string[] = [];
let parsedGroups: string[] = [];
try {
parsedSudoCommands = JSON.parse(roleRow?.sshSudoCommands ?? "[]");
if (!Array.isArray(parsedSudoCommands)) parsedSudoCommands = [];
} catch {
parsedSudoCommands = [];
}
try {
parsedGroups = JSON.parse(roleRow?.sshUnixGroups ?? "[]");
if (!Array.isArray(parsedGroups)) parsedGroups = [];
} catch {
parsedGroups = [];
}
const homedir = roleRow?.sshCreateHomeDir ?? null;
const sudoMode = roleRow?.sshSudoMode ?? "none";
// get the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, resource.siteId))
.limit(1);
if (!newt) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Site associated with resource not found"
)
);
}
// Sign the public key
const now = BigInt(Math.floor(Date.now() / 1000));
// only valid for 5 minutes
const validFor = 300n;
const cert = signPublicKey(caKeys.privateKeyPem, publicKey, {
keyId: `${usernameToUse}@${resource.niceId}`,
validPrincipals: [usernameToUse, resource.niceId],
validAfter: now - 60n, // Start 1 min ago for clock skew
validBefore: now + validFor
});
const [message] = await db
.insert(roundTripMessageTracker)
.values({
wsClientId: newt.newtId,
messageType: `newt/pam/connection`,
sentAt: Math.floor(Date.now() / 1000)
})
.returning();
if (!message) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create message tracker entry"
)
);
}
await sendToClient(newt.newtId, {
type: `newt/pam/connection`,
data: {
messageId: message.messageId,
orgId: orgId,
agentPort: resource.authDaemonPort ?? 22123,
externalAuthDaemon: resource.authDaemonMode === "remote",
agentHost: resource.destination,
caCert: caKeys.publicKeyOpenSSH,
username: usernameToUse,
niceId: resource.niceId,
metadata: {
sudoMode: sudoMode,
sudoCommands: parsedSudoCommands,
homedir: homedir,
groups: parsedGroups
}
}
});
const expiresIn = Number(validFor); // seconds
let sshHost;
if (resource.alias && resource.alias != "") {
sshHost = resource.alias;
} else {
sshHost = resource.destination;
}
return response<SignSshKeyResponse>(res, {
data: {
certificate: cert.certificate,
messageId: message.messageId,
sshUsername: usernameToUse,
sshHost: sshHost,
resourceId: resource.siteResourceId,
keyId: cert.keyId,
validPrincipals: cert.validPrincipals,
validAfter: cert.validAfter.toISOString(),
validBefore: cert.validBefore.toISOString(),
expiresIn
},
success: true,
error: false,
message: "SSH key signed successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error("Error signing SSH key:", error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred while signing the SSH key"
)
);
}
}