mirror of
https://github.com/fosrl/pangolin.git
synced 2026-05-17 14:34:42 +00:00
Merge branch 'dev' into feat/resource-policies
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import { db, requestAuditLog, driver, primaryDb } from "@server/db";
|
||||
import { logsDb, requestAuditLog, driver, primaryLogsDb } from "@server/db";
|
||||
import { registry } from "@server/openApi";
|
||||
import { NextFunction } from "express";
|
||||
import { Request, Response } from "express";
|
||||
@@ -74,12 +74,12 @@ async function query(query: Q) {
|
||||
);
|
||||
}
|
||||
|
||||
const [all] = await primaryDb
|
||||
const [all] = await primaryLogsDb
|
||||
.select({ total: count() })
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions);
|
||||
|
||||
const [blocked] = await primaryDb
|
||||
const [blocked] = await primaryLogsDb
|
||||
.select({ total: count() })
|
||||
.from(requestAuditLog)
|
||||
.where(and(baseConditions, eq(requestAuditLog.action, false)));
|
||||
@@ -90,7 +90,7 @@ async function query(query: Q) {
|
||||
|
||||
const DISTINCT_LIMIT = 500;
|
||||
|
||||
const requestsPerCountry = await primaryDb
|
||||
const requestsPerCountry = await primaryLogsDb
|
||||
.selectDistinct({
|
||||
code: requestAuditLog.location,
|
||||
count: totalQ
|
||||
@@ -118,7 +118,7 @@ async function query(query: Q) {
|
||||
const booleanTrue = driver === "pg" ? sql`true` : sql`1`;
|
||||
const booleanFalse = driver === "pg" ? sql`false` : sql`0`;
|
||||
|
||||
const requestsPerDay = await primaryDb
|
||||
const requestsPerDay = await primaryLogsDb
|
||||
.select({
|
||||
day: groupByDayFunction.as("day"),
|
||||
allowedCount:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { db, primaryDb, requestAuditLog, resources } from "@server/db";
|
||||
import { logsDb, primaryLogsDb, requestAuditLog, resources, db, primaryDb } from "@server/db";
|
||||
import { registry } from "@server/openApi";
|
||||
import { NextFunction } from "express";
|
||||
import { Request, Response } from "express";
|
||||
import { eq, gt, lt, and, count, desc } from "drizzle-orm";
|
||||
import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm";
|
||||
import { OpenAPITags } from "@server/openApi";
|
||||
import { z } from "zod";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -107,7 +107,7 @@ function getWhere(data: Q) {
|
||||
}
|
||||
|
||||
export function queryRequest(data: Q) {
|
||||
return primaryDb
|
||||
return primaryLogsDb
|
||||
.select({
|
||||
id: requestAuditLog.id,
|
||||
timestamp: requestAuditLog.timestamp,
|
||||
@@ -129,21 +129,49 @@ export function queryRequest(data: Q) {
|
||||
host: requestAuditLog.host,
|
||||
path: requestAuditLog.path,
|
||||
method: requestAuditLog.method,
|
||||
tls: requestAuditLog.tls,
|
||||
resourceName: resources.name,
|
||||
resourceNiceId: resources.niceId
|
||||
tls: requestAuditLog.tls
|
||||
})
|
||||
.from(requestAuditLog)
|
||||
.leftJoin(
|
||||
resources,
|
||||
eq(requestAuditLog.resourceId, resources.resourceId)
|
||||
) // TODO: Is this efficient?
|
||||
.where(getWhere(data))
|
||||
.orderBy(desc(requestAuditLog.timestamp));
|
||||
}
|
||||
|
||||
async function enrichWithResourceDetails(logs: Awaited<ReturnType<typeof queryRequest>>) {
|
||||
// If logs database is the same as main database, we can do a join
|
||||
// Otherwise, we need to fetch resource details separately
|
||||
const resourceIds = logs
|
||||
.map(log => log.resourceId)
|
||||
.filter((id): id is number => id !== null && id !== undefined);
|
||||
|
||||
if (resourceIds.length === 0) {
|
||||
return logs.map(log => ({ ...log, resourceName: null, resourceNiceId: null }));
|
||||
}
|
||||
|
||||
// Fetch resource details from main database
|
||||
const resourceDetails = await primaryDb
|
||||
.select({
|
||||
resourceId: resources.resourceId,
|
||||
name: resources.name,
|
||||
niceId: resources.niceId
|
||||
})
|
||||
.from(resources)
|
||||
.where(inArray(resources.resourceId, resourceIds));
|
||||
|
||||
// Create a map for quick lookup
|
||||
const resourceMap = new Map(
|
||||
resourceDetails.map(r => [r.resourceId, { name: r.name, niceId: r.niceId }])
|
||||
);
|
||||
|
||||
// Enrich logs with resource details
|
||||
return logs.map(log => ({
|
||||
...log,
|
||||
resourceName: log.resourceId ? resourceMap.get(log.resourceId)?.name ?? null : null,
|
||||
resourceNiceId: log.resourceId ? resourceMap.get(log.resourceId)?.niceId ?? null : null
|
||||
}));
|
||||
}
|
||||
|
||||
export function countRequestQuery(data: Q) {
|
||||
const countQuery = primaryDb
|
||||
const countQuery = primaryLogsDb
|
||||
.select({ count: count() })
|
||||
.from(requestAuditLog)
|
||||
.where(getWhere(data));
|
||||
@@ -185,36 +213,31 @@ async function queryUniqueFilterAttributes(
|
||||
uniquePaths,
|
||||
uniqueResources
|
||||
] = await Promise.all([
|
||||
primaryDb
|
||||
primaryLogsDb
|
||||
.selectDistinct({ actor: requestAuditLog.actor })
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions)
|
||||
.limit(DISTINCT_LIMIT + 1),
|
||||
primaryDb
|
||||
primaryLogsDb
|
||||
.selectDistinct({ locations: requestAuditLog.location })
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions)
|
||||
.limit(DISTINCT_LIMIT + 1),
|
||||
primaryDb
|
||||
primaryLogsDb
|
||||
.selectDistinct({ hosts: requestAuditLog.host })
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions)
|
||||
.limit(DISTINCT_LIMIT + 1),
|
||||
primaryDb
|
||||
primaryLogsDb
|
||||
.selectDistinct({ paths: requestAuditLog.path })
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions)
|
||||
.limit(DISTINCT_LIMIT + 1),
|
||||
primaryDb
|
||||
primaryLogsDb
|
||||
.selectDistinct({
|
||||
id: requestAuditLog.resourceId,
|
||||
name: resources.name
|
||||
id: requestAuditLog.resourceId
|
||||
})
|
||||
.from(requestAuditLog)
|
||||
.leftJoin(
|
||||
resources,
|
||||
eq(requestAuditLog.resourceId, resources.resourceId)
|
||||
)
|
||||
.where(baseConditions)
|
||||
.limit(DISTINCT_LIMIT + 1)
|
||||
]);
|
||||
@@ -231,13 +254,33 @@ async function queryUniqueFilterAttributes(
|
||||
// throw new Error("Too many distinct filter attributes to retrieve. Please refine your time range.");
|
||||
// }
|
||||
|
||||
// Fetch resource names from main database for the unique resource IDs
|
||||
const resourceIds = uniqueResources
|
||||
.map(row => row.id)
|
||||
.filter((id): id is number => id !== null);
|
||||
|
||||
let resourcesWithNames: Array<{ id: number; name: string | null }> = [];
|
||||
|
||||
if (resourceIds.length > 0) {
|
||||
const resourceDetails = await primaryDb
|
||||
.select({
|
||||
resourceId: resources.resourceId,
|
||||
name: resources.name
|
||||
})
|
||||
.from(resources)
|
||||
.where(inArray(resources.resourceId, resourceIds));
|
||||
|
||||
resourcesWithNames = resourceDetails.map(r => ({
|
||||
id: r.resourceId,
|
||||
name: r.name
|
||||
}));
|
||||
}
|
||||
|
||||
return {
|
||||
actors: uniqueActors
|
||||
.map((row) => row.actor)
|
||||
.filter((actor): actor is string => actor !== null),
|
||||
resources: uniqueResources.filter(
|
||||
(row): row is { id: number; name: string | null } => row.id !== null
|
||||
),
|
||||
resources: resourcesWithNames,
|
||||
locations: uniqueLocations
|
||||
.map((row) => row.locations)
|
||||
.filter((location): location is string => location !== null),
|
||||
@@ -280,7 +323,10 @@ export async function queryRequestAuditLogs(
|
||||
|
||||
const baseQuery = queryRequest(data);
|
||||
|
||||
const log = await baseQuery.limit(data.limit).offset(data.offset);
|
||||
const logsRaw = await baseQuery.limit(data.limit).offset(data.offset);
|
||||
|
||||
// Enrich with resource details (handles cross-database scenario)
|
||||
const log = await enrichWithResourceDetails(logsRaw);
|
||||
|
||||
const totalCountResult = await countRequestQuery(data);
|
||||
const totalCount = totalCountResult[0].count;
|
||||
|
||||
242
server/routers/auth/deleteMyAccount.ts
Normal file
242
server/routers/auth/deleteMyAccount.ts
Normal file
@@ -0,0 +1,242 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, orgs, userOrgs, users } from "@server/db";
|
||||
import { eq, and, inArray } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { verifySession } from "@server/auth/sessions/verifySession";
|
||||
import {
|
||||
invalidateSession,
|
||||
createBlankSessionTokenCookie
|
||||
} from "@server/auth/sessions/app";
|
||||
import { verifyPassword } from "@server/auth/password";
|
||||
import { verifyTotpCode } from "@server/auth/totp";
|
||||
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
|
||||
import { build } from "@server/build";
|
||||
import { getOrgTierData } from "#dynamic/lib/billing";
|
||||
import {
|
||||
deleteOrgById,
|
||||
sendTerminationMessages
|
||||
} from "@server/lib/deleteOrg";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
|
||||
const deleteMyAccountBody = z.strictObject({
|
||||
password: z.string().optional(),
|
||||
code: z.string().optional()
|
||||
});
|
||||
|
||||
export type DeleteMyAccountPreviewResponse = {
|
||||
preview: true;
|
||||
orgs: { orgId: string; name: string }[];
|
||||
twoFactorEnabled: boolean;
|
||||
};
|
||||
|
||||
export type DeleteMyAccountCodeRequestedResponse = {
|
||||
codeRequested: true;
|
||||
};
|
||||
|
||||
export type DeleteMyAccountSuccessResponse = {
|
||||
success: true;
|
||||
};
|
||||
|
||||
export async function deleteMyAccount(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const { user, session } = await verifySession(req);
|
||||
if (!user || !session) {
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "Not authenticated")
|
||||
);
|
||||
}
|
||||
|
||||
if (user.serverAdmin) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Server admins cannot delete their account this way"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (user.type !== UserType.Internal) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Account deletion with password is only supported for internal users"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const parsed = deleteMyAccountBody.safeParse(req.body ?? {});
|
||||
if (!parsed.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsed.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
const { password, code } = parsed.data;
|
||||
|
||||
const userId = user.userId;
|
||||
|
||||
const ownedOrgsRows = await db
|
||||
.select({
|
||||
orgId: userOrgs.orgId,
|
||||
isOwner: userOrgs.isOwner,
|
||||
isBillingOrg: orgs.isBillingOrg
|
||||
})
|
||||
.from(userOrgs)
|
||||
.innerJoin(orgs, eq(userOrgs.orgId, orgs.orgId))
|
||||
.where(
|
||||
and(eq(userOrgs.userId, userId), eq(userOrgs.isOwner, true))
|
||||
);
|
||||
|
||||
const orgIds = ownedOrgsRows.map((r) => r.orgId);
|
||||
|
||||
if (build === "saas" && orgIds.length > 0) {
|
||||
const primaryOrgId = ownedOrgsRows.find(
|
||||
(r) => r.isBillingOrg && r.isOwner
|
||||
)?.orgId;
|
||||
if (primaryOrgId) {
|
||||
const { tier, active } = await getOrgTierData(primaryOrgId);
|
||||
if (active && tier) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"You must cancel your subscription before deleting your account"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!password) {
|
||||
const orgsWithNames =
|
||||
orgIds.length > 0
|
||||
? await db
|
||||
.select({
|
||||
orgId: orgs.orgId,
|
||||
name: orgs.name
|
||||
})
|
||||
.from(orgs)
|
||||
.where(inArray(orgs.orgId, orgIds))
|
||||
: [];
|
||||
return response<DeleteMyAccountPreviewResponse>(res, {
|
||||
data: {
|
||||
preview: true,
|
||||
orgs: orgsWithNames.map((o) => ({
|
||||
orgId: o.orgId,
|
||||
name: o.name ?? ""
|
||||
})),
|
||||
twoFactorEnabled: user.twoFactorEnabled ?? false
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Preview",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
const validPassword = await verifyPassword(
|
||||
password,
|
||||
user.passwordHash!
|
||||
);
|
||||
if (!validPassword) {
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "Invalid password")
|
||||
);
|
||||
}
|
||||
|
||||
if (user.twoFactorEnabled) {
|
||||
if (!code) {
|
||||
return response<DeleteMyAccountCodeRequestedResponse>(res, {
|
||||
data: { codeRequested: true },
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Two-factor code required",
|
||||
status: HttpCode.ACCEPTED
|
||||
});
|
||||
}
|
||||
const validOTP = await verifyTotpCode(
|
||||
code,
|
||||
user.twoFactorSecret!,
|
||||
user.userId
|
||||
);
|
||||
if (!validOTP) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"The two-factor code you entered is incorrect"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const allDeletedNewtIds: string[] = [];
|
||||
const allOlmsToTerminate: string[] = [];
|
||||
|
||||
for (const row of ownedOrgsRows) {
|
||||
try {
|
||||
const result = await deleteOrgById(row.orgId);
|
||||
allDeletedNewtIds.push(...result.deletedNewtIds);
|
||||
allOlmsToTerminate.push(...result.olmsToTerminate);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`Failed to delete org ${row.orgId} during account deletion`,
|
||||
err
|
||||
);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to delete organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
sendTerminationMessages({
|
||||
deletedNewtIds: allDeletedNewtIds,
|
||||
olmsToTerminate: allOlmsToTerminate
|
||||
});
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
await trx.delete(users).where(eq(users.userId, userId));
|
||||
await calculateUserClientsForOrgs(userId, trx);
|
||||
});
|
||||
|
||||
try {
|
||||
await invalidateSession(session.sessionId);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
"Failed to invalidate session after account deletion",
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
const isSecure = req.protocol === "https";
|
||||
res.setHeader("Set-Cookie", createBlankSessionTokenCookie(isSecure));
|
||||
|
||||
return response<DeleteMyAccountSuccessResponse>(res, {
|
||||
data: { success: true },
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Account deleted successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"An error occurred"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -17,4 +17,5 @@ export * from "./securityKey";
|
||||
export * from "./startDeviceWebAuth";
|
||||
export * from "./verifyDeviceWebAuth";
|
||||
export * from "./pollDeviceWebAuth";
|
||||
export * from "./lookupUser";
|
||||
export * from "./lookupUser";
|
||||
export * from "./deleteMyAccount";
|
||||
@@ -1,7 +1,7 @@
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db, users } from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { z } from "zod";
|
||||
import { email, z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
@@ -21,7 +21,6 @@ import { hashPassword } from "@server/auth/password";
|
||||
import { checkValidInvite } from "@server/auth/checkValidInvite";
|
||||
import { passwordSchema } from "@server/auth/passwordSchema";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
import { createUserAccountOrg } from "@server/lib/createUserAccountOrg";
|
||||
import { build } from "@server/build";
|
||||
import resend, { AudienceIds, moveEmailToAudience } from "#dynamic/lib/resend";
|
||||
|
||||
@@ -31,7 +30,8 @@ export const signupBodySchema = z.object({
|
||||
inviteToken: z.string().optional(),
|
||||
inviteId: z.string().optional(),
|
||||
termsAcceptedTimestamp: z.string().nullable().optional(),
|
||||
marketingEmailConsent: z.boolean().optional()
|
||||
marketingEmailConsent: z.boolean().optional(),
|
||||
skipVerificationEmail: z.boolean().optional()
|
||||
});
|
||||
|
||||
export type SignUpBody = z.infer<typeof signupBodySchema>;
|
||||
@@ -62,7 +62,8 @@ export async function signup(
|
||||
inviteToken,
|
||||
inviteId,
|
||||
termsAcceptedTimestamp,
|
||||
marketingEmailConsent
|
||||
marketingEmailConsent,
|
||||
skipVerificationEmail
|
||||
} = parsedBody.data;
|
||||
|
||||
const passwordHash = await hashPassword(password);
|
||||
@@ -198,26 +199,6 @@ export async function signup(
|
||||
// orgId: null,
|
||||
// });
|
||||
|
||||
if (build == "saas") {
|
||||
const { success, error, org } = await createUserAccountOrg(
|
||||
userId,
|
||||
email
|
||||
);
|
||||
if (!success) {
|
||||
if (error) {
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, error)
|
||||
);
|
||||
}
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to create user account and organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const token = generateSessionToken();
|
||||
const sess = await createSession(token, userId);
|
||||
const isSecure = req.protocol === "https";
|
||||
@@ -235,7 +216,13 @@ export async function signup(
|
||||
}
|
||||
|
||||
if (config.getRawConfig().flags?.require_email_verification) {
|
||||
sendEmailVerificationCode(email, userId);
|
||||
if (!skipVerificationEmail) {
|
||||
sendEmailVerificationCode(email, userId);
|
||||
} else {
|
||||
logger.debug(
|
||||
`User ${email} opted out of verification email during signup.`
|
||||
);
|
||||
}
|
||||
|
||||
return response<SignUpResponse>(res, {
|
||||
data: {
|
||||
@@ -243,7 +230,9 @@ export async function signup(
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: `User created successfully. We sent an email to ${email} with a verification code.`,
|
||||
message: skipVerificationEmail
|
||||
? "User created successfully. Please verify your email."
|
||||
: `User created successfully. We sent an email to ${email} with a verification code.`,
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { db, orgs, requestAuditLog } from "@server/db";
|
||||
import { logsDb, primaryLogsDb, db, orgs, requestAuditLog } from "@server/db";
|
||||
import logger from "@server/logger";
|
||||
import { and, eq, lt, sql } from "drizzle-orm";
|
||||
import cache from "@server/lib/cache";
|
||||
import cache from "#dynamic/lib/cache";
|
||||
import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs";
|
||||
import { stripPortFromHost } from "@server/lib/ip";
|
||||
|
||||
@@ -69,7 +69,7 @@ async function flushAuditLogs() {
|
||||
try {
|
||||
// Use a transaction to ensure all inserts succeed or fail together
|
||||
// This prevents index corruption from partial writes
|
||||
await db.transaction(async (tx) => {
|
||||
await logsDb.transaction(async (tx) => {
|
||||
// Batch insert logs in groups of 25 to avoid overwhelming the database
|
||||
const BATCH_DB_SIZE = 25;
|
||||
for (let i = 0; i < logsToWrite.length; i += BATCH_DB_SIZE) {
|
||||
@@ -130,7 +130,7 @@ export async function shutdownAuditLogger() {
|
||||
|
||||
async function getRetentionDays(orgId: string): Promise<number> {
|
||||
// check cache first
|
||||
const cached = cache.get<number>(`org_${orgId}_retentionDays`);
|
||||
const cached = await cache.get<number>(`org_${orgId}_retentionDays`);
|
||||
if (cached !== undefined) {
|
||||
return cached;
|
||||
}
|
||||
@@ -149,7 +149,7 @@ async function getRetentionDays(orgId: string): Promise<number> {
|
||||
}
|
||||
|
||||
// store the result in cache
|
||||
cache.set(
|
||||
await cache.set(
|
||||
`org_${orgId}_retentionDays`,
|
||||
org.settingsLogRetentionDaysRequest,
|
||||
300
|
||||
@@ -162,7 +162,7 @@ export async function cleanUpOldLogs(orgId: string, retentionDays: number) {
|
||||
const cutoffTimestamp = calculateCutoffTimestamp(retentionDays);
|
||||
|
||||
try {
|
||||
await db
|
||||
await logsDb
|
||||
.delete(requestAuditLog)
|
||||
.where(
|
||||
and(
|
||||
|
||||
@@ -37,7 +37,7 @@ import {
|
||||
enforceResourceSessionLength
|
||||
} from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
import { logRequestAudit } from "./logRequestAudit";
|
||||
import cache from "@server/lib/cache";
|
||||
import { localCache } from "#dynamic/lib/cache";
|
||||
import { APP_VERSION } from "@server/lib/consts";
|
||||
import { isSubscribed } from "#dynamic/lib/isSubscribed";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
@@ -137,7 +137,7 @@ export async function verifyResourceSession(
|
||||
headerAuthExtendedCompatibility: ResourceHeaderAuthExtendedCompatibility | null;
|
||||
org: Org;
|
||||
}
|
||||
| undefined = cache.get(resourceCacheKey);
|
||||
| undefined = localCache.get(resourceCacheKey);
|
||||
|
||||
if (!resourceData) {
|
||||
const result = await getResourceByDomain(cleanHost);
|
||||
@@ -161,7 +161,7 @@ export async function verifyResourceSession(
|
||||
}
|
||||
|
||||
resourceData = result;
|
||||
cache.set(resourceCacheKey, resourceData, 5);
|
||||
localCache.set(resourceCacheKey, resourceData, 5);
|
||||
}
|
||||
|
||||
const {
|
||||
@@ -405,7 +405,7 @@ export async function verifyResourceSession(
|
||||
// check for HTTP Basic Auth header
|
||||
const clientHeaderAuthKey = `headerAuth:${clientHeaderAuth}`;
|
||||
if (headerAuth && clientHeaderAuth) {
|
||||
if (cache.get(clientHeaderAuthKey)) {
|
||||
if (localCache.get(clientHeaderAuthKey)) {
|
||||
logger.debug(
|
||||
"Resource allowed because header auth is valid (cached)"
|
||||
);
|
||||
@@ -428,7 +428,7 @@ export async function verifyResourceSession(
|
||||
headerAuth.headerAuthHash
|
||||
)
|
||||
) {
|
||||
cache.set(clientHeaderAuthKey, clientHeaderAuth, 5);
|
||||
localCache.set(clientHeaderAuthKey, clientHeaderAuth, 5);
|
||||
logger.debug("Resource allowed because header auth is valid");
|
||||
|
||||
logRequestAudit(
|
||||
@@ -520,7 +520,7 @@ export async function verifyResourceSession(
|
||||
|
||||
if (resourceSessionToken) {
|
||||
const sessionCacheKey = `session:${resourceSessionToken}`;
|
||||
let resourceSession: any = cache.get(sessionCacheKey);
|
||||
let resourceSession: any = localCache.get(sessionCacheKey);
|
||||
|
||||
if (!resourceSession) {
|
||||
const result = await validateResourceSessionToken(
|
||||
@@ -529,7 +529,7 @@ export async function verifyResourceSession(
|
||||
);
|
||||
|
||||
resourceSession = result?.resourceSession;
|
||||
cache.set(sessionCacheKey, resourceSession, 5);
|
||||
localCache.set(sessionCacheKey, resourceSession, 5);
|
||||
}
|
||||
|
||||
if (resourceSession?.isRequestToken) {
|
||||
@@ -662,7 +662,7 @@ export async function verifyResourceSession(
|
||||
}:${resource.resourceId}`;
|
||||
|
||||
let allowedUserData: BasicUserData | null | undefined =
|
||||
cache.get(userAccessCacheKey);
|
||||
localCache.get(userAccessCacheKey);
|
||||
|
||||
if (allowedUserData === undefined) {
|
||||
allowedUserData = await isUserAllowedToAccessResource(
|
||||
@@ -671,7 +671,7 @@ export async function verifyResourceSession(
|
||||
resourceData.org
|
||||
);
|
||||
|
||||
cache.set(userAccessCacheKey, allowedUserData, 5);
|
||||
localCache.set(userAccessCacheKey, allowedUserData, 5);
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -797,7 +797,7 @@ async function notAllowed(
|
||||
) {
|
||||
let loginPage: LoginPage | null = null;
|
||||
if (orgId) {
|
||||
const subscribed = await isSubscribed(
|
||||
const subscribed = await isSubscribed( // this is fine because the org login page is only a saas feature
|
||||
orgId,
|
||||
tierMatrix.loginPageDomain
|
||||
);
|
||||
@@ -854,7 +854,7 @@ async function headerAuthChallenged(
|
||||
) {
|
||||
let loginPage: LoginPage | null = null;
|
||||
if (orgId) {
|
||||
const subscribed = await isSubscribed(orgId, tierMatrix.loginPageDomain);
|
||||
const subscribed = await isSubscribed(orgId, tierMatrix.loginPageDomain); // this is fine because the org login page is only a saas feature
|
||||
if (subscribed) {
|
||||
loginPage = await getOrgLoginPage(orgId);
|
||||
}
|
||||
@@ -974,11 +974,11 @@ async function checkRules(
|
||||
): Promise<"ACCEPT" | "DROP" | "PASS" | undefined> {
|
||||
const ruleCacheKey = `rules:${resourceId}`;
|
||||
|
||||
let rules: ResourceRule[] | undefined = cache.get(ruleCacheKey);
|
||||
let rules: ResourceRule[] | undefined = localCache.get(ruleCacheKey);
|
||||
|
||||
if (!rules) {
|
||||
rules = await getResourceRules(resourceId);
|
||||
cache.set(ruleCacheKey, rules, 5);
|
||||
localCache.set(ruleCacheKey, rules, 5);
|
||||
}
|
||||
|
||||
if (rules.length === 0) {
|
||||
@@ -1208,13 +1208,13 @@ async function isIpInAsn(
|
||||
async function getAsnFromIp(ip: string): Promise<number | undefined> {
|
||||
const asnCacheKey = `asn:${ip}`;
|
||||
|
||||
let cachedAsn: number | undefined = cache.get(asnCacheKey);
|
||||
let cachedAsn: number | undefined = localCache.get(asnCacheKey);
|
||||
|
||||
if (!cachedAsn) {
|
||||
cachedAsn = await getAsnForIp(ip); // do it locally
|
||||
// Cache for longer since IP ASN doesn't change frequently
|
||||
if (cachedAsn) {
|
||||
cache.set(asnCacheKey, cachedAsn, 300); // 5 minutes
|
||||
localCache.set(asnCacheKey, cachedAsn, 300); // 5 minutes
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1224,14 +1224,14 @@ async function getAsnFromIp(ip: string): Promise<number | undefined> {
|
||||
async function getCountryCodeFromIp(ip: string): Promise<string | undefined> {
|
||||
const geoIpCacheKey = `geoip:${ip}`;
|
||||
|
||||
let cachedCountryCode: string | undefined = cache.get(geoIpCacheKey);
|
||||
let cachedCountryCode: string | undefined = localCache.get(geoIpCacheKey);
|
||||
|
||||
if (!cachedCountryCode) {
|
||||
cachedCountryCode = await getCountryCodeForIp(ip); // do it locally
|
||||
// Only cache successful lookups to avoid filling cache with undefined values
|
||||
if (cachedCountryCode) {
|
||||
// Cache for longer since IP geolocation doesn't change frequently
|
||||
cache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes
|
||||
localCache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -99,25 +99,54 @@ const listClientsSchema = z.object({
|
||||
.positive()
|
||||
.optional()
|
||||
.catch(20)
|
||||
.default(20),
|
||||
.default(20)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 20,
|
||||
description: "Number of items per page"
|
||||
}),
|
||||
page: z.coerce
|
||||
.number<string>() // for prettier formatting
|
||||
.int()
|
||||
.min(0)
|
||||
.optional()
|
||||
.catch(1)
|
||||
.default(1),
|
||||
.default(1)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 1,
|
||||
description: "Page number to retrieve"
|
||||
}),
|
||||
query: z.string().optional(),
|
||||
sort_by: z
|
||||
.enum(["megabytesIn", "megabytesOut"])
|
||||
.enum(["name", "megabytesIn", "megabytesOut"])
|
||||
.optional()
|
||||
.catch(undefined),
|
||||
order: z.enum(["asc", "desc"]).optional().default("asc").catch("asc"),
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["name", "megabytesIn", "megabytesOut"],
|
||||
description: "Field to sort by"
|
||||
}),
|
||||
order: z
|
||||
.enum(["asc", "desc"])
|
||||
.optional()
|
||||
.default("asc")
|
||||
.catch("asc")
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["asc", "desc"],
|
||||
default: "asc",
|
||||
description: "Sort order"
|
||||
}),
|
||||
online: z
|
||||
.enum(["true", "false"])
|
||||
.transform((v) => v === "true")
|
||||
.optional()
|
||||
.catch(undefined),
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "boolean",
|
||||
description: "Filter by online status"
|
||||
}),
|
||||
status: z.preprocess(
|
||||
(val: string | undefined) => {
|
||||
if (val) {
|
||||
@@ -130,6 +159,16 @@ const listClientsSchema = z.object({
|
||||
.optional()
|
||||
.default(["active"])
|
||||
.catch(["active"])
|
||||
.openapi({
|
||||
type: "array",
|
||||
items: {
|
||||
type: "string",
|
||||
enum: ["active", "blocked", "archived"]
|
||||
},
|
||||
default: ["active"],
|
||||
description:
|
||||
"Filter by client status. Can be a comma-separated list of values. Defaults to 'active'."
|
||||
})
|
||||
)
|
||||
});
|
||||
|
||||
@@ -324,7 +363,7 @@ export async function listClients(
|
||||
const countQuery = db.$count(baseQuery.as("filtered_clients"));
|
||||
|
||||
const listMachinesQuery = baseQuery
|
||||
.limit(page)
|
||||
.limit(pageSize)
|
||||
.offset(pageSize * (page - 1))
|
||||
.orderBy(
|
||||
sort_by
|
||||
|
||||
@@ -100,25 +100,54 @@ const listUserDevicesSchema = z.object({
|
||||
.positive()
|
||||
.optional()
|
||||
.catch(20)
|
||||
.default(20),
|
||||
.default(20)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 20,
|
||||
description: "Number of items per page"
|
||||
}),
|
||||
page: z.coerce
|
||||
.number<string>() // for prettier formatting
|
||||
.int()
|
||||
.min(0)
|
||||
.optional()
|
||||
.catch(1)
|
||||
.default(1),
|
||||
.default(1)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 1,
|
||||
description: "Page number to retrieve"
|
||||
}),
|
||||
query: z.string().optional(),
|
||||
sort_by: z
|
||||
.enum(["megabytesIn", "megabytesOut"])
|
||||
.optional()
|
||||
.catch(undefined),
|
||||
order: z.enum(["asc", "desc"]).optional().default("asc").catch("asc"),
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["megabytesIn", "megabytesOut"],
|
||||
description: "Field to sort by"
|
||||
}),
|
||||
order: z
|
||||
.enum(["asc", "desc"])
|
||||
.optional()
|
||||
.default("asc")
|
||||
.catch("asc")
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["asc", "desc"],
|
||||
default: "asc",
|
||||
description: "Sort order"
|
||||
}),
|
||||
online: z
|
||||
.enum(["true", "false"])
|
||||
.transform((v) => v === "true")
|
||||
.optional()
|
||||
.catch(undefined),
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "boolean",
|
||||
description: "Filter by online status"
|
||||
}),
|
||||
agent: z
|
||||
.enum([
|
||||
"windows",
|
||||
@@ -131,7 +160,22 @@ const listUserDevicesSchema = z.object({
|
||||
"unknown"
|
||||
])
|
||||
.optional()
|
||||
.catch(undefined),
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: [
|
||||
"windows",
|
||||
"android",
|
||||
"cli",
|
||||
"olm",
|
||||
"macos",
|
||||
"ios",
|
||||
"ipados",
|
||||
"unknown"
|
||||
],
|
||||
description:
|
||||
"Filter by agent type. Use 'unknown' to filter clients with no agent detected."
|
||||
}),
|
||||
status: z.preprocess(
|
||||
(val: string | undefined) => {
|
||||
if (val) {
|
||||
@@ -146,6 +190,16 @@ const listUserDevicesSchema = z.object({
|
||||
.optional()
|
||||
.default(["active", "pending"])
|
||||
.catch(["active", "pending"])
|
||||
.openapi({
|
||||
type: "array",
|
||||
items: {
|
||||
type: "string",
|
||||
enum: ["active", "pending", "denied", "blocked", "archived"]
|
||||
},
|
||||
default: ["active", "pending"],
|
||||
description:
|
||||
"Filter by device status. Can include multiple values separated by commas. 'active' means not archived, not blocked, and if approval is enabled, approved. 'pending' and 'denied' are only applicable if approval is enabled."
|
||||
})
|
||||
)
|
||||
});
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { eq, and, ne } from "drizzle-orm";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
@@ -93,7 +93,8 @@ export async function updateClient(
|
||||
.where(
|
||||
and(
|
||||
eq(clients.niceId, niceId),
|
||||
eq(clients.orgId, clients.orgId)
|
||||
eq(clients.orgId, clients.orgId),
|
||||
ne(clients.clientId, clientId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
@@ -148,7 +148,6 @@ export async function createOrgDomain(
|
||||
}
|
||||
}
|
||||
|
||||
let numOrgDomains: OrgDomains[] | undefined;
|
||||
let aRecords: CreateDomainResponse["aRecords"];
|
||||
let cnameRecords: CreateDomainResponse["cnameRecords"];
|
||||
let txtRecords: CreateDomainResponse["txtRecords"];
|
||||
@@ -347,20 +346,9 @@ export async function createOrgDomain(
|
||||
await trx.insert(dnsRecords).values(recordsToInsert);
|
||||
}
|
||||
|
||||
numOrgDomains = await trx
|
||||
.select()
|
||||
.from(orgDomains)
|
||||
.where(eq(orgDomains.orgId, orgId));
|
||||
await usageService.add(orgId, FeatureId.DOMAINS, 1, trx);
|
||||
});
|
||||
|
||||
if (numOrgDomains) {
|
||||
await usageService.updateCount(
|
||||
orgId,
|
||||
FeatureId.DOMAINS,
|
||||
numOrgDomains.length
|
||||
);
|
||||
}
|
||||
|
||||
if (!returned) {
|
||||
return next(
|
||||
createHttpError(
|
||||
|
||||
@@ -36,8 +36,6 @@ export async function deleteAccountDomain(
|
||||
}
|
||||
const { domainId, orgId } = parsed.data;
|
||||
|
||||
let numOrgDomains: OrgDomains[] | undefined;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
const [existing] = await trx
|
||||
.select()
|
||||
@@ -79,20 +77,9 @@ export async function deleteAccountDomain(
|
||||
|
||||
await trx.delete(domains).where(eq(domains.domainId, domainId));
|
||||
|
||||
numOrgDomains = await trx
|
||||
.select()
|
||||
.from(orgDomains)
|
||||
.where(eq(orgDomains.orgId, orgId));
|
||||
await usageService.add(orgId, FeatureId.DOMAINS, -1, trx);
|
||||
});
|
||||
|
||||
if (numOrgDomains) {
|
||||
await usageService.updateCount(
|
||||
orgId,
|
||||
FeatureId.DOMAINS,
|
||||
numOrgDomains.length
|
||||
);
|
||||
}
|
||||
|
||||
return response<DeleteAccountDomainResponse>(res, {
|
||||
data: { success: true },
|
||||
success: true,
|
||||
|
||||
@@ -52,6 +52,7 @@ import createHttpError from "http-errors";
|
||||
import { build } from "@server/build";
|
||||
import { createStore } from "#dynamic/lib/rateLimitStore";
|
||||
import { logActionAudit } from "#dynamic/middlewares";
|
||||
import { checkRoundTripMessage } from "./ws";
|
||||
|
||||
// Root routes
|
||||
export const unauthenticated = Router();
|
||||
@@ -66,9 +67,8 @@ authenticated.use(verifySessionUserMiddleware);
|
||||
|
||||
authenticated.get("/pick-org-defaults", org.pickOrgDefaults);
|
||||
authenticated.get("/org/checkId", org.checkId);
|
||||
if (build === "oss" || build === "enterprise") {
|
||||
authenticated.put("/org", getUserOrgs, org.createOrg);
|
||||
}
|
||||
|
||||
authenticated.put("/org", getUserOrgs, org.createOrg);
|
||||
|
||||
authenticated.get("/orgs", verifyUserIsServerAdmin, org.listOrgs);
|
||||
authenticated.get("/user/:userId/orgs", verifyIsLoggedInUser, org.listUserOrgs);
|
||||
@@ -88,16 +88,14 @@ authenticated.post(
|
||||
org.updateOrg
|
||||
);
|
||||
|
||||
if (build !== "saas") {
|
||||
authenticated.delete(
|
||||
"/org/:orgId",
|
||||
verifyOrgAccess,
|
||||
verifyUserIsOrgOwner,
|
||||
verifyUserHasAction(ActionsEnum.deleteOrg),
|
||||
logActionAudit(ActionsEnum.deleteOrg),
|
||||
org.deleteOrg
|
||||
);
|
||||
}
|
||||
authenticated.delete(
|
||||
"/org/:orgId",
|
||||
verifyOrgAccess,
|
||||
verifyUserIsOrgOwner,
|
||||
verifyUserHasAction(ActionsEnum.deleteOrg),
|
||||
logActionAudit(ActionsEnum.deleteOrg),
|
||||
org.deleteOrg
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/org/:orgId/site",
|
||||
@@ -1175,6 +1173,8 @@ authenticated.get(
|
||||
blueprints.getBlueprint
|
||||
);
|
||||
|
||||
authenticated.get("/ws/round-trip-message/:messageId", checkRoundTripMessage);
|
||||
|
||||
// Auth routes
|
||||
export const authRouter = Router();
|
||||
unauthenticated.use("/auth", authRouter);
|
||||
@@ -1223,6 +1223,7 @@ authRouter.post(
|
||||
auth.login
|
||||
);
|
||||
authRouter.post("/logout", auth.logout);
|
||||
authRouter.post("/delete-my-account", auth.deleteMyAccount);
|
||||
authRouter.post(
|
||||
"/lookup-user",
|
||||
rateLimit({
|
||||
|
||||
@@ -197,7 +197,6 @@ export async function updateSiteBandwidth(
|
||||
usageService
|
||||
.checkLimitSet(
|
||||
orgId,
|
||||
|
||||
FeatureId.EGRESS_DATA_MB,
|
||||
bandwidthUsage
|
||||
)
|
||||
|
||||
@@ -70,6 +70,15 @@ export async function createIdpOrgPolicy(
|
||||
const { idpId, orgId } = parsedParams.data;
|
||||
const { roleMapping, orgMapping } = parsedBody.data;
|
||||
|
||||
if (process.env.IDENTITY_PROVIDER_MODE === "org") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [existing] = await db
|
||||
.select()
|
||||
.from(idp)
|
||||
|
||||
@@ -80,6 +80,17 @@ export async function createOidcIdp(
|
||||
tags
|
||||
} = parsedBody.data;
|
||||
|
||||
if (
|
||||
process.env.IDENTITY_PROVIDER_MODE === "org"
|
||||
) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const key = config.getRawConfig().server.secret!;
|
||||
|
||||
const encryptedSecret = encrypt(clientSecret, key);
|
||||
|
||||
@@ -69,6 +69,15 @@ export async function updateIdpOrgPolicy(
|
||||
const { idpId, orgId } = parsedParams.data;
|
||||
const { roleMapping, orgMapping } = parsedBody.data;
|
||||
|
||||
if (process.env.IDENTITY_PROVIDER_MODE === "org") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Check if IDP and policy exist
|
||||
const [existing] = await db
|
||||
.select()
|
||||
|
||||
@@ -99,6 +99,15 @@ export async function updateOidcIdp(
|
||||
tags
|
||||
} = parsedBody.data;
|
||||
|
||||
if (process.env.IDENTITY_PROVIDER_MODE === "org") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Check if IDP exists and is of type OIDC
|
||||
const [existingIdp] = await db
|
||||
.select()
|
||||
|
||||
@@ -36,6 +36,10 @@ import { build } from "@server/build";
|
||||
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
|
||||
import { isSubscribed } from "#dynamic/lib/isSubscribed";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
import {
|
||||
assignUserToOrg,
|
||||
removeUserFromOrg
|
||||
} from "@server/lib/userOrg";
|
||||
|
||||
const ensureTrailingSlash = (url: string): string => {
|
||||
return url;
|
||||
@@ -436,6 +440,7 @@ export async function validateOidcCallback(
|
||||
}
|
||||
}
|
||||
|
||||
// These are the orgs that the user should be provisioned into based on the IdP mappings and the token claims
|
||||
logger.debug("User org info", { userOrgInfo });
|
||||
|
||||
let existingUserId = existingUser?.userId;
|
||||
@@ -454,15 +459,32 @@ export async function validateOidcCallback(
|
||||
);
|
||||
|
||||
if (!existingUserOrgs.length) {
|
||||
// delete all auto -provisioned user orgs
|
||||
await db
|
||||
.delete(userOrgs)
|
||||
// delete all auto-provisioned user orgs
|
||||
const autoProvisionedUserOrgs = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.userId, existingUser.userId),
|
||||
eq(userOrgs.autoProvisioned, true)
|
||||
)
|
||||
);
|
||||
const orgIdsToRemove = autoProvisionedUserOrgs.map(
|
||||
(uo) => uo.orgId
|
||||
);
|
||||
if (orgIdsToRemove.length > 0) {
|
||||
const orgsToRemove = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(inArray(orgs.orgId, orgIdsToRemove));
|
||||
for (const org of orgsToRemove) {
|
||||
await removeUserFromOrg(
|
||||
org,
|
||||
existingUser.userId,
|
||||
db
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
await calculateUserClientsForOrgs(existingUser.userId);
|
||||
|
||||
@@ -484,7 +506,7 @@ export async function validateOidcCallback(
|
||||
}
|
||||
}
|
||||
|
||||
const orgUserCounts: { orgId: string; userCount: number }[] = [];
|
||||
const orgUserCounts: { orgId: string; userCount: number }[] = [];
|
||||
|
||||
// sync the user with the orgs and roles
|
||||
await db.transaction(async (trx) => {
|
||||
@@ -538,15 +560,14 @@ export async function validateOidcCallback(
|
||||
);
|
||||
|
||||
if (orgsToDelete.length > 0) {
|
||||
await trx.delete(userOrgs).where(
|
||||
and(
|
||||
eq(userOrgs.userId, userId!),
|
||||
inArray(
|
||||
userOrgs.orgId,
|
||||
orgsToDelete.map((org) => org.orgId)
|
||||
)
|
||||
)
|
||||
);
|
||||
const orgIdsToRemove = orgsToDelete.map((org) => org.orgId);
|
||||
const fullOrgsToRemove = await trx
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(inArray(orgs.orgId, orgIdsToRemove));
|
||||
for (const org of fullOrgsToRemove) {
|
||||
await removeUserFromOrg(org, userId!, trx);
|
||||
}
|
||||
}
|
||||
|
||||
// Update roles for existing auto-provisioned orgs where the role has changed
|
||||
@@ -587,15 +608,24 @@ export async function validateOidcCallback(
|
||||
);
|
||||
|
||||
if (orgsToAdd.length > 0) {
|
||||
await trx.insert(userOrgs).values(
|
||||
orgsToAdd.map((org) => ({
|
||||
userId: userId!,
|
||||
orgId: org.orgId,
|
||||
roleId: org.roleId,
|
||||
autoProvisioned: true,
|
||||
dateCreated: new Date().toISOString()
|
||||
}))
|
||||
);
|
||||
for (const org of orgsToAdd) {
|
||||
const [fullOrg] = await trx
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, org.orgId));
|
||||
if (fullOrg) {
|
||||
await assignUserToOrg(
|
||||
fullOrg,
|
||||
{
|
||||
orgId: org.orgId,
|
||||
userId: userId!,
|
||||
roleId: org.roleId,
|
||||
autoProvisioned: true,
|
||||
},
|
||||
trx
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Loop through all the orgs and get the total number of users from the userOrgs table
|
||||
|
||||
@@ -705,6 +705,13 @@ authenticated.get(
|
||||
user.getOrgUser
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/user-by-username",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.getOrgUser),
|
||||
user.getOrgUserByUsername
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/user/:userId/2fa",
|
||||
verifyApiKeyIsRoot,
|
||||
|
||||
@@ -2,7 +2,7 @@ import { MessageHandler } from "@server/routers/ws";
|
||||
import logger from "@server/logger";
|
||||
import { Newt } from "@server/db";
|
||||
import { applyNewtDockerBlueprint } from "@server/lib/blueprints/applyNewtDockerBlueprint";
|
||||
import cache from "@server/lib/cache";
|
||||
import cache from "#dynamic/lib/cache";
|
||||
|
||||
export const handleDockerStatusMessage: MessageHandler = async (context) => {
|
||||
const { message, client, sendToClient } = context;
|
||||
@@ -24,8 +24,8 @@ export const handleDockerStatusMessage: MessageHandler = async (context) => {
|
||||
|
||||
if (available) {
|
||||
logger.info(`Newt ${newt.newtId} has Docker socket access`);
|
||||
cache.set(`${newt.newtId}:socketPath`, socketPath, 0);
|
||||
cache.set(`${newt.newtId}:isAvailable`, available, 0);
|
||||
await cache.set(`${newt.newtId}:socketPath`, socketPath, 0);
|
||||
await cache.set(`${newt.newtId}:isAvailable`, available, 0);
|
||||
} else {
|
||||
logger.warn(`Newt ${newt.newtId} does not have Docker socket access`);
|
||||
}
|
||||
@@ -54,7 +54,7 @@ export const handleDockerContainersMessage: MessageHandler = async (
|
||||
);
|
||||
|
||||
if (containers && containers.length > 0) {
|
||||
cache.set(`${newt.newtId}:dockerContainers`, containers, 0);
|
||||
await cache.set(`${newt.newtId}:dockerContainers`, containers, 0);
|
||||
} else {
|
||||
logger.warn(`Newt ${newt.newtId} does not have Docker containers`);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import { generateSessionToken } from "@server/auth/sessions/app";
|
||||
import {
|
||||
generateSessionToken,
|
||||
validateSessionToken
|
||||
} from "@server/auth/sessions/app";
|
||||
import {
|
||||
clients,
|
||||
db,
|
||||
@@ -26,8 +29,9 @@ import { APP_VERSION } from "@server/lib/consts";
|
||||
|
||||
export const olmGetTokenBodySchema = z.object({
|
||||
olmId: z.string(),
|
||||
secret: z.string(),
|
||||
token: z.string().optional(),
|
||||
secret: z.string().optional(),
|
||||
userToken: z.string().optional(),
|
||||
token: z.string().optional(), // this is the olm token
|
||||
orgId: z.string().optional()
|
||||
});
|
||||
|
||||
@@ -49,7 +53,7 @@ export async function getOlmToken(
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId, secret, token, orgId } = parsedBody.data;
|
||||
const { olmId, secret, token, orgId, userToken } = parsedBody.data;
|
||||
|
||||
try {
|
||||
if (token) {
|
||||
@@ -84,19 +88,45 @@ export async function getOlmToken(
|
||||
);
|
||||
}
|
||||
|
||||
const validSecret = await verifyPassword(
|
||||
secret,
|
||||
existingOlm.secretHash
|
||||
);
|
||||
|
||||
if (!validSecret) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
`Olm id or secret is incorrect. Olm: ID ${olmId}. IP: ${req.ip}.`
|
||||
if (userToken) {
|
||||
const { session: userSession, user } =
|
||||
await validateSessionToken(userToken);
|
||||
if (!userSession || !user) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Invalid user token")
|
||||
);
|
||||
}
|
||||
if (user.userId !== existingOlm.userId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"User token does not match olm"
|
||||
)
|
||||
);
|
||||
}
|
||||
} else if (secret) {
|
||||
// this is for backward compatibility, we want to move towards userToken but some old clients may still be using secret so we will support both for now
|
||||
const validSecret = await verifyPassword(
|
||||
secret,
|
||||
existingOlm.secretHash
|
||||
);
|
||||
|
||||
if (!validSecret) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
`Olm id or secret is incorrect. Olm: ID ${olmId}. IP: ${req.ip}.`
|
||||
);
|
||||
}
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect")
|
||||
);
|
||||
}
|
||||
} else {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect")
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Either secret or userToken is required"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { and, count, eq } from "drizzle-orm";
|
||||
import {
|
||||
domains,
|
||||
Org,
|
||||
@@ -24,13 +24,24 @@ import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { isValidCIDR } from "@server/lib/validators";
|
||||
import { createCustomer } from "#dynamic/lib/billing";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
import { FeatureId, limitsService, freeLimitSet } from "@server/lib/billing";
|
||||
import { build } from "@server/build";
|
||||
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
|
||||
import { doCidrsOverlap } from "@server/lib/ip";
|
||||
import { generateCA } from "@server/lib/sshCA";
|
||||
import { encrypt } from "@server/lib/crypto";
|
||||
|
||||
const validOrgIdRegex = /^[a-z0-9_]+(-[a-z0-9_]+)*$/;
|
||||
|
||||
const createOrgSchema = z.strictObject({
|
||||
orgId: z.string(),
|
||||
orgId: z
|
||||
.string()
|
||||
.min(1, "Organization ID is required")
|
||||
.max(32, "Organization ID must be at most 32 characters")
|
||||
.refine((val) => validOrgIdRegex.test(val), {
|
||||
message:
|
||||
"Organization ID must contain only lowercase letters, numbers, underscores, and single hyphens (no leading, trailing, or consecutive hyphens)"
|
||||
}),
|
||||
name: z.string().min(1).max(255),
|
||||
subnet: z
|
||||
// .union([z.cidrv4(), z.cidrv6()])
|
||||
@@ -108,6 +119,7 @@ export async function createOrg(
|
||||
// )
|
||||
// );
|
||||
// }
|
||||
//
|
||||
|
||||
// make sure the orgId is unique
|
||||
const orgExists = await db
|
||||
@@ -134,8 +146,74 @@ export async function createOrg(
|
||||
);
|
||||
}
|
||||
|
||||
let isFirstOrg: boolean | null = null;
|
||||
let billingOrgIdForNewOrg: string | null = null;
|
||||
if (build === "saas" && req.user) {
|
||||
const ownedOrgs = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.userId, req.user.userId),
|
||||
eq(userOrgs.isOwner, true)
|
||||
)
|
||||
);
|
||||
if (ownedOrgs.length === 0) {
|
||||
isFirstOrg = true;
|
||||
} else {
|
||||
isFirstOrg = false;
|
||||
const [billingOrg] = await db
|
||||
.select({ orgId: orgs.orgId })
|
||||
.from(orgs)
|
||||
.innerJoin(userOrgs, eq(orgs.orgId, userOrgs.orgId))
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.userId, req.user.userId),
|
||||
eq(userOrgs.isOwner, true),
|
||||
eq(orgs.isBillingOrg, true)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
if (billingOrg) {
|
||||
billingOrgIdForNewOrg = billingOrg.orgId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (build == "saas" && billingOrgIdForNewOrg) {
|
||||
const usage = await usageService.getUsage(
|
||||
billingOrgIdForNewOrg,
|
||||
FeatureId.ORGINIZATIONS
|
||||
);
|
||||
if (!usage) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"No usage data found for this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
const rejectOrgs = await usageService.checkLimitSet(
|
||||
billingOrgIdForNewOrg,
|
||||
FeatureId.ORGINIZATIONS,
|
||||
{
|
||||
...usage,
|
||||
instantaneousValue: (usage.instantaneousValue || 0) + 1
|
||||
} // We need to add one to know if we are violating the limit
|
||||
);
|
||||
if (rejectOrgs) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Organization limit exceeded. Please upgrade your plan."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let error = "";
|
||||
let org: Org | null = null;
|
||||
let numOrgs: number | null = null;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
const allDomains = await trx
|
||||
@@ -143,6 +221,29 @@ export async function createOrg(
|
||||
.from(domains)
|
||||
.where(eq(domains.configManaged, true));
|
||||
|
||||
const saasBillingFields =
|
||||
build === "saas" && req.user && isFirstOrg !== null
|
||||
? isFirstOrg
|
||||
? { isBillingOrg: true as const, billingOrgId: orgId } // if this is the first org, it becomes the billing org for itself
|
||||
: {
|
||||
isBillingOrg: false as const,
|
||||
billingOrgId: billingOrgIdForNewOrg
|
||||
}
|
||||
: {};
|
||||
|
||||
const encryptionKey = config.getRawConfig().server.secret;
|
||||
let sshCaFields: {
|
||||
sshCaPrivateKey?: string;
|
||||
sshCaPublicKey?: string;
|
||||
} = {};
|
||||
if (encryptionKey) {
|
||||
const ca = generateCA(`pangolin-ssh-ca-${orgId}`);
|
||||
sshCaFields = {
|
||||
sshCaPrivateKey: encrypt(ca.privateKeyPem, encryptionKey),
|
||||
sshCaPublicKey: ca.publicKeyOpenSSH
|
||||
};
|
||||
}
|
||||
|
||||
const newOrg = await trx
|
||||
.insert(orgs)
|
||||
.values({
|
||||
@@ -150,7 +251,9 @@ export async function createOrg(
|
||||
name,
|
||||
subnet,
|
||||
utilitySubnet,
|
||||
createdAt: new Date().toISOString()
|
||||
createdAt: new Date().toISOString(),
|
||||
...sshCaFields,
|
||||
...saasBillingFields
|
||||
})
|
||||
.returning();
|
||||
|
||||
@@ -169,7 +272,8 @@ export async function createOrg(
|
||||
orgId: newOrg[0].orgId,
|
||||
isAdmin: true,
|
||||
name: "Admin",
|
||||
description: "Admin role with the most permissions"
|
||||
description: "Admin role with the most permissions",
|
||||
sshSudoMode: "full"
|
||||
})
|
||||
.returning({ roleId: roles.roleId });
|
||||
|
||||
@@ -252,6 +356,17 @@ export async function createOrg(
|
||||
);
|
||||
|
||||
await calculateUserClientsForOrgs(ownerUserId, trx);
|
||||
|
||||
if (billingOrgIdForNewOrg) {
|
||||
const [numOrgsResult] = await trx
|
||||
.select({ count: count() })
|
||||
.from(orgs)
|
||||
.where(eq(orgs.billingOrgId, billingOrgIdForNewOrg)); // all the billable orgs including the primary org that is the billing org itself
|
||||
|
||||
numOrgs = numOrgsResult.count;
|
||||
} else {
|
||||
numOrgs = 1; // we only have one org if there is no billing org found out
|
||||
}
|
||||
});
|
||||
|
||||
if (!org) {
|
||||
@@ -267,8 +382,8 @@ export async function createOrg(
|
||||
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, error));
|
||||
}
|
||||
|
||||
if (build == "saas") {
|
||||
// make sure we have the stripe customer
|
||||
if (build === "saas" && isFirstOrg === true) {
|
||||
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
|
||||
const customerId = await createCustomer(orgId, req.user?.email);
|
||||
if (customerId) {
|
||||
await usageService.updateCount(
|
||||
@@ -280,6 +395,14 @@ export async function createOrg(
|
||||
}
|
||||
}
|
||||
|
||||
if (numOrgs) {
|
||||
usageService.updateCount(
|
||||
billingOrgIdForNewOrg || orgId,
|
||||
FeatureId.ORGINIZATIONS,
|
||||
numOrgs
|
||||
);
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: org,
|
||||
success: true,
|
||||
|
||||
@@ -1,28 +1,14 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import {
|
||||
clients,
|
||||
clientSiteResourcesAssociationsCache,
|
||||
clientSitesAssociationsCache,
|
||||
db,
|
||||
domains,
|
||||
olms,
|
||||
orgDomains,
|
||||
resources
|
||||
} from "@server/db";
|
||||
import { newts, newtSessions, orgs, sites, userActions } from "@server/db";
|
||||
import { eq, and, inArray, sql } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import { ActionsEnum, checkUserActionPermission } from "@server/auth/actions";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import { deletePeer } from "../gerbil/peers";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { OlmErrorCodes } from "../olm/error";
|
||||
import { sendTerminateClient } from "../client/terminate";
|
||||
import { deleteOrgById, sendTerminationMessages } from "@server/lib/deleteOrg";
|
||||
import { db, userOrgs, orgs } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
|
||||
const deleteOrgSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
@@ -56,16 +42,23 @@ export async function deleteOrg(
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
const [org] = await db
|
||||
const [data] = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, orgId))
|
||||
.limit(1);
|
||||
.from(userOrgs)
|
||||
.innerJoin(orgs, eq(userOrgs.orgId, orgs.orgId))
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.orgId, orgId),
|
||||
eq(userOrgs.userId, req.user!.userId)
|
||||
)
|
||||
);
|
||||
|
||||
if (!org) {
|
||||
const org = data?.orgs;
|
||||
const userOrg = data?.userOrgs;
|
||||
|
||||
if (!org || !userOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
@@ -73,153 +66,27 @@ export async function deleteOrg(
|
||||
)
|
||||
);
|
||||
}
|
||||
// we need to handle deleting each site
|
||||
const orgSites = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
const orgClients = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.orgId, orgId));
|
||||
|
||||
const deletedNewtIds: string[] = [];
|
||||
const olmsToTerminate: string[] = [];
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
for (const site of orgSites) {
|
||||
if (site.pubKey) {
|
||||
if (site.type == "wireguard") {
|
||||
await deletePeer(site.exitNodeId!, site.pubKey);
|
||||
} else if (site.type == "newt") {
|
||||
// get the newt on the site by querying the newt table for siteId
|
||||
const [deletedNewt] = await trx
|
||||
.delete(newts)
|
||||
.where(eq(newts.siteId, site.siteId))
|
||||
.returning();
|
||||
if (deletedNewt) {
|
||||
deletedNewtIds.push(deletedNewt.newtId);
|
||||
|
||||
// delete all of the sessions for the newt
|
||||
await trx
|
||||
.delete(newtSessions)
|
||||
.where(
|
||||
eq(newtSessions.newtId, deletedNewt.newtId)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`Deleting site ${site.siteId}`);
|
||||
await trx.delete(sites).where(eq(sites.siteId, site.siteId));
|
||||
}
|
||||
for (const client of orgClients) {
|
||||
const [olm] = await trx
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, client.clientId))
|
||||
.limit(1);
|
||||
|
||||
if (olm) {
|
||||
olmsToTerminate.push(olm.olmId);
|
||||
}
|
||||
|
||||
logger.info(`Deleting client ${client.clientId}`);
|
||||
await trx
|
||||
.delete(clients)
|
||||
.where(eq(clients.clientId, client.clientId));
|
||||
|
||||
// also delete the associations
|
||||
await trx
|
||||
.delete(clientSiteResourcesAssociationsCache)
|
||||
.where(
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.clientId,
|
||||
client.clientId
|
||||
)
|
||||
);
|
||||
|
||||
await trx
|
||||
.delete(clientSitesAssociationsCache)
|
||||
.where(
|
||||
eq(
|
||||
clientSitesAssociationsCache.clientId,
|
||||
client.clientId
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const allOrgDomains = await trx
|
||||
.select()
|
||||
.from(orgDomains)
|
||||
.innerJoin(domains, eq(domains.domainId, orgDomains.domainId))
|
||||
.where(
|
||||
and(
|
||||
eq(orgDomains.orgId, orgId),
|
||||
eq(domains.configManaged, false)
|
||||
)
|
||||
);
|
||||
|
||||
// For each domain, check if it belongs to multiple organizations
|
||||
const domainIdsToDelete: string[] = [];
|
||||
for (const orgDomain of allOrgDomains) {
|
||||
const domainId = orgDomain.domains.domainId;
|
||||
|
||||
// Count how many organizations this domain belongs to
|
||||
const orgCount = await trx
|
||||
.select({ count: sql<number>`count(*)` })
|
||||
.from(orgDomains)
|
||||
.where(eq(orgDomains.domainId, domainId));
|
||||
|
||||
// Only delete the domain if it belongs to exactly 1 organization (the one being deleted)
|
||||
if (orgCount[0].count === 1) {
|
||||
domainIdsToDelete.push(domainId);
|
||||
}
|
||||
}
|
||||
|
||||
// Delete domains that belong exclusively to this organization
|
||||
if (domainIdsToDelete.length > 0) {
|
||||
await trx
|
||||
.delete(domains)
|
||||
.where(inArray(domains.domainId, domainIdsToDelete));
|
||||
}
|
||||
|
||||
// Delete resources
|
||||
await trx.delete(resources).where(eq(resources.orgId, orgId));
|
||||
|
||||
await trx.delete(orgs).where(eq(orgs.orgId, orgId));
|
||||
});
|
||||
|
||||
// Send termination messages outside of transaction to prevent blocking
|
||||
for (const newtId of deletedNewtIds) {
|
||||
const payload = {
|
||||
type: `newt/wg/terminate`,
|
||||
data: {}
|
||||
};
|
||||
// Don't await this to prevent blocking the response
|
||||
sendToClient(newtId, payload).catch((error) => {
|
||||
logger.error(
|
||||
"Failed to send termination message to newt:",
|
||||
error
|
||||
);
|
||||
});
|
||||
if (!userOrg.isOwner) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Only organization owners can delete the organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
for (const olmId of olmsToTerminate) {
|
||||
sendTerminateClient(
|
||||
0, // clientId not needed since we're passing olmId
|
||||
OlmErrorCodes.TERMINATED_REKEYED,
|
||||
olmId
|
||||
).catch((error) => {
|
||||
logger.error(
|
||||
"Failed to send termination message to olm:",
|
||||
error
|
||||
);
|
||||
});
|
||||
if (org.isBillingOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Cannot delete a primary organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const result = await deleteOrgById(orgId);
|
||||
sendTerminationMessages(result);
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
@@ -228,6 +95,9 @@ export async function deleteOrg(
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
if (createHttpError.isHttpError(error)) {
|
||||
return next(error);
|
||||
}
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
|
||||
@@ -40,7 +40,11 @@ const listOrgsSchema = z.object({
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
type ResponseOrg = Org & { isOwner?: boolean; isAdmin?: boolean };
|
||||
type ResponseOrg = Org & {
|
||||
isOwner?: boolean;
|
||||
isAdmin?: boolean;
|
||||
isPrimaryOrg?: boolean;
|
||||
};
|
||||
|
||||
export type ListUserOrgsResponse = {
|
||||
orgs: ResponseOrg[];
|
||||
@@ -132,6 +136,9 @@ export async function listUserOrgs(
|
||||
if (val.roles && val.roles.isAdmin) {
|
||||
res.isAdmin = val.roles.isAdmin;
|
||||
}
|
||||
if (val.userOrgs?.isOwner && val.orgs?.isBillingOrg) {
|
||||
res.isPrimaryOrg = val.orgs.isBillingOrg;
|
||||
}
|
||||
return res;
|
||||
});
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { build } from "@server/build";
|
||||
import { cache } from "@server/lib/cache";
|
||||
import { cache } from "#dynamic/lib/cache";
|
||||
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
|
||||
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
import { getOrgTierData } from "#dynamic/lib/billing";
|
||||
@@ -194,9 +194,9 @@ export async function updateOrg(
|
||||
}
|
||||
|
||||
// invalidate the cache for all of the orgs retention days
|
||||
cache.del(`org_${orgId}_retentionDays`);
|
||||
cache.del(`org_${orgId}_actionDays`);
|
||||
cache.del(`org_${orgId}_accessDays`);
|
||||
await cache.del(`org_${orgId}_retentionDays`);
|
||||
await cache.del(`org_${orgId}_actionDays`);
|
||||
await cache.del(`org_${orgId}_accessDays`);
|
||||
|
||||
return response(res, {
|
||||
data: updatedOrg[0],
|
||||
|
||||
@@ -8,7 +8,10 @@ import {
|
||||
userOrgs,
|
||||
resourcePassword,
|
||||
resourcePincode,
|
||||
resourceWhitelist
|
||||
resourceWhitelist,
|
||||
siteResources,
|
||||
userSiteResources,
|
||||
roleSiteResources
|
||||
} from "@server/db";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
@@ -57,9 +60,21 @@ export async function getUserResources(
|
||||
.from(roleResources)
|
||||
.where(eq(roleResources.roleId, userRoleId));
|
||||
|
||||
const [directResources, roleResourceResults] = await Promise.all([
|
||||
const directSiteResourcesQuery = db
|
||||
.select({ siteResourceId: userSiteResources.siteResourceId })
|
||||
.from(userSiteResources)
|
||||
.where(eq(userSiteResources.userId, userId));
|
||||
|
||||
const roleSiteResourcesQuery = db
|
||||
.select({ siteResourceId: roleSiteResources.siteResourceId })
|
||||
.from(roleSiteResources)
|
||||
.where(eq(roleSiteResources.roleId, userRoleId));
|
||||
|
||||
const [directResources, roleResourceResults, directSiteResourceResults, roleSiteResourceResults] = await Promise.all([
|
||||
directResourcesQuery,
|
||||
roleResourcesQuery
|
||||
roleResourcesQuery,
|
||||
directSiteResourcesQuery,
|
||||
roleSiteResourcesQuery
|
||||
]);
|
||||
|
||||
// Combine all accessible resource IDs
|
||||
@@ -68,18 +83,25 @@ export async function getUserResources(
|
||||
...roleResourceResults.map((r) => r.resourceId)
|
||||
];
|
||||
|
||||
if (accessibleResourceIds.length === 0) {
|
||||
return response(res, {
|
||||
data: { resources: [] },
|
||||
success: true,
|
||||
error: false,
|
||||
message: "No resources found",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
// Combine all accessible site resource IDs
|
||||
const accessibleSiteResourceIds = [
|
||||
...directSiteResourceResults.map((r) => r.siteResourceId),
|
||||
...roleSiteResourceResults.map((r) => r.siteResourceId)
|
||||
];
|
||||
|
||||
// Get resource details for accessible resources
|
||||
const resourcesData = await db
|
||||
let resourcesData: Array<{
|
||||
resourceId: number;
|
||||
name: string;
|
||||
fullDomain: string | null;
|
||||
ssl: boolean;
|
||||
enabled: boolean;
|
||||
sso: boolean;
|
||||
protocol: string;
|
||||
emailWhitelistEnabled: boolean;
|
||||
}> = [];
|
||||
if (accessibleResourceIds.length > 0) {
|
||||
resourcesData = await db
|
||||
.select({
|
||||
resourceId: resources.resourceId,
|
||||
name: resources.name,
|
||||
@@ -98,6 +120,40 @@ export async function getUserResources(
|
||||
eq(resources.enabled, true)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Get site resource details for accessible site resources
|
||||
let siteResourcesData: Array<{
|
||||
siteResourceId: number;
|
||||
name: string;
|
||||
destination: string;
|
||||
mode: string;
|
||||
protocol: string | null;
|
||||
enabled: boolean;
|
||||
alias: string | null;
|
||||
aliasAddress: string | null;
|
||||
}> = [];
|
||||
if (accessibleSiteResourceIds.length > 0) {
|
||||
siteResourcesData = await db
|
||||
.select({
|
||||
siteResourceId: siteResources.siteResourceId,
|
||||
name: siteResources.name,
|
||||
destination: siteResources.destination,
|
||||
mode: siteResources.mode,
|
||||
protocol: siteResources.protocol,
|
||||
enabled: siteResources.enabled,
|
||||
alias: siteResources.alias,
|
||||
aliasAddress: siteResources.aliasAddress
|
||||
})
|
||||
.from(siteResources)
|
||||
.where(
|
||||
and(
|
||||
inArray(siteResources.siteResourceId, accessibleSiteResourceIds),
|
||||
eq(siteResources.orgId, orgId),
|
||||
eq(siteResources.enabled, true)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Check for password, pincode, and whitelist protection for each resource
|
||||
const resourcesWithAuth = await Promise.all(
|
||||
@@ -161,8 +217,26 @@ export async function getUserResources(
|
||||
})
|
||||
);
|
||||
|
||||
// Format site resources
|
||||
const siteResourcesFormatted = siteResourcesData.map((siteResource) => {
|
||||
return {
|
||||
siteResourceId: siteResource.siteResourceId,
|
||||
name: siteResource.name,
|
||||
destination: siteResource.destination,
|
||||
mode: siteResource.mode,
|
||||
protocol: siteResource.protocol,
|
||||
enabled: siteResource.enabled,
|
||||
alias: siteResource.alias,
|
||||
aliasAddress: siteResource.aliasAddress,
|
||||
type: 'site' as const
|
||||
};
|
||||
});
|
||||
|
||||
return response(res, {
|
||||
data: { resources: resourcesWithAuth },
|
||||
data: {
|
||||
resources: resourcesWithAuth,
|
||||
siteResources: siteResourcesFormatted
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "User resources retrieved successfully",
|
||||
@@ -190,5 +264,16 @@ export type GetUserResourcesResponse = {
|
||||
protected: boolean;
|
||||
protocol: string;
|
||||
}>;
|
||||
siteResources: Array<{
|
||||
siteResourceId: number;
|
||||
name: string;
|
||||
destination: string;
|
||||
mode: string;
|
||||
protocol: string | null;
|
||||
enabled: boolean;
|
||||
alias: string | null;
|
||||
aliasAddress: string | null;
|
||||
type: 'site';
|
||||
}>;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
and,
|
||||
asc,
|
||||
count,
|
||||
desc,
|
||||
eq,
|
||||
inArray,
|
||||
isNull,
|
||||
@@ -44,28 +45,74 @@ const listResourcesSchema = z.object({
|
||||
.positive()
|
||||
.optional()
|
||||
.catch(20)
|
||||
.default(20),
|
||||
.default(20)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 20,
|
||||
description: "Number of items per page"
|
||||
}),
|
||||
page: z.coerce
|
||||
.number<string>() // for prettier formatting
|
||||
.int()
|
||||
.min(0)
|
||||
.optional()
|
||||
.catch(1)
|
||||
.default(1),
|
||||
.default(1)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 1,
|
||||
description: "Page number to retrieve"
|
||||
}),
|
||||
query: z.string().optional(),
|
||||
sort_by: z
|
||||
.enum(["name"])
|
||||
.optional()
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["name"],
|
||||
description: "Field to sort by"
|
||||
}),
|
||||
order: z
|
||||
.enum(["asc", "desc"])
|
||||
.optional()
|
||||
.default("asc")
|
||||
.catch("asc")
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["asc", "desc"],
|
||||
default: "asc",
|
||||
description: "Sort order"
|
||||
}),
|
||||
enabled: z
|
||||
.enum(["true", "false"])
|
||||
.transform((v) => v === "true")
|
||||
.optional()
|
||||
.catch(undefined),
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "boolean",
|
||||
description: "Filter resources based on enabled status"
|
||||
}),
|
||||
authState: z
|
||||
.enum(["protected", "not_protected", "none"])
|
||||
.optional()
|
||||
.catch(undefined),
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["protected", "not_protected", "none"],
|
||||
description:
|
||||
"Filter resources based on authentication state. `protected` means the resource has at least one auth mechanism (password, pincode, header auth, SSO, or email whitelist). `not_protected` means the resource has no auth mechanisms. `none` means the resource is not protected by HTTP (i.e. it has no auth mechanisms and http is false)."
|
||||
}),
|
||||
healthStatus: z
|
||||
.enum(["no_targets", "healthy", "degraded", "offline", "unknown"])
|
||||
.optional()
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["no_targets", "healthy", "degraded", "offline", "unknown"],
|
||||
description:
|
||||
"Filter resources based on health status of their targets. `healthy` means all targets are healthy. `degraded` means at least one target is unhealthy, but not all are unhealthy. `offline` means all targets are unhealthy. `unknown` means all targets have unknown health status. `no_targets` means the resource has no targets."
|
||||
})
|
||||
});
|
||||
|
||||
// grouped by resource with targets[])
|
||||
@@ -203,8 +250,16 @@ export async function listResources(
|
||||
)
|
||||
);
|
||||
}
|
||||
const { page, pageSize, authState, enabled, query, healthStatus } =
|
||||
parsedQuery.data;
|
||||
const {
|
||||
page,
|
||||
pageSize,
|
||||
authState,
|
||||
enabled,
|
||||
query,
|
||||
healthStatus,
|
||||
sort_by,
|
||||
order
|
||||
} = parsedQuery.data;
|
||||
|
||||
const parsedParams = listResourcesParamsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
@@ -369,7 +424,13 @@ export async function listResources(
|
||||
baseQuery
|
||||
.limit(pageSize)
|
||||
.offset(pageSize * (page - 1))
|
||||
.orderBy(asc(resources.resourceId)),
|
||||
.orderBy(
|
||||
sort_by
|
||||
? order === "asc"
|
||||
? asc(resources[sort_by])
|
||||
: desc(resources[sort_by])
|
||||
: asc(resources.resourceId)
|
||||
),
|
||||
countQuery
|
||||
]);
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
Resource,
|
||||
resources
|
||||
} from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { eq, and, ne } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -33,7 +33,15 @@ const updateResourceParamsSchema = z.strictObject({
|
||||
const updateHttpResourceBodySchema = z
|
||||
.strictObject({
|
||||
name: z.string().min(1).max(255).optional(),
|
||||
niceId: z.string().min(1).max(255).optional(),
|
||||
niceId: z
|
||||
.string()
|
||||
.min(1)
|
||||
.max(255)
|
||||
.regex(
|
||||
/^[a-zA-Z0-9-]+$/,
|
||||
"niceId can only contain letters, numbers, and dashes"
|
||||
)
|
||||
.optional(),
|
||||
subdomain: subdomainSchema.nullable().optional(),
|
||||
ssl: z.boolean().optional(),
|
||||
sso: z.boolean().optional(),
|
||||
@@ -248,14 +256,13 @@ async function updateHttpResource(
|
||||
.where(
|
||||
and(
|
||||
eq(resources.niceId, updateData.niceId),
|
||||
eq(resources.orgId, resource.orgId)
|
||||
eq(resources.orgId, resource.orgId),
|
||||
ne(resources.resourceId, resource.resourceId) // exclude the current resource from the search
|
||||
)
|
||||
);
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (
|
||||
existingResource &&
|
||||
existingResource.resourceId !== resource.resourceId
|
||||
) {
|
||||
if (existingResource) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
@@ -343,7 +350,10 @@ async function updateHttpResource(
|
||||
headers = null;
|
||||
}
|
||||
|
||||
const isLicensed = await isLicensedOrSubscribed(resource.orgId, tierMatrix.maintencePage);
|
||||
const isLicensed = await isLicensedOrSubscribed(
|
||||
resource.orgId,
|
||||
tierMatrix.maintencePage
|
||||
);
|
||||
if (!isLicensed) {
|
||||
updateData.maintenanceModeEnabled = undefined;
|
||||
updateData.maintenanceModeType = undefined;
|
||||
|
||||
@@ -18,10 +18,17 @@ const createRoleParamsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const sshSudoModeSchema = z.enum(["none", "full", "commands"]);
|
||||
|
||||
const createRoleSchema = z.strictObject({
|
||||
name: z.string().min(1).max(255),
|
||||
description: z.string().optional(),
|
||||
requireDeviceApproval: z.boolean().optional()
|
||||
requireDeviceApproval: z.boolean().optional(),
|
||||
allowSsh: z.boolean().optional(),
|
||||
sshSudoMode: sshSudoModeSchema.optional(),
|
||||
sshSudoCommands: z.array(z.string()).optional(),
|
||||
sshCreateHomeDir: z.boolean().optional(),
|
||||
sshUnixGroups: z.array(z.string()).optional()
|
||||
});
|
||||
|
||||
export const defaultRoleAllowedActions: ActionsEnum[] = [
|
||||
@@ -101,24 +108,40 @@ export async function createRole(
|
||||
);
|
||||
}
|
||||
|
||||
const isLicensed = await isLicensedOrSubscribed(orgId, tierMatrix.deviceApprovals);
|
||||
if (!isLicensed) {
|
||||
const isLicensedDeviceApprovals = await isLicensedOrSubscribed(orgId, tierMatrix.deviceApprovals);
|
||||
if (!isLicensedDeviceApprovals) {
|
||||
roleData.requireDeviceApproval = undefined;
|
||||
}
|
||||
|
||||
const isLicensedSshPam = await isLicensedOrSubscribed(orgId, tierMatrix.sshPam);
|
||||
const roleInsertValues: Record<string, unknown> = {
|
||||
name: roleData.name,
|
||||
orgId
|
||||
};
|
||||
if (roleData.description !== undefined) roleInsertValues.description = roleData.description;
|
||||
if (roleData.requireDeviceApproval !== undefined) roleInsertValues.requireDeviceApproval = roleData.requireDeviceApproval;
|
||||
if (isLicensedSshPam) {
|
||||
if (roleData.sshSudoMode !== undefined) roleInsertValues.sshSudoMode = roleData.sshSudoMode;
|
||||
if (roleData.sshSudoCommands !== undefined) roleInsertValues.sshSudoCommands = JSON.stringify(roleData.sshSudoCommands);
|
||||
if (roleData.sshCreateHomeDir !== undefined) roleInsertValues.sshCreateHomeDir = roleData.sshCreateHomeDir;
|
||||
if (roleData.sshUnixGroups !== undefined) roleInsertValues.sshUnixGroups = JSON.stringify(roleData.sshUnixGroups);
|
||||
}
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
const newRole = await trx
|
||||
.insert(roles)
|
||||
.values({
|
||||
...roleData,
|
||||
orgId
|
||||
})
|
||||
.values(roleInsertValues as typeof roles.$inferInsert)
|
||||
.returning();
|
||||
|
||||
const actionsToInsert = [...defaultRoleAllowedActions];
|
||||
if (roleData.allowSsh) {
|
||||
actionsToInsert.push(ActionsEnum.signSshKey);
|
||||
}
|
||||
|
||||
await trx
|
||||
.insert(roleActions)
|
||||
.values(
|
||||
defaultRoleAllowedActions.map((action) => ({
|
||||
actionsToInsert.map((action) => ({
|
||||
roleId: newRole[0].roleId,
|
||||
actionId: action,
|
||||
orgId
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { db, orgs, roles } from "@server/db";
|
||||
import { db, orgs, roleActions, roles } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { eq, sql } from "drizzle-orm";
|
||||
import { and, eq, inArray, sql } from "drizzle-orm";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
@@ -37,7 +38,11 @@ async function queryRoles(orgId: string, limit: number, offset: number) {
|
||||
name: roles.name,
|
||||
description: roles.description,
|
||||
orgName: orgs.name,
|
||||
requireDeviceApproval: roles.requireDeviceApproval
|
||||
requireDeviceApproval: roles.requireDeviceApproval,
|
||||
sshSudoMode: roles.sshSudoMode,
|
||||
sshSudoCommands: roles.sshSudoCommands,
|
||||
sshCreateHomeDir: roles.sshCreateHomeDir,
|
||||
sshUnixGroups: roles.sshUnixGroups
|
||||
})
|
||||
.from(roles)
|
||||
.leftJoin(orgs, eq(roles.orgId, orgs.orgId))
|
||||
@@ -106,9 +111,28 @@ export async function listRoles(
|
||||
const totalCountResult = await countQuery;
|
||||
const totalCount = totalCountResult[0].count;
|
||||
|
||||
let rolesWithAllowSsh = rolesList;
|
||||
if (rolesList.length > 0) {
|
||||
const roleIds = rolesList.map((r) => r.roleId);
|
||||
const signSshKeyRows = await db
|
||||
.select({ roleId: roleActions.roleId })
|
||||
.from(roleActions)
|
||||
.where(
|
||||
and(
|
||||
inArray(roleActions.roleId, roleIds),
|
||||
eq(roleActions.actionId, ActionsEnum.signSshKey)
|
||||
)
|
||||
);
|
||||
const roleIdsWithSsh = new Set(signSshKeyRows.map((r) => r.roleId));
|
||||
rolesWithAllowSsh = rolesList.map((r) => ({
|
||||
...r,
|
||||
allowSsh: roleIdsWithSsh.has(r.roleId)
|
||||
}));
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: {
|
||||
roles: rolesList,
|
||||
roles: rolesWithAllowSsh,
|
||||
pagination: {
|
||||
total: totalCount,
|
||||
limit,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, type Role } from "@server/db";
|
||||
import { roles } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { roleActions, roles } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -16,11 +17,18 @@ const updateRoleParamsSchema = z.strictObject({
|
||||
roleId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
const sshSudoModeSchema = z.enum(["none", "full", "commands"]);
|
||||
|
||||
const updateRoleBodySchema = z
|
||||
.strictObject({
|
||||
name: z.string().min(1).max(255).optional(),
|
||||
description: z.string().optional(),
|
||||
requireDeviceApproval: z.boolean().optional()
|
||||
requireDeviceApproval: z.boolean().optional(),
|
||||
allowSsh: z.boolean().optional(),
|
||||
sshSudoMode: sshSudoModeSchema.optional(),
|
||||
sshSudoCommands: z.array(z.string()).optional(),
|
||||
sshCreateHomeDir: z.boolean().optional(),
|
||||
sshUnixGroups: z.array(z.string()).optional()
|
||||
})
|
||||
.refine((data) => Object.keys(data).length > 0, {
|
||||
error: "At least one field must be provided for update"
|
||||
@@ -75,7 +83,9 @@ export async function updateRole(
|
||||
}
|
||||
|
||||
const { roleId } = parsedParams.data;
|
||||
const updateData = parsedBody.data;
|
||||
const body = parsedBody.data;
|
||||
const { allowSsh, ...restBody } = body;
|
||||
const updateData: Record<string, unknown> = { ...restBody };
|
||||
|
||||
const role = await db
|
||||
.select()
|
||||
@@ -92,16 +102,14 @@ export async function updateRole(
|
||||
);
|
||||
}
|
||||
|
||||
if (role[0].isAdmin) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
`Cannot update a Admin role`
|
||||
)
|
||||
);
|
||||
const orgId = role[0].orgId;
|
||||
const isAdminRole = role[0].isAdmin;
|
||||
|
||||
if (isAdminRole) {
|
||||
delete updateData.name;
|
||||
delete updateData.description;
|
||||
}
|
||||
|
||||
const orgId = role[0].orgId;
|
||||
if (!orgId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
@@ -111,18 +119,70 @@ export async function updateRole(
|
||||
);
|
||||
}
|
||||
|
||||
const isLicensed = await isLicensedOrSubscribed(orgId, tierMatrix.deviceApprovals);
|
||||
if (!isLicensed) {
|
||||
const isLicensedDeviceApprovals = await isLicensedOrSubscribed(orgId, tierMatrix.deviceApprovals);
|
||||
if (!isLicensedDeviceApprovals) {
|
||||
updateData.requireDeviceApproval = undefined;
|
||||
}
|
||||
|
||||
const updatedRole = await db
|
||||
.update(roles)
|
||||
.set(updateData)
|
||||
.where(eq(roles.roleId, roleId))
|
||||
.returning();
|
||||
const isLicensedSshPam = await isLicensedOrSubscribed(orgId, tierMatrix.sshPam);
|
||||
if (!isLicensedSshPam) {
|
||||
delete updateData.sshSudoMode;
|
||||
delete updateData.sshSudoCommands;
|
||||
delete updateData.sshCreateHomeDir;
|
||||
delete updateData.sshUnixGroups;
|
||||
} else {
|
||||
if (Array.isArray(updateData.sshSudoCommands)) {
|
||||
updateData.sshSudoCommands = JSON.stringify(updateData.sshSudoCommands);
|
||||
}
|
||||
if (Array.isArray(updateData.sshUnixGroups)) {
|
||||
updateData.sshUnixGroups = JSON.stringify(updateData.sshUnixGroups);
|
||||
}
|
||||
}
|
||||
|
||||
if (updatedRole.length === 0) {
|
||||
const updatedRole = await db.transaction(async (trx) => {
|
||||
const result = await trx
|
||||
.update(roles)
|
||||
.set(updateData as typeof roles.$inferInsert)
|
||||
.where(eq(roles.roleId, roleId))
|
||||
.returning();
|
||||
|
||||
if (result.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (allowSsh === true) {
|
||||
const existing = await trx
|
||||
.select()
|
||||
.from(roleActions)
|
||||
.where(
|
||||
and(
|
||||
eq(roleActions.roleId, roleId),
|
||||
eq(roleActions.actionId, ActionsEnum.signSshKey)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
if (existing.length === 0) {
|
||||
await trx.insert(roleActions).values({
|
||||
roleId,
|
||||
actionId: ActionsEnum.signSshKey,
|
||||
orgId: orgId!
|
||||
});
|
||||
}
|
||||
} else if (allowSsh === false) {
|
||||
await trx
|
||||
.delete(roleActions)
|
||||
.where(
|
||||
and(
|
||||
eq(roleActions.roleId, roleId),
|
||||
eq(roleActions.actionId, ActionsEnum.signSshKey)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return result[0];
|
||||
});
|
||||
|
||||
if (!updatedRole) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
@@ -132,7 +192,7 @@ export async function updateRole(
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: updatedRole[0],
|
||||
data: updatedRole,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Role updated successfully",
|
||||
|
||||
@@ -6,7 +6,7 @@ import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { eq, and, count } from "drizzle-orm";
|
||||
import { getUniqueSiteName } from "../../db/names";
|
||||
import { addPeer } from "../gerbil/peers";
|
||||
import { fromError } from "zod-validation-error";
|
||||
@@ -288,7 +288,6 @@ export async function createSite(
|
||||
const niceId = await getUniqueSiteName(orgId);
|
||||
|
||||
let newSite: Site | undefined;
|
||||
let numSites: Site[] | undefined;
|
||||
await db.transaction(async (trx) => {
|
||||
if (type == "newt") {
|
||||
[newSite] = await trx
|
||||
@@ -443,20 +442,9 @@ export async function createSite(
|
||||
});
|
||||
}
|
||||
|
||||
numSites = await trx
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.orgId, orgId));
|
||||
await usageService.add(orgId, FeatureId.SITES, 1, trx);
|
||||
});
|
||||
|
||||
if (numSites) {
|
||||
await usageService.updateCount(
|
||||
orgId,
|
||||
FeatureId.SITES,
|
||||
numSites.length
|
||||
);
|
||||
}
|
||||
|
||||
if (!newSite) {
|
||||
return next(
|
||||
createHttpError(
|
||||
|
||||
@@ -64,7 +64,6 @@ export async function deleteSite(
|
||||
}
|
||||
|
||||
let deletedNewtId: string | null = null;
|
||||
let numSites: Site[] | undefined;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
if (site.type == "wireguard") {
|
||||
@@ -103,19 +102,9 @@ export async function deleteSite(
|
||||
|
||||
await trx.delete(sites).where(eq(sites.siteId, siteId));
|
||||
|
||||
numSites = await trx
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.orgId, site.orgId));
|
||||
await usageService.add(site.orgId, FeatureId.SITES, -1, trx);
|
||||
});
|
||||
|
||||
if (numSites) {
|
||||
await usageService.updateCount(
|
||||
site.orgId,
|
||||
FeatureId.SITES,
|
||||
numSites.length
|
||||
);
|
||||
}
|
||||
// Send termination message outside of transaction to prevent blocking
|
||||
if (deletedNewtId) {
|
||||
const payload = {
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
sites,
|
||||
userSites
|
||||
} from "@server/db";
|
||||
import cache from "@server/lib/cache";
|
||||
import cache from "#dynamic/lib/cache";
|
||||
import response from "@server/lib/response";
|
||||
import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
@@ -23,7 +23,7 @@ import { fromError } from "zod-validation-error";
|
||||
|
||||
async function getLatestNewtVersion(): Promise<string | null> {
|
||||
try {
|
||||
const cachedVersion = cache.get<string>("latestNewtVersion");
|
||||
const cachedVersion = await cache.get<string>("latestNewtVersion");
|
||||
if (cachedVersion) {
|
||||
return cachedVersion;
|
||||
}
|
||||
@@ -55,7 +55,7 @@ async function getLatestNewtVersion(): Promise<string | null> {
|
||||
tags = tags.filter((version) => !version.name.includes("rc"));
|
||||
const latestVersion = tags[0].name;
|
||||
|
||||
cache.set("latestNewtVersion", latestVersion);
|
||||
await cache.set("latestNewtVersion", latestVersion);
|
||||
|
||||
return latestVersion;
|
||||
} catch (error: any) {
|
||||
@@ -88,25 +88,54 @@ const listSitesSchema = z.object({
|
||||
.positive()
|
||||
.optional()
|
||||
.catch(20)
|
||||
.default(20),
|
||||
.default(20)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 20,
|
||||
description: "Number of items per page"
|
||||
}),
|
||||
page: z.coerce
|
||||
.number<string>() // for prettier formatting
|
||||
.int()
|
||||
.min(0)
|
||||
.optional()
|
||||
.catch(1)
|
||||
.default(1),
|
||||
.default(1)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 1,
|
||||
description: "Page number to retrieve"
|
||||
}),
|
||||
query: z.string().optional(),
|
||||
sort_by: z
|
||||
.enum(["megabytesIn", "megabytesOut"])
|
||||
.enum(["name", "megabytesIn", "megabytesOut"])
|
||||
.optional()
|
||||
.catch(undefined),
|
||||
order: z.enum(["asc", "desc"]).optional().default("asc").catch("asc"),
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["name", "megabytesIn", "megabytesOut"],
|
||||
description: "Field to sort by"
|
||||
}),
|
||||
order: z
|
||||
.enum(["asc", "desc"])
|
||||
.optional()
|
||||
.default("asc")
|
||||
.catch("asc")
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["asc", "desc"],
|
||||
default: "asc",
|
||||
description: "Sort order"
|
||||
}),
|
||||
online: z
|
||||
.enum(["true", "false"])
|
||||
.transform((v) => v === "true")
|
||||
.optional()
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "boolean",
|
||||
description: "Filter by online status"
|
||||
})
|
||||
});
|
||||
|
||||
function querySitesBase() {
|
||||
|
||||
@@ -11,7 +11,7 @@ import { fromError } from "zod-validation-error";
|
||||
import stoi from "@server/lib/stoi";
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import { fetchContainers, dockerSocket } from "../newt/dockerSocket";
|
||||
import cache from "@server/lib/cache";
|
||||
import cache from "#dynamic/lib/cache";
|
||||
|
||||
export interface ContainerNetwork {
|
||||
networkId: string;
|
||||
@@ -150,7 +150,7 @@ async function triggerFetch(siteId: number) {
|
||||
|
||||
// clear the cache for this Newt ID so that the site has to keep asking for the containers
|
||||
// this is to ensure that the site always gets the latest data
|
||||
cache.del(`${newt.newtId}:dockerContainers`);
|
||||
await cache.del(`${newt.newtId}:dockerContainers`);
|
||||
|
||||
return { siteId, newtId: newt.newtId };
|
||||
}
|
||||
@@ -158,7 +158,7 @@ async function triggerFetch(siteId: number) {
|
||||
async function queryContainers(siteId: number) {
|
||||
const { newt } = await getSiteAndNewt(siteId);
|
||||
|
||||
const result = cache.get(`${newt.newtId}:dockerContainers`) as Container[];
|
||||
const result = await cache.get<Container[]>(`${newt.newtId}:dockerContainers`);
|
||||
if (!result) {
|
||||
throw createHttpError(
|
||||
HttpCode.TOO_EARLY,
|
||||
@@ -173,7 +173,7 @@ async function isDockerAvailable(siteId: number): Promise<boolean> {
|
||||
const { newt } = await getSiteAndNewt(siteId);
|
||||
|
||||
const key = `${newt.newtId}:isAvailable`;
|
||||
const isAvailable = cache.get(key);
|
||||
const isAvailable = await cache.get(key);
|
||||
|
||||
return !!isAvailable;
|
||||
}
|
||||
@@ -186,9 +186,11 @@ async function getDockerStatus(
|
||||
const keys = ["isAvailable", "socketPath"];
|
||||
const mappedKeys = keys.map((x) => `${newt.newtId}:${x}`);
|
||||
|
||||
const values = await cache.mget<boolean | string>(mappedKeys);
|
||||
|
||||
const result = {
|
||||
isAvailable: cache.get(mappedKeys[0]) as boolean,
|
||||
socketPath: cache.get(mappedKeys[1]) as string | undefined
|
||||
isAvailable: values[0] as boolean,
|
||||
socketPath: values[1] as string | undefined
|
||||
};
|
||||
|
||||
return result;
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { sites } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { eq, and, ne } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -19,8 +19,8 @@ const updateSiteBodySchema = z
|
||||
.strictObject({
|
||||
name: z.string().min(1).max(255).optional(),
|
||||
niceId: z.string().min(1).max(255).optional(),
|
||||
dockerSocketEnabled: z.boolean().optional(),
|
||||
remoteSubnets: z.string().optional()
|
||||
dockerSocketEnabled: z.boolean().optional()
|
||||
// remoteSubnets: z.string().optional()
|
||||
// subdomain: z
|
||||
// .string()
|
||||
// .min(1)
|
||||
@@ -86,18 +86,19 @@ export async function updateSite(
|
||||
|
||||
// if niceId is provided, check if it's already in use by another site
|
||||
if (updateData.niceId) {
|
||||
const existingSite = await db
|
||||
const [existingSite] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(
|
||||
and(
|
||||
eq(sites.niceId, updateData.niceId),
|
||||
eq(sites.orgId, sites.orgId)
|
||||
eq(sites.orgId, sites.orgId),
|
||||
ne(sites.siteId, siteId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (existingSite.length > 0 && existingSite[0].siteId !== siteId) {
|
||||
if (existingSite) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
@@ -107,22 +108,22 @@ export async function updateSite(
|
||||
}
|
||||
}
|
||||
|
||||
// if remoteSubnets is provided, ensure it's a valid comma-separated list of cidrs
|
||||
if (updateData.remoteSubnets) {
|
||||
const subnets = updateData.remoteSubnets
|
||||
.split(",")
|
||||
.map((s) => s.trim());
|
||||
for (const subnet of subnets) {
|
||||
if (!isValidCIDR(subnet)) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Invalid CIDR format: ${subnet}`
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
// // if remoteSubnets is provided, ensure it's a valid comma-separated list of cidrs
|
||||
// if (updateData.remoteSubnets) {
|
||||
// const subnets = updateData.remoteSubnets
|
||||
// .split(",")
|
||||
// .map((s) => s.trim());
|
||||
// for (const subnet of subnets) {
|
||||
// if (!isValidCIDR(subnet)) {
|
||||
// return next(
|
||||
// createHttpError(
|
||||
// HttpCode.BAD_REQUEST,
|
||||
// `Invalid CIDR format: ${subnet}`
|
||||
// )
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
const updatedSite = await db
|
||||
.update(sites)
|
||||
|
||||
@@ -16,6 +16,8 @@ import {
|
||||
isIpInCidr,
|
||||
portRangeStringSchema
|
||||
} from "@server/lib/ip";
|
||||
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations";
|
||||
import response from "@server/lib/response";
|
||||
import logger from "@server/logger";
|
||||
@@ -53,7 +55,9 @@ const createSiteResourceSchema = z
|
||||
clientIds: z.array(z.int()),
|
||||
tcpPortRangeString: portRangeStringSchema,
|
||||
udpPortRangeString: portRangeStringSchema,
|
||||
disableIcmp: z.boolean().optional()
|
||||
disableIcmp: z.boolean().optional(),
|
||||
authDaemonPort: z.int().positive().optional(),
|
||||
authDaemonMode: z.enum(["site", "remote"]).optional()
|
||||
})
|
||||
.strict()
|
||||
.refine(
|
||||
@@ -168,7 +172,9 @@ export async function createSiteResource(
|
||||
clientIds,
|
||||
tcpPortRangeString,
|
||||
udpPortRangeString,
|
||||
disableIcmp
|
||||
disableIcmp,
|
||||
authDaemonPort,
|
||||
authDaemonMode
|
||||
} = parsedBody.data;
|
||||
|
||||
// Verify the site exists and belongs to the org
|
||||
@@ -267,6 +273,11 @@ export async function createSiteResource(
|
||||
}
|
||||
}
|
||||
|
||||
const isLicensedSshPam = await isLicensedOrSubscribed(
|
||||
orgId,
|
||||
tierMatrix.sshPam
|
||||
);
|
||||
|
||||
const niceId = await getUniqueSiteResourceName(orgId);
|
||||
let aliasAddress: string | null = null;
|
||||
if (mode == "host") {
|
||||
@@ -277,25 +288,29 @@ export async function createSiteResource(
|
||||
let newSiteResource: SiteResource | undefined;
|
||||
await db.transaction(async (trx) => {
|
||||
// Create the site resource
|
||||
const insertValues: typeof siteResources.$inferInsert = {
|
||||
siteId,
|
||||
niceId,
|
||||
orgId,
|
||||
name,
|
||||
mode: mode as "host" | "cidr",
|
||||
destination,
|
||||
enabled,
|
||||
alias,
|
||||
aliasAddress,
|
||||
tcpPortRangeString,
|
||||
udpPortRangeString,
|
||||
disableIcmp
|
||||
};
|
||||
if (isLicensedSshPam) {
|
||||
if (authDaemonPort !== undefined)
|
||||
insertValues.authDaemonPort = authDaemonPort;
|
||||
if (authDaemonMode !== undefined)
|
||||
insertValues.authDaemonMode = authDaemonMode;
|
||||
}
|
||||
[newSiteResource] = await trx
|
||||
.insert(siteResources)
|
||||
.values({
|
||||
siteId,
|
||||
niceId,
|
||||
orgId,
|
||||
name,
|
||||
mode: mode as "host" | "cidr",
|
||||
// protocol: mode === "port" ? protocol : null,
|
||||
// proxyPort: mode === "port" ? proxyPort : null,
|
||||
// destinationPort: mode === "port" ? destinationPort : null,
|
||||
destination,
|
||||
enabled,
|
||||
alias,
|
||||
aliasAddress,
|
||||
tcpPortRangeString,
|
||||
udpPortRangeString,
|
||||
disableIcmp
|
||||
})
|
||||
.values(insertValues)
|
||||
.returning();
|
||||
|
||||
const siteResourceId = newSiteResource.siteResourceId;
|
||||
|
||||
@@ -4,7 +4,7 @@ import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import type { PaginatedResponse } from "@server/types/Pagination";
|
||||
import { and, asc, eq, like, or, sql } from "drizzle-orm";
|
||||
import { and, asc, desc, eq, like, or, sql } from "drizzle-orm";
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
@@ -21,16 +21,54 @@ const listAllSiteResourcesByOrgQuerySchema = z.object({
|
||||
.positive()
|
||||
.optional()
|
||||
.catch(20)
|
||||
.default(20),
|
||||
.default(20)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 20,
|
||||
description: "Number of items per page"
|
||||
}),
|
||||
page: z.coerce
|
||||
.number<string>() // for prettier formatting
|
||||
.int()
|
||||
.min(0)
|
||||
.optional()
|
||||
.catch(1)
|
||||
.default(1),
|
||||
.default(1)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 1,
|
||||
description: "Page number to retrieve"
|
||||
}),
|
||||
query: z.string().optional(),
|
||||
mode: z.enum(["host", "cidr"]).optional().catch(undefined)
|
||||
mode: z
|
||||
.enum(["host", "cidr"])
|
||||
.optional()
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["host", "cidr"],
|
||||
description: "Filter site resources by mode"
|
||||
}),
|
||||
sort_by: z
|
||||
.enum(["name"])
|
||||
.optional()
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["name"],
|
||||
description: "Field to sort by"
|
||||
}),
|
||||
order: z
|
||||
.enum(["asc", "desc"])
|
||||
.optional()
|
||||
.default("asc")
|
||||
.catch("asc")
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["asc", "desc"],
|
||||
default: "asc",
|
||||
description: "Sort order"
|
||||
})
|
||||
});
|
||||
|
||||
export type ListAllSiteResourcesByOrgResponse = PaginatedResponse<{
|
||||
@@ -60,6 +98,8 @@ function querySiteResourcesBase() {
|
||||
tcpPortRangeString: siteResources.tcpPortRangeString,
|
||||
udpPortRangeString: siteResources.udpPortRangeString,
|
||||
disableIcmp: siteResources.disableIcmp,
|
||||
authDaemonMode: siteResources.authDaemonMode,
|
||||
authDaemonPort: siteResources.authDaemonPort,
|
||||
siteName: sites.name,
|
||||
siteNiceId: sites.niceId,
|
||||
siteAddress: sites.address
|
||||
@@ -111,7 +151,8 @@ export async function listAllSiteResourcesByOrg(
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
const { page, pageSize, query, mode } = parsedQuery.data;
|
||||
const { page, pageSize, query, mode, sort_by, order } =
|
||||
parsedQuery.data;
|
||||
|
||||
const conditions = [and(eq(siteResources.orgId, orgId))];
|
||||
if (query) {
|
||||
@@ -159,7 +200,13 @@ export async function listAllSiteResourcesByOrg(
|
||||
baseQuery
|
||||
.limit(pageSize)
|
||||
.offset(pageSize * (page - 1))
|
||||
.orderBy(asc(siteResources.siteResourceId)),
|
||||
.orderBy(
|
||||
sort_by
|
||||
? order === "asc"
|
||||
? asc(siteResources[sort_by])
|
||||
: desc(siteResources[sort_by])
|
||||
: asc(siteResources.siteResourceId)
|
||||
),
|
||||
countQuery
|
||||
]);
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import { siteResources, sites, SiteResource } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { and, asc, desc, eq } from "drizzle-orm";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
@@ -27,7 +27,16 @@ const listSiteResourcesQuerySchema = z.object({
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative())
|
||||
.pipe(z.int().nonnegative()),
|
||||
sort_by: z
|
||||
.enum(["name"])
|
||||
.optional()
|
||||
.catch(undefined),
|
||||
order: z
|
||||
.enum(["asc", "desc"])
|
||||
.optional()
|
||||
.default("asc")
|
||||
.catch("asc")
|
||||
});
|
||||
|
||||
export type ListSiteResourcesResponse = {
|
||||
@@ -75,7 +84,7 @@ export async function listSiteResources(
|
||||
}
|
||||
|
||||
const { siteId, orgId } = parsedParams.data;
|
||||
const { limit, offset } = parsedQuery.data;
|
||||
const { limit, offset, sort_by, order } = parsedQuery.data;
|
||||
|
||||
// Verify the site exists and belongs to the org
|
||||
const site = await db
|
||||
@@ -98,6 +107,13 @@ export async function listSiteResources(
|
||||
eq(siteResources.orgId, orgId)
|
||||
)
|
||||
)
|
||||
.orderBy(
|
||||
sort_by
|
||||
? order === "asc"
|
||||
? asc(siteResources[sort_by])
|
||||
: desc(siteResources[sort_by])
|
||||
: asc(siteResources.siteResourceId)
|
||||
)
|
||||
.limit(limit)
|
||||
.offset(offset);
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ import {
|
||||
getClientSiteResourceAccess,
|
||||
rebuildClientAssociationsFromSiteResource
|
||||
} from "@server/lib/rebuildClientAssociations";
|
||||
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
|
||||
const updateSiteResourceParamsSchema = z.strictObject({
|
||||
siteResourceId: z.string().transform(Number).pipe(z.int().positive())
|
||||
@@ -41,6 +43,7 @@ const updateSiteResourceSchema = z
|
||||
.strictObject({
|
||||
name: z.string().min(1).max(255).optional(),
|
||||
siteId: z.int(),
|
||||
// niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(),
|
||||
// mode: z.enum(["host", "cidr", "port"]).optional(),
|
||||
mode: z.enum(["host", "cidr"]).optional(),
|
||||
// protocol: z.enum(["tcp", "udp"]).nullish(),
|
||||
@@ -60,7 +63,9 @@ const updateSiteResourceSchema = z
|
||||
clientIds: z.array(z.int()),
|
||||
tcpPortRangeString: portRangeStringSchema,
|
||||
udpPortRangeString: portRangeStringSchema,
|
||||
disableIcmp: z.boolean().optional()
|
||||
disableIcmp: z.boolean().optional(),
|
||||
authDaemonPort: z.int().positive().nullish(),
|
||||
authDaemonMode: z.enum(["site", "remote"]).optional()
|
||||
})
|
||||
.strict()
|
||||
.refine(
|
||||
@@ -171,7 +176,9 @@ export async function updateSiteResource(
|
||||
clientIds,
|
||||
tcpPortRangeString,
|
||||
udpPortRangeString,
|
||||
disableIcmp
|
||||
disableIcmp,
|
||||
authDaemonPort,
|
||||
authDaemonMode
|
||||
} = parsedBody.data;
|
||||
|
||||
const [site] = await db
|
||||
@@ -197,6 +204,11 @@ export async function updateSiteResource(
|
||||
);
|
||||
}
|
||||
|
||||
const isLicensedSshPam = await isLicensedOrSubscribed(
|
||||
existingSiteResource.orgId,
|
||||
tierMatrix.sshPam
|
||||
);
|
||||
|
||||
const [org] = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
@@ -307,6 +319,18 @@ export async function updateSiteResource(
|
||||
// wait some time to allow for messages to be handled
|
||||
await new Promise((resolve) => setTimeout(resolve, 750));
|
||||
|
||||
const sshPamSet =
|
||||
isLicensedSshPam &&
|
||||
(authDaemonPort !== undefined || authDaemonMode !== undefined)
|
||||
? {
|
||||
...(authDaemonPort !== undefined && {
|
||||
authDaemonPort
|
||||
}),
|
||||
...(authDaemonMode !== undefined && {
|
||||
authDaemonMode
|
||||
})
|
||||
}
|
||||
: {};
|
||||
[updatedSiteResource] = await trx
|
||||
.update(siteResources)
|
||||
.set({
|
||||
@@ -318,7 +342,8 @@ export async function updateSiteResource(
|
||||
alias: alias && alias.trim() ? alias : null,
|
||||
tcpPortRangeString: tcpPortRangeString,
|
||||
udpPortRangeString: udpPortRangeString,
|
||||
disableIcmp: disableIcmp
|
||||
disableIcmp: disableIcmp,
|
||||
...sshPamSet
|
||||
})
|
||||
.where(
|
||||
and(
|
||||
@@ -396,6 +421,18 @@ export async function updateSiteResource(
|
||||
);
|
||||
} else {
|
||||
// Update the site resource
|
||||
const sshPamSet =
|
||||
isLicensedSshPam &&
|
||||
(authDaemonPort !== undefined || authDaemonMode !== undefined)
|
||||
? {
|
||||
...(authDaemonPort !== undefined && {
|
||||
authDaemonPort
|
||||
}),
|
||||
...(authDaemonMode !== undefined && {
|
||||
authDaemonMode
|
||||
})
|
||||
}
|
||||
: {};
|
||||
[updatedSiteResource] = await trx
|
||||
.update(siteResources)
|
||||
.set({
|
||||
@@ -407,7 +444,8 @@ export async function updateSiteResource(
|
||||
alias: alias && alias.trim() ? alias : null,
|
||||
tcpPortRangeString: tcpPortRangeString,
|
||||
udpPortRangeString: udpPortRangeString,
|
||||
disableIcmp: disableIcmp
|
||||
disableIcmp: disableIcmp,
|
||||
...sshPamSet
|
||||
})
|
||||
.where(
|
||||
and(eq(siteResources.siteResourceId, siteResourceId))
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, UserOrg } from "@server/db";
|
||||
import { db, orgs, UserOrg } from "@server/db";
|
||||
import { roles, userInvites, userOrgs, users } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { eq, and, inArray, ne } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -14,6 +14,7 @@ import { usageService } from "@server/lib/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
|
||||
import { build } from "@server/build";
|
||||
import { assignUserToOrg } from "@server/lib/userOrg";
|
||||
|
||||
const acceptInviteBodySchema = z.strictObject({
|
||||
token: z.string(),
|
||||
@@ -125,8 +126,22 @@ export async function acceptInvite(
|
||||
}
|
||||
}
|
||||
|
||||
const [org] = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, existingInvite.orgId))
|
||||
.limit(1);
|
||||
|
||||
if (!org) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Organization does not exist. Please contact an admin."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let roleId: number;
|
||||
let totalUsers: UserOrg[] | undefined;
|
||||
// get the role to make sure it exists
|
||||
const existingRole = await db
|
||||
.select()
|
||||
@@ -146,12 +161,15 @@ export async function acceptInvite(
|
||||
}
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
// add the user to the org
|
||||
await trx.insert(userOrgs).values({
|
||||
userId: existingUser[0].userId,
|
||||
orgId: existingInvite.orgId,
|
||||
roleId: existingInvite.roleId
|
||||
});
|
||||
await assignUserToOrg(
|
||||
org,
|
||||
{
|
||||
userId: existingUser[0].userId,
|
||||
orgId: existingInvite.orgId,
|
||||
roleId: existingInvite.roleId
|
||||
},
|
||||
trx
|
||||
);
|
||||
|
||||
// delete the invite
|
||||
await trx
|
||||
@@ -160,25 +178,11 @@ export async function acceptInvite(
|
||||
|
||||
await calculateUserClientsForOrgs(existingUser[0].userId, trx);
|
||||
|
||||
// Get the total number of users in the org now
|
||||
totalUsers = await trx
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(eq(userOrgs.orgId, existingInvite.orgId));
|
||||
|
||||
logger.debug(
|
||||
`User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}. Total users in org: ${totalUsers.length}`
|
||||
`User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}`
|
||||
);
|
||||
});
|
||||
|
||||
if (totalUsers) {
|
||||
await usageService.updateCount(
|
||||
existingInvite.orgId,
|
||||
FeatureId.USERS,
|
||||
totalUsers.length
|
||||
);
|
||||
}
|
||||
|
||||
return response<AcceptInviteResponse>(res, {
|
||||
data: { accepted: true, orgId: existingInvite.orgId },
|
||||
success: true,
|
||||
|
||||
@@ -6,8 +6,8 @@ import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { db, UserOrg } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { db, orgs, UserOrg } from "@server/db";
|
||||
import { and, eq, inArray, ne } from "drizzle-orm";
|
||||
import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db";
|
||||
import { generateId } from "@server/auth/sessions/app";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
@@ -16,6 +16,7 @@ import { build } from "@server/build";
|
||||
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
|
||||
import { isSubscribed } from "#dynamic/lib/isSubscribed";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
import { assignUserToOrg } from "@server/lib/userOrg";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string().nonempty()
|
||||
@@ -151,6 +152,21 @@ export async function createOrgUser(
|
||||
);
|
||||
}
|
||||
|
||||
const [org] = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
if (!org) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"Organization not found"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [idpRes] = await db
|
||||
.select()
|
||||
.from(idp)
|
||||
@@ -172,8 +188,6 @@ export async function createOrgUser(
|
||||
);
|
||||
}
|
||||
|
||||
let orgUsers: UserOrg[] | undefined;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
const [existingUser] = await trx
|
||||
.select()
|
||||
@@ -207,15 +221,12 @@ export async function createOrgUser(
|
||||
);
|
||||
}
|
||||
|
||||
await trx
|
||||
.insert(userOrgs)
|
||||
.values({
|
||||
orgId,
|
||||
userId: existingUser.userId,
|
||||
roleId: role.roleId,
|
||||
autoProvisioned: false
|
||||
})
|
||||
.returning();
|
||||
await assignUserToOrg(org, {
|
||||
orgId,
|
||||
userId: existingUser.userId,
|
||||
roleId: role.roleId,
|
||||
autoProvisioned: false
|
||||
}, trx);
|
||||
} else {
|
||||
userId = generateId(15);
|
||||
|
||||
@@ -233,33 +244,16 @@ export async function createOrgUser(
|
||||
})
|
||||
.returning();
|
||||
|
||||
await trx
|
||||
.insert(userOrgs)
|
||||
.values({
|
||||
await assignUserToOrg(org, {
|
||||
orgId,
|
||||
userId: newUser.userId,
|
||||
roleId: role.roleId,
|
||||
autoProvisioned: false
|
||||
})
|
||||
.returning();
|
||||
}, trx);
|
||||
}
|
||||
|
||||
// List all of the users in the org
|
||||
orgUsers = await trx
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(eq(userOrgs.orgId, orgId));
|
||||
|
||||
await calculateUserClientsForOrgs(userId, trx);
|
||||
});
|
||||
|
||||
if (orgUsers) {
|
||||
await usageService.updateCount(
|
||||
orgId,
|
||||
FeatureId.USERS,
|
||||
orgUsers.length
|
||||
);
|
||||
}
|
||||
} else {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "User type is required")
|
||||
|
||||
@@ -11,7 +11,7 @@ import { fromError } from "zod-validation-error";
|
||||
import { ActionsEnum, checkUserActionPermission } from "@server/auth/actions";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
async function queryUser(orgId: string, userId: string) {
|
||||
export async function queryUser(orgId: string, userId: string) {
|
||||
const [user] = await db
|
||||
.select({
|
||||
orgId: userOrgs.orgId,
|
||||
|
||||
136
server/routers/user/getOrgUserByUsername.ts
Normal file
136
server/routers/user/getOrgUserByUsername.ts
Normal file
@@ -0,0 +1,136 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { userOrgs, users } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { queryUser, type GetOrgUserResponse } from "./getOrgUser";
|
||||
|
||||
const getOrgUserByUsernameParamsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const getOrgUserByUsernameQuerySchema = z.strictObject({
|
||||
username: z.string().min(1, "username is required"),
|
||||
idpId: z
|
||||
.string()
|
||||
.optional()
|
||||
.transform((v) =>
|
||||
v === undefined || v === "" ? undefined : parseInt(v, 10)
|
||||
)
|
||||
.refine(
|
||||
(v) =>
|
||||
v === undefined || (Number.isInteger(v) && (v as number) > 0),
|
||||
{ message: "idpId must be a positive integer" }
|
||||
)
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/user-by-username",
|
||||
description:
|
||||
"Get a user in an organization by username. When idpId is not passed, only internal users are searched (username is globally unique for them). For external (OIDC) users, pass idpId to search by username within that identity provider.",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.User],
|
||||
request: {
|
||||
params: getOrgUserByUsernameParamsSchema,
|
||||
query: getOrgUserByUsernameQuerySchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function getOrgUserByUsername(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = getOrgUserByUsernameParamsSchema.safeParse(
|
||||
req.params
|
||||
);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const parsedQuery = getOrgUserByUsernameQuerySchema.safeParse(
|
||||
req.query
|
||||
);
|
||||
if (!parsedQuery.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedQuery.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
const { username, idpId } = parsedQuery.data;
|
||||
|
||||
const conditions = [
|
||||
eq(userOrgs.orgId, orgId),
|
||||
eq(users.username, username)
|
||||
];
|
||||
if (idpId !== undefined) {
|
||||
conditions.push(eq(users.idpId, idpId));
|
||||
} else {
|
||||
conditions.push(eq(users.type, "internal"));
|
||||
}
|
||||
|
||||
const candidates = await db
|
||||
.select({ userId: users.userId })
|
||||
.from(userOrgs)
|
||||
.innerJoin(users, eq(userOrgs.userId, users.userId))
|
||||
.where(and(...conditions));
|
||||
|
||||
if (candidates.length === 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`User with username '${username}' not found in organization`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (candidates.length > 1) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Multiple users with this username (external users from different identity providers). Specify idpId (identity provider ID) to disambiguate. When not specified, this searches for internal users only."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const user = await queryUser(orgId, candidates[0].userId);
|
||||
if (!user) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`User with username '${username}' not found in organization`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return response<GetOrgUserResponse>(res, {
|
||||
data: user,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "User retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ export * from "./addUserRole";
|
||||
export * from "./inviteUser";
|
||||
export * from "./acceptInvite";
|
||||
export * from "./getOrgUser";
|
||||
export * from "./getOrgUserByUsername";
|
||||
export * from "./adminListUsers";
|
||||
export * from "./adminRemoveUser";
|
||||
export * from "./adminGetUser";
|
||||
|
||||
@@ -19,7 +19,7 @@ import { UserType } from "@server/types/UserTypes";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
import { build } from "@server/build";
|
||||
import cache from "@server/lib/cache";
|
||||
import cache from "#dynamic/lib/cache";
|
||||
|
||||
const inviteUserParamsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
@@ -191,7 +191,7 @@ export async function inviteUser(
|
||||
}
|
||||
|
||||
if (existingInvite.length) {
|
||||
const attempts = cache.get<number>(email) || 0;
|
||||
const attempts = (await cache.get<number>(email)) || 0;
|
||||
if (attempts >= 3) {
|
||||
return next(
|
||||
createHttpError(
|
||||
@@ -201,7 +201,7 @@ export async function inviteUser(
|
||||
);
|
||||
}
|
||||
|
||||
cache.set(email, attempts + 1);
|
||||
await cache.set(email, attempts + 1);
|
||||
|
||||
const inviteId = existingInvite[0].inviteId; // Retrieve the original inviteId
|
||||
const token = generateRandomString(
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, resources, sites, UserOrg } from "@server/db";
|
||||
import {
|
||||
db,
|
||||
orgs,
|
||||
resources,
|
||||
siteResources,
|
||||
sites,
|
||||
UserOrg,
|
||||
userSiteResources
|
||||
} from "@server/db";
|
||||
import { userOrgs, userResources, users, userSites } from "@server/db";
|
||||
import { and, count, eq, exists } from "drizzle-orm";
|
||||
import { and, count, eq, exists, inArray } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -14,6 +22,7 @@ import { FeatureId } from "@server/lib/billing";
|
||||
import { build } from "@server/build";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
|
||||
import { removeUserFromOrg } from "@server/lib/userOrg";
|
||||
|
||||
const removeUserSchema = z.strictObject({
|
||||
userId: z.string(),
|
||||
@@ -50,16 +59,16 @@ export async function removeUserOrg(
|
||||
const { userId, orgId } = parsedParams.data;
|
||||
|
||||
// get the user first
|
||||
const user = await db
|
||||
const [user] = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)));
|
||||
|
||||
if (!user || user.length === 0) {
|
||||
if (!user) {
|
||||
return next(createHttpError(HttpCode.NOT_FOUND, "User not found"));
|
||||
}
|
||||
|
||||
if (user[0].isOwner) {
|
||||
if (user.isOwner) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
@@ -68,56 +77,20 @@ export async function removeUserOrg(
|
||||
);
|
||||
}
|
||||
|
||||
let userCount: UserOrg[] | undefined;
|
||||
const [org] = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
if (!org) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
|
||||
);
|
||||
}
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
await trx
|
||||
.delete(userOrgs)
|
||||
.where(
|
||||
and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId))
|
||||
);
|
||||
|
||||
await db.delete(userResources).where(
|
||||
and(
|
||||
eq(userResources.userId, userId),
|
||||
exists(
|
||||
db
|
||||
.select()
|
||||
.from(resources)
|
||||
.where(
|
||||
and(
|
||||
eq(
|
||||
resources.resourceId,
|
||||
userResources.resourceId
|
||||
),
|
||||
eq(resources.orgId, orgId)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
await db.delete(userSites).where(
|
||||
and(
|
||||
eq(userSites.userId, userId),
|
||||
exists(
|
||||
db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(
|
||||
and(
|
||||
eq(sites.siteId, userSites.siteId),
|
||||
eq(sites.orgId, orgId)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
userCount = await trx
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(eq(userOrgs.orgId, orgId));
|
||||
await removeUserFromOrg(org, userId, trx);
|
||||
|
||||
// if (build === "saas") {
|
||||
// const [rootUser] = await trx
|
||||
@@ -139,14 +112,6 @@ export async function removeUserOrg(
|
||||
await calculateUserClientsForOrgs(userId, trx);
|
||||
});
|
||||
|
||||
if (userCount) {
|
||||
await usageService.updateCount(
|
||||
orgId,
|
||||
FeatureId.USERS,
|
||||
userCount.length
|
||||
);
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
|
||||
85
server/routers/ws/checkRoundTripMessage.ts
Normal file
85
server/routers/ws/checkRoundTripMessage.ts
Normal file
@@ -0,0 +1,85 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, roundTripMessageTracker } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const checkRoundTripMessageParamsSchema = z
|
||||
.object({
|
||||
messageId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().positive())
|
||||
})
|
||||
.strict();
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "get",
|
||||
// path: "/ws/round-trip-message/{messageId}",
|
||||
// description:
|
||||
// "Check if a round trip message has been completed by checking the roundTripMessageTracker table",
|
||||
// tags: [OpenAPITags.WebSocket],
|
||||
// request: {
|
||||
// params: checkRoundTripMessageParamsSchema
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export async function checkRoundTripMessage(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = checkRoundTripMessageParamsSchema.safeParse(
|
||||
req.params
|
||||
);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { messageId } = parsedParams.data;
|
||||
|
||||
// Get the round trip message from the tracker
|
||||
const [message] = await db
|
||||
.select()
|
||||
.from(roundTripMessageTracker)
|
||||
.where(eq(roundTripMessageTracker.messageId, messageId))
|
||||
.limit(1);
|
||||
|
||||
if (!message) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Message not found")
|
||||
);
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: {
|
||||
messageId: message.messageId,
|
||||
complete: message.complete,
|
||||
sentAt: message.sentAt,
|
||||
receivedAt: message.receivedAt,
|
||||
error: message.error,
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Round trip message status retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
49
server/routers/ws/handleRoundTripMessage.ts
Normal file
49
server/routers/ws/handleRoundTripMessage.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { db, roundTripMessageTracker } from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { eq } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
|
||||
interface RoundTripCompleteMessage {
|
||||
messageId: number;
|
||||
complete: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export const handleRoundTripMessage: MessageHandler = async (
|
||||
context
|
||||
) => {
|
||||
const { message, client: c } = context;
|
||||
|
||||
logger.info("Handling round trip message");
|
||||
|
||||
const data = message.data as RoundTripCompleteMessage;
|
||||
|
||||
try {
|
||||
const { messageId, complete, error } = data;
|
||||
|
||||
if (!messageId) {
|
||||
logger.error("Round trip message missing messageId");
|
||||
return;
|
||||
}
|
||||
|
||||
// Update the roundTripMessageTracker with completion status
|
||||
await db
|
||||
.update(roundTripMessageTracker)
|
||||
.set({
|
||||
complete: complete,
|
||||
receivedAt: Math.floor(Date.now() / 1000),
|
||||
error: error || null
|
||||
})
|
||||
.where(eq(roundTripMessageTracker.messageId, messageId));
|
||||
|
||||
logger.info(`Round trip message ${messageId} marked as complete: ${complete}`);
|
||||
|
||||
if (error) {
|
||||
logger.warn(`Round trip message ${messageId} completed with error: ${error}`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Error processing round trip message:", error);
|
||||
}
|
||||
|
||||
return;
|
||||
};
|
||||
@@ -1,2 +1,3 @@
|
||||
export * from "./ws";
|
||||
export * from "./types";
|
||||
export * from "./checkRoundTripMessage";
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
handleOlmDisconnecingMessage
|
||||
} from "../olm";
|
||||
import { handleHealthcheckStatusMessage } from "../target";
|
||||
import { handleRoundTripMessage } from "./handleRoundTripMessage";
|
||||
import { MessageHandler } from "./types";
|
||||
|
||||
export const messageHandlers: Record<string, MessageHandler> = {
|
||||
@@ -35,7 +36,8 @@ export const messageHandlers: Record<string, MessageHandler> = {
|
||||
"newt/socket/containers": handleDockerContainersMessage,
|
||||
"newt/ping/request": handleNewtPingRequestMessage,
|
||||
"newt/blueprint/apply": handleApplyBlueprintMessage,
|
||||
"newt/healthcheck/status": handleHealthcheckStatusMessage
|
||||
"newt/healthcheck/status": handleHealthcheckStatusMessage,
|
||||
"ws/round-trip/complete": handleRoundTripMessage
|
||||
};
|
||||
|
||||
startOlmOfflineChecker(); // this is to handle the offline check for olms
|
||||
|
||||
Reference in New Issue
Block a user