Format all files

This commit is contained in:
Owen
2025-12-09 10:56:14 -05:00
parent fa839a811f
commit f9b03943c3
535 changed files with 7670 additions and 5626 deletions

View File

@@ -1,6 +1,3 @@
{
"extends": [
"next/core-web-vitals",
"next/typescript"
]
"extends": ["next/core-web-vitals", "next/typescript"]
}

View File

@@ -1,3 +1,3 @@
{
"recommendations": ["esbenp.prettier-vscode"]
}
}

View File

@@ -19,4 +19,4 @@
"editor.defaultFormatter": "esbenp.prettier-vscode"
},
"editor.formatOnSave": true
}
}

View File

@@ -17,4 +17,4 @@
"lib": "@/lib",
"hooks": "@/hooks"
}
}
}

View File

@@ -1,9 +1,7 @@
import { defineConfig } from "drizzle-kit";
import path from "path";
const schema = [
path.join("server", "db", "pg", "schema"),
];
const schema = [path.join("server", "db", "pg", "schema")];
export default defineConfig({
dialect: "postgresql",

View File

@@ -2,9 +2,7 @@ import { APP_PATH } from "@server/lib/consts";
import { defineConfig } from "drizzle-kit";
import path from "path";
const schema = [
path.join("server", "db", "sqlite", "schema"),
];
const schema = [path.join("server", "db", "sqlite", "schema")];
export default defineConfig({
dialect: "sqlite",

View File

@@ -24,20 +24,20 @@ const argv = yargs(hideBin(process.argv))
alias: "e",
describe: "Entry point file",
type: "string",
demandOption: true,
demandOption: true
})
.option("out", {
alias: "o",
describe: "Output file path",
type: "string",
demandOption: true,
demandOption: true
})
.option("build", {
alias: "b",
describe: "Build type (oss, saas, enterprise)",
type: "string",
choices: ["oss", "saas", "enterprise"],
default: "oss",
default: "oss"
})
.help()
.alias("help", "h").argv;
@@ -66,7 +66,9 @@ function privateImportGuardPlugin() {
// Check if the importing file is NOT in server/private
const normalizedImporter = path.normalize(importingFile);
const isInServerPrivate = normalizedImporter.includes(path.normalize("server/private"));
const isInServerPrivate = normalizedImporter.includes(
path.normalize("server/private")
);
if (!isInServerPrivate) {
const violation = {
@@ -79,8 +81,8 @@ function privateImportGuardPlugin() {
console.log(`PRIVATE IMPORT VIOLATION:`);
console.log(` File: ${importingFile}`);
console.log(` Import: ${args.path}`);
console.log(` Resolve dir: ${args.resolveDir || 'N/A'}`);
console.log('');
console.log(` Resolve dir: ${args.resolveDir || "N/A"}`);
console.log("");
}
// Return null to let the default resolver handle it
@@ -89,16 +91,20 @@ function privateImportGuardPlugin() {
build.onEnd((result) => {
if (violations.length > 0) {
console.log(`\nSUMMARY: Found ${violations.length} private import violation(s):`);
console.log(
`\nSUMMARY: Found ${violations.length} private import violation(s):`
);
violations.forEach((v, i) => {
console.log(` ${i + 1}. ${path.relative(process.cwd(), v.file)} imports ${v.importPath}`);
console.log(
` ${i + 1}. ${path.relative(process.cwd(), v.file)} imports ${v.importPath}`
);
});
console.log('');
console.log("");
result.errors.push({
text: `Private import violations detected: ${violations.length} violation(s) found`,
location: null,
notes: violations.map(v => ({
notes: violations.map((v) => ({
text: `${path.relative(process.cwd(), v.file)} imports ${v.importPath}`,
location: null
}))
@@ -121,7 +127,9 @@ function dynamicImportGuardPlugin() {
// Check if the importing file is NOT in server/private
const normalizedImporter = path.normalize(importingFile);
const isInServerPrivate = normalizedImporter.includes(path.normalize("server/private"));
const isInServerPrivate = normalizedImporter.includes(
path.normalize("server/private")
);
if (isInServerPrivate) {
const violation = {
@@ -134,8 +142,8 @@ function dynamicImportGuardPlugin() {
console.log(`DYNAMIC IMPORT VIOLATION:`);
console.log(` File: ${importingFile}`);
console.log(` Import: ${args.path}`);
console.log(` Resolve dir: ${args.resolveDir || 'N/A'}`);
console.log('');
console.log(` Resolve dir: ${args.resolveDir || "N/A"}`);
console.log("");
}
// Return null to let the default resolver handle it
@@ -144,16 +152,20 @@ function dynamicImportGuardPlugin() {
build.onEnd((result) => {
if (violations.length > 0) {
console.log(`\nSUMMARY: Found ${violations.length} dynamic import violation(s):`);
console.log(
`\nSUMMARY: Found ${violations.length} dynamic import violation(s):`
);
violations.forEach((v, i) => {
console.log(` ${i + 1}. ${path.relative(process.cwd(), v.file)} imports ${v.importPath}`);
console.log(
` ${i + 1}. ${path.relative(process.cwd(), v.file)} imports ${v.importPath}`
);
});
console.log('');
console.log("");
result.errors.push({
text: `Dynamic import violations detected: ${violations.length} violation(s) found`,
location: null,
notes: violations.map(v => ({
notes: violations.map((v) => ({
text: `${path.relative(process.cwd(), v.file)} imports ${v.importPath}`,
location: null
}))
@@ -172,21 +184,28 @@ function dynamicImportSwitcherPlugin(buildValue) {
const switches = [];
build.onStart(() => {
console.log(`Dynamic import switcher using build type: ${buildValue}`);
console.log(
`Dynamic import switcher using build type: ${buildValue}`
);
});
build.onResolve({ filter: /^#dynamic\// }, (args) => {
// Extract the path after #dynamic/
const dynamicPath = args.path.replace(/^#dynamic\//, '');
const dynamicPath = args.path.replace(/^#dynamic\//, "");
// Determine the replacement based on build type
let replacement;
if (buildValue === "oss") {
replacement = `#open/${dynamicPath}`;
} else if (buildValue === "saas" || buildValue === "enterprise") {
} else if (
buildValue === "saas" ||
buildValue === "enterprise"
) {
replacement = `#closed/${dynamicPath}`; // We use #closed here so that the route guards dont complain after its been changed but this is the same as #private
} else {
console.warn(`Unknown build type '${buildValue}', defaulting to #open/`);
console.warn(
`Unknown build type '${buildValue}', defaulting to #open/`
);
replacement = `#open/${dynamicPath}`;
}
@@ -201,8 +220,10 @@ function dynamicImportSwitcherPlugin(buildValue) {
console.log(`DYNAMIC IMPORT SWITCH:`);
console.log(` File: ${args.importer}`);
console.log(` Original: ${args.path}`);
console.log(` Switched to: ${replacement} (build: ${buildValue})`);
console.log('');
console.log(
` Switched to: ${replacement} (build: ${buildValue})`
);
console.log("");
// Rewrite the import path and let the normal resolution continue
return build.resolve(replacement, {
@@ -215,12 +236,18 @@ function dynamicImportSwitcherPlugin(buildValue) {
build.onEnd((result) => {
if (switches.length > 0) {
console.log(`\nDYNAMIC IMPORT SUMMARY: Switched ${switches.length} import(s) for build type '${buildValue}':`);
console.log(
`\nDYNAMIC IMPORT SUMMARY: Switched ${switches.length} import(s) for build type '${buildValue}':`
);
switches.forEach((s, i) => {
console.log(` ${i + 1}. ${path.relative(process.cwd(), s.file)}`);
console.log(` ${s.originalPath} ${s.replacementPath}`);
console.log(
` ${i + 1}. ${path.relative(process.cwd(), s.file)}`
);
console.log(
` ${s.originalPath}${s.replacementPath}`
);
});
console.log('');
console.log("");
}
});
}
@@ -235,7 +262,7 @@ esbuild
format: "esm",
minify: false,
banner: {
js: banner,
js: banner
},
platform: "node",
external: ["body-parser"],
@@ -244,20 +271,22 @@ esbuild
dynamicImportGuardPlugin(),
dynamicImportSwitcherPlugin(argv.build),
nodeExternalsPlugin({
packagePath: getPackagePaths(),
}),
packagePath: getPackagePaths()
})
],
sourcemap: "inline",
target: "node22",
target: "node22"
})
.then((result) => {
// Check if there were any errors in the build result
if (result.errors && result.errors.length > 0) {
console.error(`Build failed with ${result.errors.length} error(s):`);
console.error(
`Build failed with ${result.errors.length} error(s):`
);
result.errors.forEach((error, i) => {
console.error(`${i + 1}. ${error.text}`);
if (error.notes) {
error.notes.forEach(note => {
error.notes.forEach((note) => {
console.error(` - ${note.text}`);
});
}

View File

@@ -1,19 +1,19 @@
import tseslint from 'typescript-eslint';
import tseslint from "typescript-eslint";
export default tseslint.config({
files: ["**/*.{ts,tsx,js,jsx}"],
languageOptions: {
parser: tseslint.parser,
parserOptions: {
ecmaVersion: "latest",
sourceType: "module",
ecmaFeatures: {
jsx: true
}
files: ["**/*.{ts,tsx,js,jsx}"],
languageOptions: {
parser: tseslint.parser,
parserOptions: {
ecmaVersion: "latest",
sourceType: "module",
ecmaFeatures: {
jsx: true
}
}
},
rules: {
semi: "error",
"prefer-const": "warn"
}
},
rules: {
"semi": "error",
"prefer-const": "warn"
}
});
});

View File

@@ -1,8 +1,8 @@
/** @type {import('postcss-load-config').Config} */
const config = {
plugins: {
"@tailwindcss/postcss": {},
},
"@tailwindcss/postcss": {}
}
};
export default config;

View File

@@ -2,13 +2,13 @@ import { hash, verify } from "@node-rs/argon2";
export async function verifyPassword(
password: string,
hash: string,
hash: string
): Promise<boolean> {
const validPassword = await verify(hash, password, {
memoryCost: 19456,
timeCost: 2,
outputLen: 32,
parallelism: 1,
parallelism: 1
});
return validPassword;
}
@@ -18,7 +18,7 @@ export async function hashPassword(password: string): Promise<string> {
memoryCost: 19456,
timeCost: 2,
outputLen: 32,
parallelism: 1,
parallelism: 1
});
return passwordHash;

View File

@@ -4,10 +4,13 @@ export const passwordSchema = z
.string()
.min(8, { message: "Password must be at least 8 characters long" })
.max(128, { message: "Password must be at most 128 characters long" })
.regex(/^(?=.*?[A-Z])(?=.*?[a-z])(?=.*?[0-9])(?=.*?[~!`@#$%^&*()_\-+={}[\]|\\:;"'<>,.\/?]).*$/, {
message: `Your password must meet the following conditions:
.regex(
/^(?=.*?[A-Z])(?=.*?[a-z])(?=.*?[0-9])(?=.*?[~!`@#$%^&*()_\-+={}[\]|\\:;"'<>,.\/?]).*$/,
{
message: `Your password must meet the following conditions:
at least one uppercase English letter,
at least one lowercase English letter,
at least one digit,
at least one special character.`
});
}
);

View File

@@ -1,6 +1,4 @@
import {
encodeHexLowerCase,
} from "@oslojs/encoding";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { Newt, newts, newtSessions, NewtSession } from "@server/db";
import { db } from "@server/db";
@@ -10,25 +8,25 @@ export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
export async function createNewtSession(
token: string,
newtId: string,
newtId: string
): Promise<NewtSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const session: NewtSession = {
sessionId: sessionId,
newtId,
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
expiresAt: new Date(Date.now() + EXPIRES).getTime()
};
await db.insert(newtSessions).values(session);
return session;
}
export async function validateNewtSessionToken(
token: string,
token: string
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const result = await db
.select({ newt: newts, session: newtSessions })
@@ -45,14 +43,12 @@ export async function validateNewtSessionToken(
.where(eq(newtSessions.sessionId, session.sessionId));
return { session: null, newt: null };
}
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + EXPIRES,
).getTime();
if (Date.now() >= session.expiresAt - EXPIRES / 2) {
session.expiresAt = new Date(Date.now() + EXPIRES).getTime();
await db
.update(newtSessions)
.set({
expiresAt: session.expiresAt,
expiresAt: session.expiresAt
})
.where(eq(newtSessions.sessionId, session.sessionId));
}

View File

@@ -1,6 +1,4 @@
import {
encodeHexLowerCase,
} from "@oslojs/encoding";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { Olm, olms, olmSessions, OlmSession } from "@server/db";
import { db } from "@server/db";
@@ -10,25 +8,25 @@ export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
export async function createOlmSession(
token: string,
olmId: string,
olmId: string
): Promise<OlmSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const session: OlmSession = {
sessionId: sessionId,
olmId,
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
expiresAt: new Date(Date.now() + EXPIRES).getTime()
};
await db.insert(olmSessions).values(session);
return session;
}
export async function validateOlmSessionToken(
token: string,
token: string
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const result = await db
.select({ olm: olms, session: olmSessions })
@@ -45,14 +43,12 @@ export async function validateOlmSessionToken(
.where(eq(olmSessions.sessionId, session.sessionId));
return { session: null, olm: null };
}
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + EXPIRES,
).getTime();
if (Date.now() >= session.expiresAt - EXPIRES / 2) {
session.expiresAt = new Date(Date.now() + EXPIRES).getTime();
await db
.update(olmSessions)
.set({
expiresAt: session.expiresAt,
expiresAt: session.expiresAt
})
.where(eq(olmSessions.sessionId, session.sessionId));
}

View File

@@ -10,4 +10,4 @@ export async function initCleanup() {
// Handle process termination
process.on("SIGTERM", () => cleanup());
process.on("SIGINT", () => cleanup());
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1708,4 +1708,4 @@
"Desert Box Turtle",
"African Striped Weasel"
]
}
}

View File

@@ -215,42 +215,56 @@ export const sessionTransferToken = pgTable("sessionTransferToken", {
expiresAt: bigint("expiresAt", { mode: "number" }).notNull()
});
export const actionAuditLog = pgTable("actionAuditLog", {
id: serial("id").primaryKey(),
timestamp: bigint("timestamp", { mode: "number" }).notNull(), // this is EPOCH time in seconds
orgId: varchar("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
actorType: varchar("actorType", { length: 50 }).notNull(),
actor: varchar("actor", { length: 255 }).notNull(),
actorId: varchar("actorId", { length: 255 }).notNull(),
action: varchar("action", { length: 100 }).notNull(),
metadata: text("metadata")
}, (table) => ([
index("idx_actionAuditLog_timestamp").on(table.timestamp),
index("idx_actionAuditLog_org_timestamp").on(table.orgId, table.timestamp)
]));
export const actionAuditLog = pgTable(
"actionAuditLog",
{
id: serial("id").primaryKey(),
timestamp: bigint("timestamp", { mode: "number" }).notNull(), // this is EPOCH time in seconds
orgId: varchar("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
actorType: varchar("actorType", { length: 50 }).notNull(),
actor: varchar("actor", { length: 255 }).notNull(),
actorId: varchar("actorId", { length: 255 }).notNull(),
action: varchar("action", { length: 100 }).notNull(),
metadata: text("metadata")
},
(table) => [
index("idx_actionAuditLog_timestamp").on(table.timestamp),
index("idx_actionAuditLog_org_timestamp").on(
table.orgId,
table.timestamp
)
]
);
export const accessAuditLog = pgTable("accessAuditLog", {
id: serial("id").primaryKey(),
timestamp: bigint("timestamp", { mode: "number" }).notNull(), // this is EPOCH time in seconds
orgId: varchar("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
actorType: varchar("actorType", { length: 50 }),
actor: varchar("actor", { length: 255 }),
actorId: varchar("actorId", { length: 255 }),
resourceId: integer("resourceId"),
ip: varchar("ip", { length: 45 }),
type: varchar("type", { length: 100 }).notNull(),
action: boolean("action").notNull(),
location: text("location"),
userAgent: text("userAgent"),
metadata: text("metadata")
}, (table) => ([
index("idx_identityAuditLog_timestamp").on(table.timestamp),
index("idx_identityAuditLog_org_timestamp").on(table.orgId, table.timestamp)
]));
export const accessAuditLog = pgTable(
"accessAuditLog",
{
id: serial("id").primaryKey(),
timestamp: bigint("timestamp", { mode: "number" }).notNull(), // this is EPOCH time in seconds
orgId: varchar("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
actorType: varchar("actorType", { length: 50 }),
actor: varchar("actor", { length: 255 }),
actorId: varchar("actorId", { length: 255 }),
resourceId: integer("resourceId"),
ip: varchar("ip", { length: 45 }),
type: varchar("type", { length: 100 }).notNull(),
action: boolean("action").notNull(),
location: text("location"),
userAgent: text("userAgent"),
metadata: text("metadata")
},
(table) => [
index("idx_identityAuditLog_timestamp").on(table.timestamp),
index("idx_identityAuditLog_org_timestamp").on(
table.orgId,
table.timestamp
)
]
);
export type Limit = InferSelectModel<typeof limits>;
export type Account = InferSelectModel<typeof account>;
@@ -270,4 +284,4 @@ export type RemoteExitNodeSession = InferSelectModel<
export type ExitNodeOrg = InferSelectModel<typeof exitNodeOrgs>;
export type LoginPage = InferSelectModel<typeof loginPage>;
export type ActionAuditLog = InferSelectModel<typeof actionAuditLog>;
export type AccessAuditLog = InferSelectModel<typeof accessAuditLog>;
export type AccessAuditLog = InferSelectModel<typeof accessAuditLog>;

View File

@@ -177,7 +177,7 @@ export const targetHealthCheck = pgTable("targetHealthCheck", {
hcMethod: varchar("hcMethod").default("GET"),
hcStatus: integer("hcStatus"), // http code
hcHealth: text("hcHealth").default("unknown"), // "unknown", "healthy", "unhealthy"
hcTlsServerName: text("hcTlsServerName"),
hcTlsServerName: text("hcTlsServerName")
});
export const exitNodes = pgTable("exitNodes", {

View File

@@ -52,10 +52,7 @@ export async function getResourceByDomain(
resourceHeaderAuth,
eq(resourceHeaderAuth.resourceId, resources.resourceId)
)
.innerJoin(
orgs,
eq(orgs.orgId, resources.orgId)
)
.innerJoin(orgs, eq(orgs.orgId, resources.orgId))
.where(eq(resources.fullDomain, domain))
.limit(1);

View File

@@ -8,7 +8,7 @@ const runMigrations = async () => {
console.log("Running migrations...");
try {
migrate(db as any, {
migrationsFolder: migrationsFolder,
migrationsFolder: migrationsFolder
});
console.log("Migrations completed successfully.");
} catch (error) {

View File

@@ -29,7 +29,9 @@ export const certificates = sqliteTable("certificates", {
});
export const dnsChallenge = sqliteTable("dnsChallenges", {
dnsChallengeId: integer("dnsChallengeId").primaryKey({ autoIncrement: true }),
dnsChallengeId: integer("dnsChallengeId").primaryKey({
autoIncrement: true
}),
domain: text("domain").notNull(),
token: text("token").notNull(),
keyAuthorization: text("keyAuthorization").notNull(),
@@ -61,9 +63,7 @@ export const customers = sqliteTable("customers", {
});
export const subscriptions = sqliteTable("subscriptions", {
subscriptionId: text("subscriptionId")
.primaryKey()
.notNull(),
subscriptionId: text("subscriptionId").primaryKey().notNull(),
customerId: text("customerId")
.notNull()
.references(() => customers.customerId, { onDelete: "cascade" }),
@@ -75,7 +75,9 @@ export const subscriptions = sqliteTable("subscriptions", {
});
export const subscriptionItems = sqliteTable("subscriptionItems", {
subscriptionItemId: integer("subscriptionItemId").primaryKey({ autoIncrement: true }),
subscriptionItemId: integer("subscriptionItemId").primaryKey({
autoIncrement: true
}),
subscriptionId: text("subscriptionId")
.notNull()
.references(() => subscriptions.subscriptionId, {
@@ -129,7 +131,9 @@ export const limits = sqliteTable("limits", {
});
export const usageNotifications = sqliteTable("usageNotifications", {
notificationId: integer("notificationId").primaryKey({ autoIncrement: true }),
notificationId: integer("notificationId").primaryKey({
autoIncrement: true
}),
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
@@ -210,42 +214,56 @@ export const sessionTransferToken = sqliteTable("sessionTransferToken", {
expiresAt: integer("expiresAt").notNull()
});
export const actionAuditLog = sqliteTable("actionAuditLog", {
id: integer("id").primaryKey({ autoIncrement: true }),
timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
actorType: text("actorType").notNull(),
actor: text("actor").notNull(),
actorId: text("actorId").notNull(),
action: text("action").notNull(),
metadata: text("metadata")
}, (table) => ([
index("idx_actionAuditLog_timestamp").on(table.timestamp),
index("idx_actionAuditLog_org_timestamp").on(table.orgId, table.timestamp)
]));
export const actionAuditLog = sqliteTable(
"actionAuditLog",
{
id: integer("id").primaryKey({ autoIncrement: true }),
timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
actorType: text("actorType").notNull(),
actor: text("actor").notNull(),
actorId: text("actorId").notNull(),
action: text("action").notNull(),
metadata: text("metadata")
},
(table) => [
index("idx_actionAuditLog_timestamp").on(table.timestamp),
index("idx_actionAuditLog_org_timestamp").on(
table.orgId,
table.timestamp
)
]
);
export const accessAuditLog = sqliteTable("accessAuditLog", {
id: integer("id").primaryKey({ autoIncrement: true }),
timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
actorType: text("actorType"),
actor: text("actor"),
actorId: text("actorId"),
resourceId: integer("resourceId"),
ip: text("ip"),
location: text("location"),
type: text("type").notNull(),
action: integer("action", { mode: "boolean" }).notNull(),
userAgent: text("userAgent"),
metadata: text("metadata")
}, (table) => ([
index("idx_identityAuditLog_timestamp").on(table.timestamp),
index("idx_identityAuditLog_org_timestamp").on(table.orgId, table.timestamp)
]));
export const accessAuditLog = sqliteTable(
"accessAuditLog",
{
id: integer("id").primaryKey({ autoIncrement: true }),
timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
actorType: text("actorType"),
actor: text("actor"),
actorId: text("actorId"),
resourceId: integer("resourceId"),
ip: text("ip"),
location: text("location"),
type: text("type").notNull(),
action: integer("action", { mode: "boolean" }).notNull(),
userAgent: text("userAgent"),
metadata: text("metadata")
},
(table) => [
index("idx_identityAuditLog_timestamp").on(table.timestamp),
index("idx_identityAuditLog_org_timestamp").on(
table.orgId,
table.timestamp
)
]
);
export type Limit = InferSelectModel<typeof limits>;
export type Account = InferSelectModel<typeof account>;
@@ -265,4 +283,4 @@ export type RemoteExitNodeSession = InferSelectModel<
export type ExitNodeOrg = InferSelectModel<typeof exitNodeOrgs>;
export type LoginPage = InferSelectModel<typeof loginPage>;
export type ActionAuditLog = InferSelectModel<typeof actionAuditLog>;
export type AccessAuditLog = InferSelectModel<typeof accessAuditLog>;
export type AccessAuditLog = InferSelectModel<typeof accessAuditLog>;

View File

@@ -18,10 +18,13 @@ function createEmailClient() {
host: emailConfig.smtp_host,
port: emailConfig.smtp_port,
secure: emailConfig.smtp_secure || false,
auth: (emailConfig.smtp_user && emailConfig.smtp_pass) ? {
user: emailConfig.smtp_user,
pass: emailConfig.smtp_pass
} : null
auth:
emailConfig.smtp_user && emailConfig.smtp_pass
? {
user: emailConfig.smtp_user,
pass: emailConfig.smtp_pass
}
: null
} as SMTPTransport.Options;
if (emailConfig.smtp_tls_reject_unauthorized !== undefined) {

View File

@@ -19,7 +19,13 @@ interface Props {
billingLink: string; // Link to billing page
}
export const NotifyUsageLimitApproaching = ({ email, limitName, currentUsage, usageLimit, billingLink }: Props) => {
export const NotifyUsageLimitApproaching = ({
email,
limitName,
currentUsage,
usageLimit,
billingLink
}: Props) => {
const previewText = `Your usage for ${limitName} is approaching the limit.`;
const usagePercentage = Math.round((currentUsage / usageLimit) * 100);
@@ -37,23 +43,32 @@ export const NotifyUsageLimitApproaching = ({ email, limitName, currentUsage, us
<EmailGreeting>Hi there,</EmailGreeting>
<EmailText>
We wanted to let you know that your usage for <strong>{limitName}</strong> is approaching your plan limit.
We wanted to let you know that your usage for{" "}
<strong>{limitName}</strong> is approaching your
plan limit.
</EmailText>
<EmailText>
<strong>Current Usage:</strong> {currentUsage} of {usageLimit} ({usagePercentage}%)
<strong>Current Usage:</strong> {currentUsage} of{" "}
{usageLimit} ({usagePercentage}%)
</EmailText>
<EmailText>
Once you reach your limit, some functionality may be restricted or your sites may disconnect until you upgrade your plan or your usage resets.
Once you reach your limit, some functionality may be
restricted or your sites may disconnect until you
upgrade your plan or your usage resets.
</EmailText>
<EmailText>
To avoid any interruption to your service, we recommend upgrading your plan or monitoring your usage closely. You can <a href={billingLink}>upgrade your plan here</a>.
To avoid any interruption to your service, we
recommend upgrading your plan or monitoring your
usage closely. You can{" "}
<a href={billingLink}>upgrade your plan here</a>.
</EmailText>
<EmailText>
If you have any questions or need assistance, please don't hesitate to reach out to our support team.
If you have any questions or need assistance, please
don't hesitate to reach out to our support team.
</EmailText>
<EmailFooter>

View File

@@ -19,7 +19,13 @@ interface Props {
billingLink: string; // Link to billing page
}
export const NotifyUsageLimitReached = ({ email, limitName, currentUsage, usageLimit, billingLink }: Props) => {
export const NotifyUsageLimitReached = ({
email,
limitName,
currentUsage,
usageLimit,
billingLink
}: Props) => {
const previewText = `You've reached your ${limitName} usage limit - Action required`;
const usagePercentage = Math.round((currentUsage / usageLimit) * 100);
@@ -32,30 +38,48 @@ export const NotifyUsageLimitReached = ({ email, limitName, currentUsage, usageL
<EmailContainer>
<EmailLetterHead />
<EmailHeading>Usage Limit Reached - Action Required</EmailHeading>
<EmailHeading>
Usage Limit Reached - Action Required
</EmailHeading>
<EmailGreeting>Hi there,</EmailGreeting>
<EmailText>
You have reached your usage limit for <strong>{limitName}</strong>.
You have reached your usage limit for{" "}
<strong>{limitName}</strong>.
</EmailText>
<EmailText>
<strong>Current Usage:</strong> {currentUsage} of {usageLimit} ({usagePercentage}%)
<strong>Current Usage:</strong> {currentUsage} of{" "}
{usageLimit} ({usagePercentage}%)
</EmailText>
<EmailText>
<strong>Important:</strong> Your functionality may now be restricted and your sites may disconnect until you either upgrade your plan or your usage resets. To prevent any service interruption, immediate action is recommended.
<strong>Important:</strong> Your functionality may
now be restricted and your sites may disconnect
until you either upgrade your plan or your usage
resets. To prevent any service interruption,
immediate action is recommended.
</EmailText>
<EmailText>
<strong>What you can do:</strong>
<br /> <a href={billingLink} style={{ color: '#2563eb', fontWeight: 'bold' }}>Upgrade your plan immediately</a> to restore full functionality
<br /> Monitor your usage to stay within limits in the future
<br />{" "}
<a
href={billingLink}
style={{ color: "#2563eb", fontWeight: "bold" }}
>
Upgrade your plan immediately
</a>{" "}
to restore full functionality
<br /> Monitor your usage to stay within limits in
the future
</EmailText>
<EmailText>
If you have any questions or need immediate assistance, please contact our support team right away.
If you have any questions or need immediate
assistance, please contact our support team right
away.
</EmailText>
<EmailFooter>

View File

@@ -5,7 +5,7 @@ import config from "@server/lib/config";
import logger from "@server/logger";
import {
errorHandlerMiddleware,
notFoundMiddleware,
notFoundMiddleware
} from "@server/middlewares";
import { authenticated, unauthenticated } from "#dynamic/routers/integration";
import { logIncomingMiddleware } from "./middlewares/logIncoming";

View File

@@ -25,16 +25,22 @@ export const FeatureMeterIdsSandbox: Record<FeatureId, string> = {
};
export function getFeatureMeterId(featureId: FeatureId): string {
if (process.env.ENVIRONMENT == "prod" && process.env.SANDBOX_MODE !== "true") {
if (
process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true"
) {
return FeatureMeterIds[featureId];
} else {
return FeatureMeterIdsSandbox[featureId];
}
}
export function getFeatureIdByMetricId(metricId: string): FeatureId | undefined {
return (Object.entries(FeatureMeterIds) as [FeatureId, string][])
.find(([_, v]) => v === metricId)?.[0];
export function getFeatureIdByMetricId(
metricId: string
): FeatureId | undefined {
return (Object.entries(FeatureMeterIds) as [FeatureId, string][]).find(
([_, v]) => v === metricId
)?.[0];
}
export type FeaturePriceSet = {
@@ -43,7 +49,8 @@ export type FeaturePriceSet = {
[FeatureId.DOMAINS]?: string; // Optional since domains are not billed
};
export const standardFeaturePriceSet: FeaturePriceSet = { // Free tier matches the freeLimitSet
export const standardFeaturePriceSet: FeaturePriceSet = {
// Free tier matches the freeLimitSet
[FeatureId.SITE_UPTIME]: "price_1RrQc4D3Ee2Ir7WmaJGZ3MtF",
[FeatureId.USERS]: "price_1RrQeJD3Ee2Ir7WmgveP3xea",
[FeatureId.EGRESS_DATA_MB]: "price_1RrQXFD3Ee2Ir7WmvGDlgxQk",
@@ -51,7 +58,8 @@ export const standardFeaturePriceSet: FeaturePriceSet = { // Free tier matches t
[FeatureId.REMOTE_EXIT_NODES]: "price_1S46weD3Ee2Ir7Wm94KEHI4h"
};
export const standardFeaturePriceSetSandbox: FeaturePriceSet = { // Free tier matches the freeLimitSet
export const standardFeaturePriceSetSandbox: FeaturePriceSet = {
// Free tier matches the freeLimitSet
[FeatureId.SITE_UPTIME]: "price_1RefFBDCpkOb237BPrKZ8IEU",
[FeatureId.USERS]: "price_1ReNa4DCpkOb237Bc67G5muF",
[FeatureId.EGRESS_DATA_MB]: "price_1Rfp9LDCpkOb237BwuN5Oiu0",
@@ -60,15 +68,20 @@ export const standardFeaturePriceSetSandbox: FeaturePriceSet = { // Free tier ma
};
export function getStandardFeaturePriceSet(): FeaturePriceSet {
if (process.env.ENVIRONMENT == "prod" && process.env.SANDBOX_MODE !== "true") {
if (
process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true"
) {
return standardFeaturePriceSet;
} else {
return standardFeaturePriceSetSandbox;
}
}
export function getLineItems(featurePriceSet: FeaturePriceSet): Stripe.Checkout.SessionCreateParams.LineItem[] {
export function getLineItems(
featurePriceSet: FeaturePriceSet
): Stripe.Checkout.SessionCreateParams.LineItem[] {
return Object.entries(featurePriceSet).map(([featureId, priceId]) => ({
price: priceId,
price: priceId
}));
}
}

View File

@@ -2,4 +2,4 @@ export * from "./limitSet";
export * from "./features";
export * from "./limitsService";
export * from "./getOrgTierData";
export * from "./createCustomer";
export * from "./createCustomer";

View File

@@ -12,7 +12,7 @@ export const sandboxLimitSet: LimitSet = {
[FeatureId.USERS]: { value: 1, description: "Sandbox limit" },
[FeatureId.EGRESS_DATA_MB]: { value: 1000, description: "Sandbox limit" }, // 1 GB
[FeatureId.DOMAINS]: { value: 0, description: "Sandbox limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" }
};
export const freeLimitSet: LimitSet = {
@@ -29,7 +29,7 @@ export const freeLimitSet: LimitSet = {
export const subscribedLimitSet: LimitSet = {
[FeatureId.SITE_UPTIME]: {
value: 2232000,
description: "Contact us to increase soft limit.",
description: "Contact us to increase soft limit."
}, // 50 sites up for 31 days
[FeatureId.USERS]: {
value: 150,
@@ -38,7 +38,7 @@ export const subscribedLimitSet: LimitSet = {
[FeatureId.EGRESS_DATA_MB]: {
value: 12000000,
description: "Contact us to increase soft limit."
}, // 12000 GB
}, // 12000 GB
[FeatureId.DOMAINS]: {
value: 25,
description: "Contact us to increase soft limit."

View File

@@ -1,22 +1,32 @@
export enum TierId {
STANDARD = "standard",
STANDARD = "standard"
}
export type TierPriceSet = {
[key in TierId]: string;
};
export const tierPriceSet: TierPriceSet = { // Free tier matches the freeLimitSet
[TierId.STANDARD]: "price_1RrQ9cD3Ee2Ir7Wmqdy3KBa0",
export const tierPriceSet: TierPriceSet = {
// Free tier matches the freeLimitSet
[TierId.STANDARD]: "price_1RrQ9cD3Ee2Ir7Wmqdy3KBa0"
};
export const tierPriceSetSandbox: TierPriceSet = { // Free tier matches the freeLimitSet
export const tierPriceSetSandbox: TierPriceSet = {
// Free tier matches the freeLimitSet
// when matching tier the keys closer to 0 index are matched first so list the tiers in descending order of value
[TierId.STANDARD]: "price_1RrAYJDCpkOb237By2s1P32m",
[TierId.STANDARD]: "price_1RrAYJDCpkOb237By2s1P32m"
};
export function getTierPriceSet(environment?: string, sandbox_mode?: boolean): TierPriceSet {
if ((process.env.ENVIRONMENT == "prod" && process.env.SANDBOX_MODE !== "true") || (environment === "prod" && sandbox_mode !== true)) { // THIS GETS LOADED CLIENT SIDE AND SERVER SIDE
export function getTierPriceSet(
environment?: string,
sandbox_mode?: boolean
): TierPriceSet {
if (
(process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true") ||
(environment === "prod" && sandbox_mode !== true)
) {
// THIS GETS LOADED CLIENT SIDE AND SERVER SIDE
return tierPriceSet;
} else {
return tierPriceSetSandbox;

View File

@@ -19,7 +19,7 @@ import logger from "@server/logger";
import { sendToClient } from "#dynamic/routers/ws";
import { build } from "@server/build";
import { s3Client } from "@server/lib/s3";
import cache from "@server/lib/cache";
import cache from "@server/lib/cache";
interface StripeEvent {
identifier?: string;

View File

@@ -34,7 +34,10 @@ export async function applyNewtDockerBlueprint(
return;
}
if (isEmptyObject(blueprint["proxy-resources"]) && isEmptyObject(blueprint["client-resources"])) {
if (
isEmptyObject(blueprint["proxy-resources"]) &&
isEmptyObject(blueprint["client-resources"])
) {
return;
}

View File

@@ -84,12 +84,20 @@ export function processContainerLabels(containers: Container[]): {
// Process proxy resources
if (Object.keys(proxyResourceLabels).length > 0) {
processResourceLabels(proxyResourceLabels, container, result["proxy-resources"]);
processResourceLabels(
proxyResourceLabels,
container,
result["proxy-resources"]
);
}
// Process client resources
if (Object.keys(clientResourceLabels).length > 0) {
processResourceLabels(clientResourceLabels, container, result["client-resources"]);
processResourceLabels(
clientResourceLabels,
container,
result["client-resources"]
);
}
});
@@ -161,8 +169,7 @@ function processResourceLabels(
const finalTarget = { ...target };
if (!finalTarget.hostname) {
finalTarget.hostname =
container.name ||
container.hostname;
container.name || container.hostname;
}
if (!finalTarget.port) {
const containerPort =

View File

@@ -312,7 +312,7 @@ export const ConfigSchema = z
};
delete (data as any)["public-resources"];
}
// Merge private-resources into client-resources
if (data["private-resources"]) {
data["client-resources"] = {
@@ -321,10 +321,13 @@ export const ConfigSchema = z
};
delete (data as any)["private-resources"];
}
return data as {
"proxy-resources": Record<string, z.infer<typeof ResourceSchema>>;
"client-resources": Record<string, z.infer<typeof ClientResourceSchema>>;
"client-resources": Record<
string,
z.infer<typeof ClientResourceSchema>
>;
sites: Record<string, z.infer<typeof SiteSchema>>;
};
})

View File

@@ -2,4 +2,4 @@ import NodeCache from "node-cache";
export const cache = new NodeCache({ stdTTL: 3600, checkperiod: 120 });
export default cache;
export default cache;

View File

@@ -166,7 +166,10 @@ export async function calculateUserClientsForOrgs(
];
// Get next available subnet
const newSubnet = await getNextAvailableClientSubnet(orgId, transaction);
const newSubnet = await getNextAvailableClientSubnet(
orgId,
transaction
);
if (!newSubnet) {
logger.warn(
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): no available subnet found`

View File

@@ -1,4 +1,6 @@
export async function getValidCertificatesForDomains(domains: Set<string>): Promise<
export async function getValidCertificatesForDomains(
domains: Set<string>
): Promise<
Array<{
id: number;
domain: string;
@@ -10,4 +12,4 @@ export async function getValidCertificatesForDomains(domains: Set<string>): Prom
}>
> {
return []; // stub
}
}

View File

@@ -7,7 +7,10 @@ function dateToTimestamp(dateStr: string): number {
// Testable version of calculateCutoffTimestamp that accepts a "now" timestamp
// This matches the logic in cleanupLogs.ts but allows injecting the current time
function calculateCutoffTimestampWithNow(retentionDays: number, nowTimestamp: number): number {
function calculateCutoffTimestampWithNow(
retentionDays: number,
nowTimestamp: number
): number {
if (retentionDays === 9001) {
// Special case: data is erased at the end of the year following the year it was generated
// This means we delete logs from 2 years ago or older (logs from year Y are deleted after Dec 31 of year Y+1)
@@ -28,7 +31,7 @@ function testCalculateCutoffTimestamp() {
{
const now = dateToTimestamp("2025-12-06T12:00:00Z");
const result = calculateCutoffTimestampWithNow(30, now);
const expected = now - (30 * 24 * 60 * 60);
const expected = now - 30 * 24 * 60 * 60;
assertEquals(result, expected, "30 days retention calculation failed");
}
@@ -36,7 +39,7 @@ function testCalculateCutoffTimestamp() {
{
const now = dateToTimestamp("2025-06-15T00:00:00Z");
const result = calculateCutoffTimestampWithNow(90, now);
const expected = now - (90 * 24 * 60 * 60);
const expected = now - 90 * 24 * 60 * 60;
assertEquals(result, expected, "90 days retention calculation failed");
}
@@ -48,7 +51,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2025-12-06T12:00:00Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2024-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (Dec 2025) - should cutoff at Jan 1, 2024");
assertEquals(
result,
expected,
"9001 retention (Dec 2025) - should cutoff at Jan 1, 2024"
);
}
// Test 4: Special case 9001 - January 2026
@@ -58,7 +65,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2026-01-15T12:00:00Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2025-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (Jan 2026) - should cutoff at Jan 1, 2025");
assertEquals(
result,
expected,
"9001 retention (Jan 2026) - should cutoff at Jan 1, 2025"
);
}
// Test 5: Special case 9001 - December 31, 2025 at 23:59:59 UTC
@@ -68,7 +79,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2025-12-31T23:59:59Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2024-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (Dec 31, 2025 23:59:59) - should cutoff at Jan 1, 2024");
assertEquals(
result,
expected,
"9001 retention (Dec 31, 2025 23:59:59) - should cutoff at Jan 1, 2024"
);
}
// Test 6: Special case 9001 - January 1, 2026 at 00:00:01 UTC
@@ -78,7 +93,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2026-01-01T00:00:01Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2025-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (Jan 1, 2026 00:00:01) - should cutoff at Jan 1, 2025");
assertEquals(
result,
expected,
"9001 retention (Jan 1, 2026 00:00:01) - should cutoff at Jan 1, 2025"
);
}
// Test 7: Special case 9001 - Mid year 2025
@@ -87,7 +106,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2025-06-15T12:00:00Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2024-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (mid 2025) - should cutoff at Jan 1, 2024");
assertEquals(
result,
expected,
"9001 retention (mid 2025) - should cutoff at Jan 1, 2024"
);
}
// Test 8: Special case 9001 - Early 2024
@@ -96,14 +119,18 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2024-02-01T12:00:00Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2023-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (early 2024) - should cutoff at Jan 1, 2023");
assertEquals(
result,
expected,
"9001 retention (early 2024) - should cutoff at Jan 1, 2023"
);
}
// Test 9: 1 day retention
{
const now = dateToTimestamp("2025-12-06T12:00:00Z");
const result = calculateCutoffTimestampWithNow(1, now);
const expected = now - (1 * 24 * 60 * 60);
const expected = now - 1 * 24 * 60 * 60;
assertEquals(result, expected, "1 day retention calculation failed");
}
@@ -111,7 +138,7 @@ function testCalculateCutoffTimestamp() {
{
const now = dateToTimestamp("2025-12-06T12:00:00Z");
const result = calculateCutoffTimestampWithNow(365, now);
const expected = now - (365 * 24 * 60 * 60);
const expected = now - 365 * 24 * 60 * 60;
assertEquals(result, expected, "365 days retention calculation failed");
}
@@ -123,11 +150,19 @@ function testCalculateCutoffTimestamp() {
const cutoff = calculateCutoffTimestampWithNow(9001, now);
const logFromDec2023 = dateToTimestamp("2023-12-31T23:59:59Z");
const logFromJan2024 = dateToTimestamp("2024-01-01T00:00:00Z");
// Log from Dec 2023 should be before cutoff (deleted)
assertEquals(logFromDec2023 < cutoff, true, "Log from Dec 2023 should be deleted");
assertEquals(
logFromDec2023 < cutoff,
true,
"Log from Dec 2023 should be deleted"
);
// Log from Jan 2024 should be at or after cutoff (kept)
assertEquals(logFromJan2024 >= cutoff, true, "Log from Jan 2024 should be kept");
assertEquals(
logFromJan2024 >= cutoff,
true,
"Log from Jan 2024 should be kept"
);
}
// Test 12: Verify 9001 in 2026 - logs from 2024 should now be deleted
@@ -136,11 +171,19 @@ function testCalculateCutoffTimestamp() {
const cutoff = calculateCutoffTimestampWithNow(9001, now);
const logFromDec2024 = dateToTimestamp("2024-12-31T23:59:59Z");
const logFromJan2025 = dateToTimestamp("2025-01-01T00:00:00Z");
// Log from Dec 2024 should be before cutoff (deleted)
assertEquals(logFromDec2024 < cutoff, true, "Log from Dec 2024 should be deleted in 2026");
assertEquals(
logFromDec2024 < cutoff,
true,
"Log from Dec 2024 should be deleted in 2026"
);
// Log from Jan 2025 should be at or after cutoff (kept)
assertEquals(logFromJan2025 >= cutoff, true, "Log from Jan 2025 should be kept in 2026");
assertEquals(
logFromJan2025 >= cutoff,
true,
"Log from Jan 2025 should be kept in 2026"
);
}
// Test 13: Edge case - exactly at year boundary for 9001
@@ -149,7 +192,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2025-01-01T00:00:00Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2024-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (Jan 1, 2025 00:00:00) - should cutoff at Jan 1, 2024");
assertEquals(
result,
expected,
"9001 retention (Jan 1, 2025 00:00:00) - should cutoff at Jan 1, 2024"
);
}
// Test 14: Verify data from 2024 is kept throughout 2025 when using 9001
@@ -157,18 +204,29 @@ function testCalculateCutoffTimestamp() {
{
// Running in June 2025
const nowJune2025 = dateToTimestamp("2025-06-15T12:00:00Z");
const cutoffJune2025 = calculateCutoffTimestampWithNow(9001, nowJune2025);
const cutoffJune2025 = calculateCutoffTimestampWithNow(
9001,
nowJune2025
);
const logFromJuly2024 = dateToTimestamp("2024-07-15T12:00:00Z");
// Log from July 2024 should be KEPT in June 2025
assertEquals(logFromJuly2024 >= cutoffJune2025, true, "Log from July 2024 should be kept in June 2025");
assertEquals(
logFromJuly2024 >= cutoffJune2025,
true,
"Log from July 2024 should be kept in June 2025"
);
// Running in January 2026
const nowJan2026 = dateToTimestamp("2026-01-15T12:00:00Z");
const cutoffJan2026 = calculateCutoffTimestampWithNow(9001, nowJan2026);
// Log from July 2024 should be DELETED in January 2026
assertEquals(logFromJuly2024 < cutoffJan2026, true, "Log from July 2024 should be deleted in Jan 2026");
assertEquals(
logFromJuly2024 < cutoffJan2026,
true,
"Log from July 2024 should be deleted in Jan 2026"
);
}
// Test 15: Verify the exact requirement - data from 2024 must be purged on December 31, 2025
@@ -176,16 +234,27 @@ function testCalculateCutoffTimestamp() {
// On Jan 1, 2026 (now 2026), data from 2024 can be deleted
{
const logFromMid2024 = dateToTimestamp("2024-06-15T12:00:00Z");
// Dec 31, 2025 23:59:59 - still 2025, log should be kept
const nowDec31_2025 = dateToTimestamp("2025-12-31T23:59:59Z");
const cutoffDec31 = calculateCutoffTimestampWithNow(9001, nowDec31_2025);
assertEquals(logFromMid2024 >= cutoffDec31, true, "Log from mid-2024 should be kept on Dec 31, 2025");
const cutoffDec31 = calculateCutoffTimestampWithNow(
9001,
nowDec31_2025
);
assertEquals(
logFromMid2024 >= cutoffDec31,
true,
"Log from mid-2024 should be kept on Dec 31, 2025"
);
// Jan 1, 2026 00:00:00 - now 2026, log can be deleted
const nowJan1_2026 = dateToTimestamp("2026-01-01T00:00:00Z");
const cutoffJan1 = calculateCutoffTimestampWithNow(9001, nowJan1_2026);
assertEquals(logFromMid2024 < cutoffJan1, true, "Log from mid-2024 should be deleted on Jan 1, 2026");
assertEquals(
logFromMid2024 < cutoffJan1,
true,
"Log from mid-2024 should be deleted on Jan 1, 2026"
);
}
console.log("All calculateCutoffTimestamp tests passed!");

View File

@@ -4,18 +4,20 @@ import { eq, and } from "drizzle-orm";
import { subdomainSchema } from "@server/lib/schemas";
import { fromError } from "zod-validation-error";
export type DomainValidationResult = {
success: true;
fullDomain: string;
subdomain: string | null;
} | {
success: false;
error: string;
};
export type DomainValidationResult =
| {
success: true;
fullDomain: string;
subdomain: string | null;
}
| {
success: false;
error: string;
};
/**
* Validates a domain and constructs the full domain based on domain type and subdomain.
*
*
* @param domainId - The ID of the domain to validate
* @param orgId - The organization ID to check domain access
* @param subdomain - Optional subdomain to append (for ns and wildcard domains)
@@ -34,7 +36,10 @@ export async function validateAndConstructDomain(
.where(eq(domains.domainId, domainId))
.leftJoin(
orgDomains,
and(eq(orgDomains.orgId, orgId), eq(orgDomains.domainId, domainId))
and(
eq(orgDomains.orgId, orgId),
eq(orgDomains.domainId, domainId)
)
);
// Check if domain exists
@@ -106,7 +111,7 @@ export async function validateAndConstructDomain(
} catch (error) {
return {
success: false,
error: `An error occurred while validating domain: ${error instanceof Error ? error.message : 'Unknown error'}`
error: `An error occurred while validating domain: ${error instanceof Error ? error.message : "Unknown error"}`
};
}
}

View File

@@ -1,39 +1,39 @@
import crypto from 'crypto';
import crypto from "crypto";
export function encryptData(data: string, key: Buffer): string {
const algorithm = 'aes-256-gcm';
const iv = crypto.randomBytes(16);
const cipher = crypto.createCipheriv(algorithm, key, iv);
let encrypted = cipher.update(data, 'utf8', 'hex');
encrypted += cipher.final('hex');
const authTag = cipher.getAuthTag();
// Combine IV, auth tag, and encrypted data
return iv.toString('hex') + ':' + authTag.toString('hex') + ':' + encrypted;
const algorithm = "aes-256-gcm";
const iv = crypto.randomBytes(16);
const cipher = crypto.createCipheriv(algorithm, key, iv);
let encrypted = cipher.update(data, "utf8", "hex");
encrypted += cipher.final("hex");
const authTag = cipher.getAuthTag();
// Combine IV, auth tag, and encrypted data
return iv.toString("hex") + ":" + authTag.toString("hex") + ":" + encrypted;
}
// Helper function to decrypt data (you'll need this to read certificates)
export function decryptData(encryptedData: string, key: Buffer): string {
const algorithm = 'aes-256-gcm';
const parts = encryptedData.split(':');
if (parts.length !== 3) {
throw new Error('Invalid encrypted data format');
}
const iv = Buffer.from(parts[0], 'hex');
const authTag = Buffer.from(parts[1], 'hex');
const encrypted = parts[2];
const decipher = crypto.createDecipheriv(algorithm, key, iv);
decipher.setAuthTag(authTag);
let decrypted = decipher.update(encrypted, 'hex', 'utf8');
decrypted += decipher.final('utf8');
return decrypted;
const algorithm = "aes-256-gcm";
const parts = encryptedData.split(":");
if (parts.length !== 3) {
throw new Error("Invalid encrypted data format");
}
const iv = Buffer.from(parts[0], "hex");
const authTag = Buffer.from(parts[1], "hex");
const encrypted = parts[2];
const decipher = crypto.createDecipheriv(algorithm, key, iv);
decipher.setAuthTag(authTag);
let decrypted = decipher.update(encrypted, "hex", "utf8");
decrypted += decipher.final("utf8");
return decrypted;
}
// openssl rand -hex 32 > config/encryption.key
// openssl rand -hex 32 > config/encryption.key

View File

@@ -30,4 +30,4 @@ export async function getCurrentExitNodeId(): Promise<number> {
}
}
return currentExitNodeId;
}
}

View File

@@ -1,4 +1,4 @@
export * from "./exitNodes";
export * from "./exitNodeComms";
export * from "./subnet";
export * from "./getCurrentExitNodeId";
export * from "./getCurrentExitNodeId";

View File

@@ -27,4 +27,4 @@ export async function getNextAvailableSubnet(): Promise<string> {
"/" +
subnet.split("/")[1];
return subnet;
}
}

View File

@@ -30,4 +30,4 @@ export async function getCountryCodeForIp(
}
return;
}
}

View File

@@ -33,7 +33,11 @@ export async function generateOidcRedirectUrl(
)
.limit(1);
if (res?.loginPage && res.loginPage.domainId && res.loginPage.fullDomain) {
if (
res?.loginPage &&
res.loginPage.domainId &&
res.loginPage.fullDomain
) {
baseUrl = `${method}://${res.loginPage.fullDomain}`;
}
}

View File

@@ -4,7 +4,7 @@ import { assertEquals } from "@test/assert";
// Test cases
function testFindNextAvailableCidr() {
console.log("Running findNextAvailableCidr tests...");
// Test 0: Basic IPv4 allocation with a subnet in the wrong range
{
const existing = ["100.90.130.1/30", "100.90.128.4/30"];
@@ -23,7 +23,11 @@ function testFindNextAvailableCidr() {
{
const existing = ["10.0.0.0/16", "10.2.0.0/16"];
const result = findNextAvailableCidr(existing, 16, "10.0.0.0/8");
assertEquals(result, "10.1.0.0/16", "Finding gap between allocations failed");
assertEquals(
result,
"10.1.0.0/16",
"Finding gap between allocations failed"
);
}
// Test 3: No available space
@@ -33,7 +37,7 @@ function testFindNextAvailableCidr() {
assertEquals(result, null, "No available space test failed");
}
// Test 4: Empty existing
// Test 4: Empty existing
{
const existing: string[] = [];
const result = findNextAvailableCidr(existing, 30, "10.0.0.0/8");
@@ -137,4 +141,4 @@ try {
} catch (error) {
console.error("Test failed:", error);
process.exit(1);
}
}

View File

@@ -247,7 +247,10 @@ export async function getNextAvailableClientSubnet(
orgId: string,
transaction: Transaction | typeof db = db
): Promise<string> {
const [org] = await transaction.select().from(orgs).where(eq(orgs.orgId, orgId));
const [org] = await transaction
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId));
if (!org) {
throw new Error(`Organization with ID ${orgId} not found`);
@@ -360,7 +363,9 @@ export async function getNextAvailableOrgSubnet(): Promise<string> {
return subnet;
}
export function generateRemoteSubnets(allSiteResources: SiteResource[]): string[] {
export function generateRemoteSubnets(
allSiteResources: SiteResource[]
): string[] {
const remoteSubnets = allSiteResources
.filter((sr) => {
if (sr.mode === "cidr") return true;

View File

@@ -14,4 +14,4 @@ export async function logAccessAudit(data: {
requestIp?: string;
}) {
return;
}
}

View File

@@ -14,7 +14,8 @@ export const configSchema = z
.object({
app: z
.object({
dashboard_url: z.url()
dashboard_url: z
.url()
.pipe(z.url())
.transform((url) => url.toLowerCase())
.optional(),
@@ -255,7 +256,10 @@ export const configSchema = z
.object({
block_size: z.number().positive().gt(0).optional().default(24),
subnet_group: z.string().optional().default("100.90.128.0/24"),
utility_subnet_group: z.string().optional().default("100.96.128.0/24") //just hardcode this for now as well
utility_subnet_group: z
.string()
.optional()
.default("100.96.128.0/24") //just hardcode this for now as well
})
.optional()
.default({

View File

@@ -32,7 +32,7 @@ import logger from "@server/logger";
import {
generateAliasConfig,
generateRemoteSubnets,
generateSubnetProxyTargets,
generateSubnetProxyTargets
} from "@server/lib/ip";
import {
addPeerData,
@@ -109,21 +109,22 @@ export async function getClientSiteResourceAccess(
const directClientIds = allClientSiteResources.map((row) => row.clientId);
// Get full client details for directly associated clients
const directClients = directClientIds.length > 0
? await trx
.select({
clientId: clients.clientId,
pubKey: clients.pubKey,
subnet: clients.subnet
})
.from(clients)
.where(
and(
inArray(clients.clientId, directClientIds),
eq(clients.orgId, siteResource.orgId) // filter by org to prevent cross-org associations
const directClients =
directClientIds.length > 0
? await trx
.select({
clientId: clients.clientId,
pubKey: clients.pubKey,
subnet: clients.subnet
})
.from(clients)
.where(
and(
inArray(clients.clientId, directClientIds),
eq(clients.orgId, siteResource.orgId) // filter by org to prevent cross-org associations
)
)
)
: [];
: [];
// Merge user-based clients with directly associated clients
const allClientsMap = new Map(
@@ -731,9 +732,10 @@ async function handleSubnetProxyTargetUpdates(
);
// Only remove remote subnet if no other resource uses the same destination
const remoteSubnetsToRemove = destinationStillInUse.length > 0
? []
: generateRemoteSubnets([siteResource]);
const remoteSubnetsToRemove =
destinationStillInUse.length > 0
? []
: generateRemoteSubnets([siteResource]);
olmJobs.push(
removePeerData(
@@ -817,7 +819,10 @@ export async function rebuildClientAssociationsFromClient(
.from(roleSiteResources)
.innerJoin(
siteResources,
eq(siteResources.siteResourceId, roleSiteResources.siteResourceId)
eq(
siteResources.siteResourceId,
roleSiteResources.siteResourceId
)
)
.where(
and(
@@ -1277,9 +1282,10 @@ async function handleMessagesForClientResources(
);
// Only remove remote subnet if no other resource uses the same destination
const remoteSubnetsToRemove = destinationStillInUse.length > 0
? []
: generateRemoteSubnets([resource]);
const remoteSubnetsToRemove =
destinationStillInUse.length > 0
? []
: generateRemoteSubnets([resource]);
// Remove peer data from olm
olmJobs.push(

View File

@@ -1,8 +1,8 @@
export enum AudienceIds {
SignUps = "",
Subscribed = "",
Churned = "",
Newsletter = ""
SignUps = "",
Subscribed = "",
Churned = "",
Newsletter = ""
}
let resend;

View File

@@ -3,14 +3,14 @@ import { Response } from "express";
export const response = <T>(
res: Response,
{ data, success, error, message, status }: ResponseT<T>,
{ data, success, error, message, status }: ResponseT<T>
) => {
return res.status(status).send({
data,
success,
error,
message,
status,
status
});
};

View File

@@ -1,5 +1,5 @@
import { S3Client } from "@aws-sdk/client-s3";
export const s3Client = new S3Client({
region: process.env.S3_REGION || "us-east-1",
region: process.env.S3_REGION || "us-east-1"
});

View File

@@ -6,7 +6,7 @@ let serverIp: string | null = null;
const services = [
"https://checkip.amazonaws.com",
"https://ifconfig.io/ip",
"https://api.ipify.org",
"https://api.ipify.org"
];
export async function fetchServerIp() {
@@ -17,7 +17,9 @@ export async function fetchServerIp() {
logger.debug("Detected public IP: " + serverIp);
return;
} catch (err: any) {
console.warn(`Failed to fetch server IP from ${url}: ${err.message || err.code}`);
console.warn(
`Failed to fetch server IP from ${url}: ${err.message || err.code}`
);
}
}

View File

@@ -1,8 +1,7 @@
export default function stoi(val: any) {
if (typeof val === "string") {
return parseInt(val);
return parseInt(val);
} else {
return val;
}
else {
return val;
}
}
}

View File

@@ -195,7 +195,9 @@ export class TraefikConfigManager {
state.set(domain, {
exists: certExists && keyExists,
lastModified: lastModified ? Math.floor(lastModified.getTime() / 1000) : null,
lastModified: lastModified
? Math.floor(lastModified.getTime() / 1000)
: null,
expiresAt,
wildcard
});
@@ -464,7 +466,9 @@ export class TraefikConfigManager {
config.getRawConfig().traefik.site_types,
build == "oss", // filter out the namespace domains in open source
build != "oss", // generate the login pages on the cloud and hybrid,
build == "saas" ? false : config.getRawConfig().traefik.allow_raw_resources // dont allow raw resources on saas otherwise use config
build == "saas"
? false
: config.getRawConfig().traefik.allow_raw_resources // dont allow raw resources on saas otherwise use config
);
const domains = new Set<string>();
@@ -788,7 +792,10 @@ export class TraefikConfigManager {
// Store the certificate expiry time
if (cert.expiresAt) {
const expiresAtPath = path.join(domainDir, ".expires_at");
const expiresAtPath = path.join(
domainDir,
".expires_at"
);
fs.writeFileSync(
expiresAtPath,
cert.expiresAt.toString(),

View File

@@ -1 +1 @@
export * from "./getTraefikConfig";
export * from "./getTraefikConfig";

View File

@@ -2,234 +2,249 @@ import { assertEquals } from "@test/assert";
import { isDomainCoveredByWildcard } from "./TraefikConfigManager";
function runTests() {
console.log('Running wildcard domain coverage tests...');
console.log("Running wildcard domain coverage tests...");
// Test case 1: Basic wildcard certificate at example.com
const basicWildcardCerts = new Map([
['example.com', { exists: true, wildcard: true }]
["example.com", { exists: true, wildcard: true }]
]);
// Should match first-level subdomains
assertEquals(
isDomainCoveredByWildcard('level1.example.com', basicWildcardCerts),
isDomainCoveredByWildcard("level1.example.com", basicWildcardCerts),
true,
'Wildcard cert at example.com should match level1.example.com'
"Wildcard cert at example.com should match level1.example.com"
);
assertEquals(
isDomainCoveredByWildcard('api.example.com', basicWildcardCerts),
isDomainCoveredByWildcard("api.example.com", basicWildcardCerts),
true,
'Wildcard cert at example.com should match api.example.com'
"Wildcard cert at example.com should match api.example.com"
);
assertEquals(
isDomainCoveredByWildcard('www.example.com', basicWildcardCerts),
isDomainCoveredByWildcard("www.example.com", basicWildcardCerts),
true,
'Wildcard cert at example.com should match www.example.com'
"Wildcard cert at example.com should match www.example.com"
);
// Should match the root domain (exact match)
assertEquals(
isDomainCoveredByWildcard('example.com', basicWildcardCerts),
isDomainCoveredByWildcard("example.com", basicWildcardCerts),
true,
'Wildcard cert at example.com should match example.com itself'
"Wildcard cert at example.com should match example.com itself"
);
// Should NOT match second-level subdomains
assertEquals(
isDomainCoveredByWildcard('level2.level1.example.com', basicWildcardCerts),
isDomainCoveredByWildcard(
"level2.level1.example.com",
basicWildcardCerts
),
false,
'Wildcard cert at example.com should NOT match level2.level1.example.com'
"Wildcard cert at example.com should NOT match level2.level1.example.com"
);
assertEquals(
isDomainCoveredByWildcard('deep.nested.subdomain.example.com', basicWildcardCerts),
isDomainCoveredByWildcard(
"deep.nested.subdomain.example.com",
basicWildcardCerts
),
false,
'Wildcard cert at example.com should NOT match deep.nested.subdomain.example.com'
"Wildcard cert at example.com should NOT match deep.nested.subdomain.example.com"
);
// Should NOT match different domains
assertEquals(
isDomainCoveredByWildcard('test.otherdomain.com', basicWildcardCerts),
isDomainCoveredByWildcard("test.otherdomain.com", basicWildcardCerts),
false,
'Wildcard cert at example.com should NOT match test.otherdomain.com'
"Wildcard cert at example.com should NOT match test.otherdomain.com"
);
assertEquals(
isDomainCoveredByWildcard('notexample.com', basicWildcardCerts),
isDomainCoveredByWildcard("notexample.com", basicWildcardCerts),
false,
'Wildcard cert at example.com should NOT match notexample.com'
"Wildcard cert at example.com should NOT match notexample.com"
);
// Test case 2: Multiple wildcard certificates
const multipleWildcardCerts = new Map([
['example.com', { exists: true, wildcard: true }],
['test.org', { exists: true, wildcard: true }],
['api.service.net', { exists: true, wildcard: true }]
["example.com", { exists: true, wildcard: true }],
["test.org", { exists: true, wildcard: true }],
["api.service.net", { exists: true, wildcard: true }]
]);
assertEquals(
isDomainCoveredByWildcard('app.example.com', multipleWildcardCerts),
isDomainCoveredByWildcard("app.example.com", multipleWildcardCerts),
true,
'Should match subdomain of first wildcard cert'
"Should match subdomain of first wildcard cert"
);
assertEquals(
isDomainCoveredByWildcard('staging.test.org', multipleWildcardCerts),
isDomainCoveredByWildcard("staging.test.org", multipleWildcardCerts),
true,
'Should match subdomain of second wildcard cert'
"Should match subdomain of second wildcard cert"
);
assertEquals(
isDomainCoveredByWildcard('v1.api.service.net', multipleWildcardCerts),
isDomainCoveredByWildcard("v1.api.service.net", multipleWildcardCerts),
true,
'Should match subdomain of third wildcard cert'
"Should match subdomain of third wildcard cert"
);
assertEquals(
isDomainCoveredByWildcard('deep.nested.api.service.net', multipleWildcardCerts),
isDomainCoveredByWildcard(
"deep.nested.api.service.net",
multipleWildcardCerts
),
false,
'Should NOT match multi-level subdomain of third wildcard cert'
"Should NOT match multi-level subdomain of third wildcard cert"
);
// Test exact domain matches for multiple certs
assertEquals(
isDomainCoveredByWildcard('example.com', multipleWildcardCerts),
isDomainCoveredByWildcard("example.com", multipleWildcardCerts),
true,
'Should match exact domain of first wildcard cert'
"Should match exact domain of first wildcard cert"
);
assertEquals(
isDomainCoveredByWildcard('test.org', multipleWildcardCerts),
isDomainCoveredByWildcard("test.org", multipleWildcardCerts),
true,
'Should match exact domain of second wildcard cert'
"Should match exact domain of second wildcard cert"
);
assertEquals(
isDomainCoveredByWildcard('api.service.net', multipleWildcardCerts),
isDomainCoveredByWildcard("api.service.net", multipleWildcardCerts),
true,
'Should match exact domain of third wildcard cert'
"Should match exact domain of third wildcard cert"
);
// Test case 3: Non-wildcard certificates (should not match anything)
const nonWildcardCerts = new Map([
['example.com', { exists: true, wildcard: false }],
['specific.domain.com', { exists: true, wildcard: false }]
["example.com", { exists: true, wildcard: false }],
["specific.domain.com", { exists: true, wildcard: false }]
]);
assertEquals(
isDomainCoveredByWildcard('sub.example.com', nonWildcardCerts),
isDomainCoveredByWildcard("sub.example.com", nonWildcardCerts),
false,
'Non-wildcard cert should not match subdomains'
"Non-wildcard cert should not match subdomains"
);
assertEquals(
isDomainCoveredByWildcard('example.com', nonWildcardCerts),
isDomainCoveredByWildcard("example.com", nonWildcardCerts),
false,
'Non-wildcard cert should not match even exact domain via this function'
"Non-wildcard cert should not match even exact domain via this function"
);
// Test case 4: Non-existent certificates (should not match)
const nonExistentCerts = new Map([
['example.com', { exists: false, wildcard: true }],
['missing.com', { exists: false, wildcard: true }]
["example.com", { exists: false, wildcard: true }],
["missing.com", { exists: false, wildcard: true }]
]);
assertEquals(
isDomainCoveredByWildcard('sub.example.com', nonExistentCerts),
isDomainCoveredByWildcard("sub.example.com", nonExistentCerts),
false,
'Non-existent wildcard cert should not match'
"Non-existent wildcard cert should not match"
);
// Test case 5: Edge cases with special domain names
const specialDomainCerts = new Map([
['localhost', { exists: true, wildcard: true }],
['127-0-0-1.nip.io', { exists: true, wildcard: true }],
['xn--e1afmkfd.xn--p1ai', { exists: true, wildcard: true }] // IDN domain
["localhost", { exists: true, wildcard: true }],
["127-0-0-1.nip.io", { exists: true, wildcard: true }],
["xn--e1afmkfd.xn--p1ai", { exists: true, wildcard: true }] // IDN domain
]);
assertEquals(
isDomainCoveredByWildcard('app.localhost', specialDomainCerts),
isDomainCoveredByWildcard("app.localhost", specialDomainCerts),
true,
'Should match subdomain of localhost wildcard'
"Should match subdomain of localhost wildcard"
);
assertEquals(
isDomainCoveredByWildcard('test.127-0-0-1.nip.io', specialDomainCerts),
isDomainCoveredByWildcard("test.127-0-0-1.nip.io", specialDomainCerts),
true,
'Should match subdomain of nip.io wildcard'
"Should match subdomain of nip.io wildcard"
);
assertEquals(
isDomainCoveredByWildcard('sub.xn--e1afmkfd.xn--p1ai', specialDomainCerts),
isDomainCoveredByWildcard(
"sub.xn--e1afmkfd.xn--p1ai",
specialDomainCerts
),
true,
'Should match subdomain of IDN wildcard'
"Should match subdomain of IDN wildcard"
);
// Test case 6: Empty input and edge cases
const emptyCerts = new Map();
assertEquals(
isDomainCoveredByWildcard('any.domain.com', emptyCerts),
isDomainCoveredByWildcard("any.domain.com", emptyCerts),
false,
'Empty certificate map should not match any domain'
"Empty certificate map should not match any domain"
);
// Test case 7: Domains with single character components
const singleCharCerts = new Map([
['a.com', { exists: true, wildcard: true }],
['x.y.z', { exists: true, wildcard: true }]
["a.com", { exists: true, wildcard: true }],
["x.y.z", { exists: true, wildcard: true }]
]);
assertEquals(
isDomainCoveredByWildcard('b.a.com', singleCharCerts),
isDomainCoveredByWildcard("b.a.com", singleCharCerts),
true,
'Should match single character subdomain'
"Should match single character subdomain"
);
assertEquals(
isDomainCoveredByWildcard('w.x.y.z', singleCharCerts),
isDomainCoveredByWildcard("w.x.y.z", singleCharCerts),
true,
'Should match single character subdomain of multi-part domain'
"Should match single character subdomain of multi-part domain"
);
assertEquals(
isDomainCoveredByWildcard('v.w.x.y.z', singleCharCerts),
isDomainCoveredByWildcard("v.w.x.y.z", singleCharCerts),
false,
'Should NOT match multi-level subdomain of single char domain'
"Should NOT match multi-level subdomain of single char domain"
);
// Test case 8: Domains with numbers and hyphens
const numericCerts = new Map([
['api-v2.service-1.com', { exists: true, wildcard: true }],
['123.456.net', { exists: true, wildcard: true }]
["api-v2.service-1.com", { exists: true, wildcard: true }],
["123.456.net", { exists: true, wildcard: true }]
]);
assertEquals(
isDomainCoveredByWildcard('staging.api-v2.service-1.com', numericCerts),
isDomainCoveredByWildcard("staging.api-v2.service-1.com", numericCerts),
true,
'Should match subdomain with hyphens and numbers'
"Should match subdomain with hyphens and numbers"
);
assertEquals(
isDomainCoveredByWildcard('test.123.456.net', numericCerts),
isDomainCoveredByWildcard("test.123.456.net", numericCerts),
true,
'Should match subdomain with numeric components'
"Should match subdomain with numeric components"
);
assertEquals(
isDomainCoveredByWildcard('deep.staging.api-v2.service-1.com', numericCerts),
isDomainCoveredByWildcard(
"deep.staging.api-v2.service-1.com",
numericCerts
),
false,
'Should NOT match multi-level subdomain with hyphens and numbers'
"Should NOT match multi-level subdomain with hyphens and numbers"
);
console.log('All wildcard domain coverage tests passed!');
console.log("All wildcard domain coverage tests passed!");
}
// Run all tests
try {
runTests();
} catch (error) {
console.error('Test failed:', error);
console.error("Test failed:", error);
process.exit(1);
}

View File

@@ -31,12 +31,17 @@ export function validatePathRewriteConfig(
}
if (rewritePathType !== "stripPrefix") {
if ((rewritePath && !rewritePathType) || (!rewritePath && rewritePathType)) {
return { isValid: false, error: "Both rewritePath and rewritePathType must be specified together" };
if (
(rewritePath && !rewritePathType) ||
(!rewritePath && rewritePathType)
) {
return {
isValid: false,
error: "Both rewritePath and rewritePathType must be specified together"
};
}
}
if (!rewritePath || !rewritePathType) {
return { isValid: true };
}
@@ -68,14 +73,14 @@ export function validatePathRewriteConfig(
}
}
// Additional validation for stripPrefix
if (rewritePathType === "stripPrefix") {
if (pathMatchType !== "prefix") {
logger.warn(`stripPrefix rewrite type is most effective with prefix path matching. Current match type: ${pathMatchType}`);
logger.warn(
`stripPrefix rewrite type is most effective with prefix path matching. Current match type: ${pathMatchType}`
);
}
}
return { isValid: true };
}

View File

@@ -1,71 +1,247 @@
import { isValidUrlGlobPattern } from "./validators";
import { isValidUrlGlobPattern } from "./validators";
import { assertEquals } from "@test/assert";
function runTests() {
console.log('Running URL pattern validation tests...');
console.log("Running URL pattern validation tests...");
// Test valid patterns
assertEquals(isValidUrlGlobPattern('simple'), true, 'Simple path segment should be valid');
assertEquals(isValidUrlGlobPattern('simple/path'), true, 'Simple path with slash should be valid');
assertEquals(isValidUrlGlobPattern('/leading/slash'), true, 'Path with leading slash should be valid');
assertEquals(isValidUrlGlobPattern('path/'), true, 'Path with trailing slash should be valid');
assertEquals(isValidUrlGlobPattern('path/*'), true, 'Path with wildcard segment should be valid');
assertEquals(isValidUrlGlobPattern('*'), true, 'Single wildcard should be valid');
assertEquals(isValidUrlGlobPattern('*/subpath'), true, 'Wildcard with subpath should be valid');
assertEquals(isValidUrlGlobPattern('path/*/more'), true, 'Path with wildcard in the middle should be valid');
assertEquals(
isValidUrlGlobPattern("simple"),
true,
"Simple path segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("simple/path"),
true,
"Simple path with slash should be valid"
);
assertEquals(
isValidUrlGlobPattern("/leading/slash"),
true,
"Path with leading slash should be valid"
);
assertEquals(
isValidUrlGlobPattern("path/"),
true,
"Path with trailing slash should be valid"
);
assertEquals(
isValidUrlGlobPattern("path/*"),
true,
"Path with wildcard segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("*"),
true,
"Single wildcard should be valid"
);
assertEquals(
isValidUrlGlobPattern("*/subpath"),
true,
"Wildcard with subpath should be valid"
);
assertEquals(
isValidUrlGlobPattern("path/*/more"),
true,
"Path with wildcard in the middle should be valid"
);
// Test with special characters
assertEquals(isValidUrlGlobPattern('path-with-dash'), true, 'Path with dash should be valid');
assertEquals(isValidUrlGlobPattern('path_with_underscore'), true, 'Path with underscore should be valid');
assertEquals(isValidUrlGlobPattern('path.with.dots'), true, 'Path with dots should be valid');
assertEquals(isValidUrlGlobPattern('path~with~tilde'), true, 'Path with tilde should be valid');
assertEquals(isValidUrlGlobPattern('path!with!exclamation'), true, 'Path with exclamation should be valid');
assertEquals(isValidUrlGlobPattern('path$with$dollar'), true, 'Path with dollar should be valid');
assertEquals(isValidUrlGlobPattern('path&with&ampersand'), true, 'Path with ampersand should be valid');
assertEquals(isValidUrlGlobPattern("path'with'quote"), true, "Path with quote should be valid");
assertEquals(isValidUrlGlobPattern('path(with)parentheses'), true, 'Path with parentheses should be valid');
assertEquals(isValidUrlGlobPattern('path+with+plus'), true, 'Path with plus should be valid');
assertEquals(isValidUrlGlobPattern('path,with,comma'), true, 'Path with comma should be valid');
assertEquals(isValidUrlGlobPattern('path;with;semicolon'), true, 'Path with semicolon should be valid');
assertEquals(isValidUrlGlobPattern('path=with=equals'), true, 'Path with equals should be valid');
assertEquals(isValidUrlGlobPattern('path:with:colon'), true, 'Path with colon should be valid');
assertEquals(isValidUrlGlobPattern('path@with@at'), true, 'Path with at should be valid');
assertEquals(
isValidUrlGlobPattern("path-with-dash"),
true,
"Path with dash should be valid"
);
assertEquals(
isValidUrlGlobPattern("path_with_underscore"),
true,
"Path with underscore should be valid"
);
assertEquals(
isValidUrlGlobPattern("path.with.dots"),
true,
"Path with dots should be valid"
);
assertEquals(
isValidUrlGlobPattern("path~with~tilde"),
true,
"Path with tilde should be valid"
);
assertEquals(
isValidUrlGlobPattern("path!with!exclamation"),
true,
"Path with exclamation should be valid"
);
assertEquals(
isValidUrlGlobPattern("path$with$dollar"),
true,
"Path with dollar should be valid"
);
assertEquals(
isValidUrlGlobPattern("path&with&ampersand"),
true,
"Path with ampersand should be valid"
);
assertEquals(
isValidUrlGlobPattern("path'with'quote"),
true,
"Path with quote should be valid"
);
assertEquals(
isValidUrlGlobPattern("path(with)parentheses"),
true,
"Path with parentheses should be valid"
);
assertEquals(
isValidUrlGlobPattern("path+with+plus"),
true,
"Path with plus should be valid"
);
assertEquals(
isValidUrlGlobPattern("path,with,comma"),
true,
"Path with comma should be valid"
);
assertEquals(
isValidUrlGlobPattern("path;with;semicolon"),
true,
"Path with semicolon should be valid"
);
assertEquals(
isValidUrlGlobPattern("path=with=equals"),
true,
"Path with equals should be valid"
);
assertEquals(
isValidUrlGlobPattern("path:with:colon"),
true,
"Path with colon should be valid"
);
assertEquals(
isValidUrlGlobPattern("path@with@at"),
true,
"Path with at should be valid"
);
// Test with percent encoding
assertEquals(isValidUrlGlobPattern('path%20with%20spaces'), true, 'Path with percent-encoded spaces should be valid');
assertEquals(isValidUrlGlobPattern('path%2Fwith%2Fencoded%2Fslashes'), true, 'Path with percent-encoded slashes should be valid');
assertEquals(
isValidUrlGlobPattern("path%20with%20spaces"),
true,
"Path with percent-encoded spaces should be valid"
);
assertEquals(
isValidUrlGlobPattern("path%2Fwith%2Fencoded%2Fslashes"),
true,
"Path with percent-encoded slashes should be valid"
);
// Test with wildcards in segments (the fixed functionality)
assertEquals(isValidUrlGlobPattern('padbootstrap*'), true, 'Path with wildcard at the end of segment should be valid');
assertEquals(isValidUrlGlobPattern('pad*bootstrap'), true, 'Path with wildcard in the middle of segment should be valid');
assertEquals(isValidUrlGlobPattern('*bootstrap'), true, 'Path with wildcard at the start of segment should be valid');
assertEquals(isValidUrlGlobPattern('multiple*wildcards*in*segment'), true, 'Path with multiple wildcards in segment should be valid');
assertEquals(isValidUrlGlobPattern('wild*/cards/in*/different/seg*ments'), true, 'Path with wildcards in different segments should be valid');
assertEquals(
isValidUrlGlobPattern("padbootstrap*"),
true,
"Path with wildcard at the end of segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("pad*bootstrap"),
true,
"Path with wildcard in the middle of segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("*bootstrap"),
true,
"Path with wildcard at the start of segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("multiple*wildcards*in*segment"),
true,
"Path with multiple wildcards in segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("wild*/cards/in*/different/seg*ments"),
true,
"Path with wildcards in different segments should be valid"
);
// Test invalid patterns
assertEquals(isValidUrlGlobPattern(''), false, 'Empty string should be invalid');
assertEquals(isValidUrlGlobPattern('//double/slash'), false, 'Path with double slash should be invalid');
assertEquals(isValidUrlGlobPattern('path//end'), false, 'Path with double slash in the middle should be invalid');
assertEquals(isValidUrlGlobPattern('invalid<char>'), false, 'Path with invalid characters should be invalid');
assertEquals(isValidUrlGlobPattern('invalid|char'), false, 'Path with invalid pipe character should be invalid');
assertEquals(isValidUrlGlobPattern('invalid"char'), false, 'Path with invalid quote character should be invalid');
assertEquals(isValidUrlGlobPattern('invalid`char'), false, 'Path with invalid backtick character should be invalid');
assertEquals(isValidUrlGlobPattern('invalid^char'), false, 'Path with invalid caret character should be invalid');
assertEquals(isValidUrlGlobPattern('invalid\\char'), false, 'Path with invalid backslash character should be invalid');
assertEquals(isValidUrlGlobPattern('invalid[char]'), false, 'Path with invalid square brackets should be invalid');
assertEquals(isValidUrlGlobPattern('invalid{char}'), false, 'Path with invalid curly braces should be invalid');
assertEquals(
isValidUrlGlobPattern(""),
false,
"Empty string should be invalid"
);
assertEquals(
isValidUrlGlobPattern("//double/slash"),
false,
"Path with double slash should be invalid"
);
assertEquals(
isValidUrlGlobPattern("path//end"),
false,
"Path with double slash in the middle should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid<char>"),
false,
"Path with invalid characters should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid|char"),
false,
"Path with invalid pipe character should be invalid"
);
assertEquals(
isValidUrlGlobPattern('invalid"char'),
false,
"Path with invalid quote character should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid`char"),
false,
"Path with invalid backtick character should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid^char"),
false,
"Path with invalid caret character should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid\\char"),
false,
"Path with invalid backslash character should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid[char]"),
false,
"Path with invalid square brackets should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid{char}"),
false,
"Path with invalid curly braces should be invalid"
);
// Test invalid percent encoding
assertEquals(isValidUrlGlobPattern('invalid%2'), false, 'Path with incomplete percent encoding should be invalid');
assertEquals(isValidUrlGlobPattern('invalid%GZ'), false, 'Path with invalid hex in percent encoding should be invalid');
assertEquals(isValidUrlGlobPattern('invalid%'), false, 'Path with isolated percent sign should be invalid');
console.log('All tests passed!');
assertEquals(
isValidUrlGlobPattern("invalid%2"),
false,
"Path with incomplete percent encoding should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid%GZ"),
false,
"Path with invalid hex in percent encoding should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid%"),
false,
"Path with isolated percent sign should be invalid"
);
console.log("All tests passed!");
}
// Run all tests
try {
runTests();
} catch (error) {
console.error('Test failed:', error);
}
console.error("Test failed:", error);
}

View File

@@ -2,7 +2,9 @@ import z from "zod";
import ipaddr from "ipaddr.js";
export function isValidCIDR(cidr: string): boolean {
return z.cidrv4().safeParse(cidr).success || z.cidrv6().safeParse(cidr).success;
return (
z.cidrv4().safeParse(cidr).success || z.cidrv6().safeParse(cidr).success
);
}
export function isValidIP(ip: string): boolean {
@@ -69,11 +71,11 @@ export function isUrlValid(url: string | undefined) {
if (!url) return true; // the link is optional in the schema so if it's empty it's valid
var pattern = new RegExp(
"^(https?:\\/\\/)?" + // protocol
"((([a-z\\d]([a-z\\d-]*[a-z\\d])*)\\.)+[a-z]{2,}|" + // domain name
"((\\d{1,3}\\.){3}\\d{1,3}))" + // OR ip (v4) address
"(\\:\\d+)?(\\/[-a-z\\d%_.~+]*)*" + // port and path
"(\\?[;&a-z\\d%_.~+=-]*)?" + // query string
"(\\#[-a-z\\d_]*)?$",
"((([a-z\\d]([a-z\\d-]*[a-z\\d])*)\\.)+[a-z]{2,}|" + // domain name
"((\\d{1,3}\\.){3}\\d{1,3}))" + // OR ip (v4) address
"(\\:\\d+)?(\\/[-a-z\\d%_.~+]*)*" + // port and path
"(\\?[;&a-z\\d%_.~+=-]*)?" + // query string
"(\\#[-a-z\\d_]*)?$",
"i"
);
return !!pattern.test(url);
@@ -168,14 +170,14 @@ export function validateHeaders(headers: string): boolean {
}
export function isSecondLevelDomain(domain: string): boolean {
if (!domain || typeof domain !== 'string') {
if (!domain || typeof domain !== "string") {
return false;
}
const trimmedDomain = domain.trim().toLowerCase();
// Split into parts
const parts = trimmedDomain.split('.');
const parts = trimmedDomain.split(".");
// Should have exactly 2 parts for a second-level domain (e.g., "example.com")
if (parts.length !== 2) {

View File

@@ -20,6 +20,6 @@ export const errorHandlerMiddleware: ErrorRequestHandler = (
error: true,
message: error.message || "Internal Server Error",
status: statusCode,
stack: process.env.ENVIRONMENT === "prod" ? null : error.stack,
stack: process.env.ENVIRONMENT === "prod" ? null : error.stack
});
};

View File

@@ -8,13 +8,13 @@ import HttpCode from "@server/types/HttpCode";
export async function getUserOrgs(
req: Request,
res: Response,
next: NextFunction,
next: NextFunction
) {
const userId = req.user?.userId; // Assuming you have user information in the request
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated"),
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
);
}
@@ -22,7 +22,7 @@ export async function getUserOrgs(
const userOrganizations = await db
.select({
orgId: userOrgs.orgId,
roleId: userOrgs.roleId,
roleId: userOrgs.roleId
})
.from(userOrgs)
.where(eq(userOrgs.userId, userId));
@@ -38,8 +38,8 @@ export async function getUserOrgs(
next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error retrieving user organizations",
),
"Error retrieving user organizations"
)
);
}
}

View File

@@ -12,4 +12,4 @@ export * from "./verifyAccessTokenAccess";
export * from "./verifyApiKeyIsRoot";
export * from "./verifyApiKeyApiKeyAccess";
export * from "./verifyApiKeyClientAccess";
export * from "./verifyApiKeySiteResourceAccess";
export * from "./verifyApiKeySiteResourceAccess";

View File

@@ -97,7 +97,6 @@ export async function verifyApiKeyAccessTokenAccess(
);
}
return next();
} catch (e) {
return next(

View File

@@ -11,7 +11,7 @@ export async function verifyApiKeyApiKeyAccess(
next: NextFunction
) {
try {
const {apiKey: callerApiKey } = req;
const { apiKey: callerApiKey } = req;
const apiKeyId =
req.params.apiKeyId || req.body.apiKeyId || req.query.apiKeyId;
@@ -44,7 +44,10 @@ export async function verifyApiKeyApiKeyAccess(
.select()
.from(apiKeyOrg)
.where(
and(eq(apiKeys.apiKeyId, callerApiKey.apiKeyId), eq(apiKeyOrg.orgId, orgId))
and(
eq(apiKeys.apiKeyId, callerApiKey.apiKeyId),
eq(apiKeyOrg.orgId, orgId)
)
)
.limit(1);

View File

@@ -11,9 +11,12 @@ export async function verifyApiKeySetResourceClients(
next: NextFunction
) {
const apiKey = req.apiKey;
const singleClientId = req.params.clientId || req.body.clientId || req.query.clientId;
const singleClientId =
req.params.clientId || req.body.clientId || req.query.clientId;
const { clientIds } = req.body;
const allClientIds = clientIds || (singleClientId ? [parseInt(singleClientId as string)] : []);
const allClientIds =
clientIds ||
(singleClientId ? [parseInt(singleClientId as string)] : []);
if (!apiKey) {
return next(
@@ -70,4 +73,3 @@ export async function verifyApiKeySetResourceClients(
);
}
}

View File

@@ -11,7 +11,8 @@ export async function verifyApiKeySetResourceUsers(
next: NextFunction
) {
const apiKey = req.apiKey;
const singleUserId = req.params.userId || req.body.userId || req.query.userId;
const singleUserId =
req.params.userId || req.body.userId || req.query.userId;
const { userIds } = req.body;
const allUserIds = userIds || (singleUserId ? [singleUserId] : []);

View File

@@ -38,17 +38,12 @@ export async function verifyApiKeySiteResourceAccess(
const [siteResource] = await db
.select()
.from(siteResources)
.where(and(
eq(siteResources.siteResourceId, siteResourceId)
))
.where(and(eq(siteResources.siteResourceId, siteResourceId)))
.limit(1);
if (!siteResource) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Site resource not found"
)
createHttpError(HttpCode.NOT_FOUND, "Site resource not found")
);
}

View File

@@ -5,7 +5,7 @@ import HttpCode from "@server/types/HttpCode";
export function notFoundMiddleware(
req: Request,
res: Response,
next: NextFunction,
next: NextFunction
) {
if (req.path.startsWith("/api")) {
const message = `The requests url is not found - ${req.originalUrl}`;

View File

@@ -1,30 +1,32 @@
import { Request, Response, NextFunction } from 'express';
import logger from '@server/logger';
import createHttpError from 'http-errors';
import HttpCode from '@server/types/HttpCode';
import { Request, Response, NextFunction } from "express";
import logger from "@server/logger";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
export function requestTimeoutMiddleware(timeoutMs: number = 30000) {
return (req: Request, res: Response, next: NextFunction) => {
// Set a timeout for the request
const timeout = setTimeout(() => {
if (!res.headersSent) {
logger.error(`Request timeout: ${req.method} ${req.url} from ${req.ip}`);
logger.error(
`Request timeout: ${req.method} ${req.url} from ${req.ip}`
);
return next(
createHttpError(
HttpCode.REQUEST_TIMEOUT,
'Request timeout - operation took too long to complete'
"Request timeout - operation took too long to complete"
)
);
}
}, timeoutMs);
// Clear timeout when response finishes
res.on('finish', () => {
res.on("finish", () => {
clearTimeout(timeout);
});
// Clear timeout when response closes
res.on('close', () => {
res.on("close", () => {
clearTimeout(timeout);
});

View File

@@ -76,7 +76,10 @@ export async function verifySiteAccess(
.select()
.from(userOrgs)
.where(
and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, site.orgId))
and(
eq(userOrgs.userId, userId),
eq(userOrgs.orgId, site.orgId)
)
)
.limit(1);
req.userOrg = userOrgRole[0];

View File

@@ -9,7 +9,10 @@ const nextPort = config.getRawConfig().server.next_port;
export async function createNextServer() {
// const app = next({ dev });
const app = next({ dev: process.env.ENVIRONMENT !== "prod", turbopack: true });
const app = next({
dev: process.env.ENVIRONMENT !== "prod",
turbopack: true
});
const handle = app.getRequestHandler();
await app.prepare();

View File

@@ -11,11 +11,14 @@
* This file is not licensed under the AGPLv3.
*/
import {
encodeHexLowerCase,
} from "@oslojs/encoding";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { RemoteExitNode, remoteExitNodes, remoteExitNodeSessions, RemoteExitNodeSession } from "@server/db";
import {
RemoteExitNode,
remoteExitNodes,
remoteExitNodeSessions,
RemoteExitNodeSession
} from "@server/db";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
@@ -23,30 +26,39 @@ export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
export async function createRemoteExitNodeSession(
token: string,
remoteExitNodeId: string,
remoteExitNodeId: string
): Promise<RemoteExitNodeSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const session: RemoteExitNodeSession = {
sessionId: sessionId,
remoteExitNodeId,
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
expiresAt: new Date(Date.now() + EXPIRES).getTime()
};
await db.insert(remoteExitNodeSessions).values(session);
return session;
}
export async function validateRemoteExitNodeSessionToken(
token: string,
token: string
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const result = await db
.select({ remoteExitNode: remoteExitNodes, session: remoteExitNodeSessions })
.select({
remoteExitNode: remoteExitNodes,
session: remoteExitNodeSessions
})
.from(remoteExitNodeSessions)
.innerJoin(remoteExitNodes, eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodes.remoteExitNodeId))
.innerJoin(
remoteExitNodes,
eq(
remoteExitNodeSessions.remoteExitNodeId,
remoteExitNodes.remoteExitNodeId
)
)
.where(eq(remoteExitNodeSessions.sessionId, sessionId));
if (result.length < 1) {
return { session: null, remoteExitNode: null };
@@ -58,26 +70,32 @@ export async function validateRemoteExitNodeSessionToken(
.where(eq(remoteExitNodeSessions.sessionId, session.sessionId));
return { session: null, remoteExitNode: null };
}
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + EXPIRES,
).getTime();
if (Date.now() >= session.expiresAt - EXPIRES / 2) {
session.expiresAt = new Date(Date.now() + EXPIRES).getTime();
await db
.update(remoteExitNodeSessions)
.set({
expiresAt: session.expiresAt,
expiresAt: session.expiresAt
})
.where(eq(remoteExitNodeSessions.sessionId, session.sessionId));
}
return { session, remoteExitNode };
}
export async function invalidateRemoteExitNodeSession(sessionId: string): Promise<void> {
await db.delete(remoteExitNodeSessions).where(eq(remoteExitNodeSessions.sessionId, sessionId));
export async function invalidateRemoteExitNodeSession(
sessionId: string
): Promise<void> {
await db
.delete(remoteExitNodeSessions)
.where(eq(remoteExitNodeSessions.sessionId, sessionId));
}
export async function invalidateAllRemoteExitNodeSessions(remoteExitNodeId: string): Promise<void> {
await db.delete(remoteExitNodeSessions).where(eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodeId));
export async function invalidateAllRemoteExitNodeSessions(
remoteExitNodeId: string
): Promise<void> {
await db
.delete(remoteExitNodeSessions)
.where(eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodeId));
}
export type SessionValidationResult =

View File

@@ -25,4 +25,4 @@ export async function initCleanup() {
// Handle process termination
process.on("SIGTERM", () => cleanup());
process.on("SIGINT", () => cleanup());
}
}

View File

@@ -12,4 +12,4 @@
*/
export * from "./getOrgTierData";
export * from "./createCustomer";
export * from "./createCustomer";

View File

@@ -55,7 +55,6 @@ export async function getValidCertificatesForDomains(
domains: Set<string>,
useCache: boolean = true
): Promise<Array<CertificateResult>> {
loadEncryptData(); // Ensure encryption key is loaded
const finalResults: CertificateResult[] = [];

View File

@@ -12,14 +12,7 @@
*/
import { build } from "@server/build";
import {
db,
Org,
orgs,
ResourceSession,
sessions,
users
} from "@server/db";
import { db, Org, orgs, ResourceSession, sessions, users } from "@server/db";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import license from "#private/license/license";

View File

@@ -66,7 +66,9 @@ export async function sendToExitNode(
// logger.debug(`Configured local exit node name: ${config.getRawConfig().gerbil.exit_node_name}`);
if (exitNode.name == config.getRawConfig().gerbil.exit_node_name) {
hostname = privateConfig.getRawPrivateConfig().gerbil.local_exit_node_reachable_at;
hostname =
privateConfig.getRawPrivateConfig().gerbil
.local_exit_node_reachable_at;
}
if (!hostname) {

View File

@@ -44,43 +44,53 @@ async function checkExitNodeOnlineStatus(
const delayBetweenAttempts = 100; // 100ms delay between starting each attempt
// Create promises for all attempts with staggered delays
const attemptPromises = Array.from({ length: maxAttempts }, async (_, index) => {
const attemptNumber = index + 1;
// Add delay before each attempt (except the first)
if (index > 0) {
await new Promise((resolve) => setTimeout(resolve, delayBetweenAttempts * index));
}
const attemptPromises = Array.from(
{ length: maxAttempts },
async (_, index) => {
const attemptNumber = index + 1;
try {
const response = await axios.get(`http://${endpoint}/ping`, {
timeout: timeoutMs,
validateStatus: (status) => status === 200
});
if (response.status === 200) {
logger.debug(
`Exit node ${endpoint} is online (attempt ${attemptNumber}/${maxAttempts})`
// Add delay before each attempt (except the first)
if (index > 0) {
await new Promise((resolve) =>
setTimeout(resolve, delayBetweenAttempts * index)
);
return { success: true, attemptNumber };
}
return { success: false, attemptNumber, error: 'Non-200 status' };
} catch (error) {
const errorMessage = error instanceof Error ? error.message : "Unknown error";
logger.debug(
`Exit node ${endpoint} ping failed (attempt ${attemptNumber}/${maxAttempts}): ${errorMessage}`
);
return { success: false, attemptNumber, error: errorMessage };
try {
const response = await axios.get(`http://${endpoint}/ping`, {
timeout: timeoutMs,
validateStatus: (status) => status === 200
});
if (response.status === 200) {
logger.debug(
`Exit node ${endpoint} is online (attempt ${attemptNumber}/${maxAttempts})`
);
return { success: true, attemptNumber };
}
return {
success: false,
attemptNumber,
error: "Non-200 status"
};
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
logger.debug(
`Exit node ${endpoint} ping failed (attempt ${attemptNumber}/${maxAttempts}): ${errorMessage}`
);
return { success: false, attemptNumber, error: errorMessage };
}
}
});
);
try {
// Wait for the first successful response or all to fail
const results = await Promise.allSettled(attemptPromises);
// Check if any attempt succeeded
for (const result of results) {
if (result.status === 'fulfilled' && result.value.success) {
if (result.status === "fulfilled" && result.value.success) {
return true;
}
}
@@ -137,7 +147,11 @@ export async function verifyExitNodeOrgAccess(
return { hasAccess: false, exitNode };
}
export async function listExitNodes(orgId: string, filterOnline = false, noCloud = false) {
export async function listExitNodes(
orgId: string,
filterOnline = false,
noCloud = false
) {
const allExitNodes = await db
.select({
exitNodeId: exitNodes.exitNodeId,
@@ -166,7 +180,10 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
eq(exitNodes.type, "gerbil"),
or(
// only choose nodes that are in the same region
eq(exitNodes.region, config.getRawPrivateConfig().app.region),
eq(
exitNodes.region,
config.getRawPrivateConfig().app.region
),
isNull(exitNodes.region) // or for enterprise where region is not set
)
),
@@ -191,7 +208,7 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
// let online: boolean;
// if (filterOnline && node.type == "remoteExitNode") {
// try {
// const isActuallyOnline = await checkExitNodeOnlineStatus(
// const isActuallyOnline = await checkExitNodeOnlineStatus(
// node.endpoint
// );
@@ -225,7 +242,8 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
node.type === "remoteExitNode" && (!filterOnline || node.online)
);
const gerbilExitNodes = allExitNodes.filter(
(node) => node.type === "gerbil" && (!filterOnline || node.online) && !noCloud
(node) =>
node.type === "gerbil" && (!filterOnline || node.online) && !noCloud
);
// THIS PROVIDES THE FALL
@@ -334,7 +352,11 @@ export function selectBestExitNode(
return fallbackNode;
}
export async function checkExitNodeOrg(exitNodeId: number, orgId: string, trx: Transaction | typeof db = db) {
export async function checkExitNodeOrg(
exitNodeId: number,
orgId: string,
trx: Transaction | typeof db = db
) {
const [exitNodeOrg] = await trx
.select()
.from(exitNodeOrgs)

View File

@@ -12,4 +12,4 @@
*/
export * from "./exitNodeComms";
export * from "./exitNodes";
export * from "./exitNodes";

View File

@@ -177,7 +177,9 @@ export class LockManager {
const exists = value !== null;
const ownedByMe =
exists &&
value!.startsWith(`${config.getRawConfig().gerbil.exit_node_name}:`);
value!.startsWith(
`${config.getRawConfig().gerbil.exit_node_name}:`
);
const owner = exists ? value!.split(":")[0] : undefined;
return {

View File

@@ -14,15 +14,15 @@
// Simple test file for the rate limit service with Redis
// Run with: npx ts-node rateLimitService.test.ts
import { RateLimitService } from './rateLimit';
import { RateLimitService } from "./rateLimit";
function generateClientId() {
return 'client-' + Math.random().toString(36).substring(2, 15);
return "client-" + Math.random().toString(36).substring(2, 15);
}
async function runTests() {
console.log('Starting Rate Limit Service Tests...\n');
console.log("Starting Rate Limit Service Tests...\n");
const rateLimitService = new RateLimitService();
let testsPassed = 0;
let testsTotal = 0;
@@ -47,36 +47,54 @@ async function runTests() {
}
// Test 1: Basic rate limiting
await test('Should allow requests under the limit', async () => {
await test("Should allow requests under the limit", async () => {
const clientId = generateClientId();
const maxRequests = 5;
for (let i = 0; i < maxRequests - 1; i++) {
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
const result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
assert(result.totalHits === i + 1, `Expected ${i + 1} hits, got ${result.totalHits}`);
assert(
result.totalHits === i + 1,
`Expected ${i + 1} hits, got ${result.totalHits}`
);
}
});
// Test 2: Rate limit blocking
await test('Should block requests over the limit', async () => {
await test("Should block requests over the limit", async () => {
const clientId = generateClientId();
const maxRequests = 30;
// Use up all allowed requests
for (let i = 0; i < maxRequests - 1; i++) {
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
const result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
}
// Next request should be blocked
const blockedResult = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(blockedResult.isLimited, 'Request should be blocked');
assert(blockedResult.reason === 'global', 'Should be blocked for global reason');
const blockedResult = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(blockedResult.isLimited, "Request should be blocked");
assert(
blockedResult.reason === "global",
"Should be blocked for global reason"
);
});
// Test 3: Message type limits
await test('Should handle message type limits', async () => {
await test("Should handle message type limits", async () => {
const clientId = generateClientId();
const globalMax = 10;
const messageTypeMax = 2;
@@ -84,54 +102,64 @@ async function runTests() {
// Send messages of type 'ping' up to the limit
for (let i = 0; i < messageTypeMax - 1; i++) {
const result = await rateLimitService.checkRateLimit(
clientId,
'ping',
globalMax,
clientId,
"ping",
globalMax,
messageTypeMax
);
assert(!result.isLimited, `Ping message ${i + 1} should be allowed`);
assert(
!result.isLimited,
`Ping message ${i + 1} should be allowed`
);
}
// Next 'ping' should be blocked
const blockedResult = await rateLimitService.checkRateLimit(
clientId,
'ping',
globalMax,
clientId,
"ping",
globalMax,
messageTypeMax
);
assert(blockedResult.isLimited, 'Ping message should be blocked');
assert(blockedResult.reason === 'message_type:ping', 'Should be blocked for message type');
assert(blockedResult.isLimited, "Ping message should be blocked");
assert(
blockedResult.reason === "message_type:ping",
"Should be blocked for message type"
);
// Other message types should still work
const otherResult = await rateLimitService.checkRateLimit(
clientId,
'pong',
globalMax,
clientId,
"pong",
globalMax,
messageTypeMax
);
assert(!otherResult.isLimited, 'Pong message should be allowed');
assert(!otherResult.isLimited, "Pong message should be allowed");
});
// Test 4: Reset functionality
await test('Should reset client correctly', async () => {
await test("Should reset client correctly", async () => {
const clientId = generateClientId();
const maxRequests = 3;
// Use up some requests
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
await rateLimitService.checkRateLimit(clientId, 'test', maxRequests);
await rateLimitService.checkRateLimit(clientId, "test", maxRequests);
// Reset the client
await rateLimitService.resetKey(clientId);
// Should be able to make fresh requests
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(!result.isLimited, 'Request after reset should be allowed');
assert(result.totalHits === 1, 'Should have 1 hit after reset');
const result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(!result.isLimited, "Request after reset should be allowed");
assert(result.totalHits === 1, "Should have 1 hit after reset");
});
// Test 5: Different clients are independent
await test('Should handle different clients independently', async () => {
await test("Should handle different clients independently", async () => {
const client1 = generateClientId();
const client2 = generateClientId();
const maxRequests = 2;
@@ -139,43 +167,62 @@ async function runTests() {
// Client 1 uses up their limit
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
const client1Blocked = await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
assert(client1Blocked.isLimited, 'Client 1 should be blocked');
const client1Blocked = await rateLimitService.checkRateLimit(
client1,
undefined,
maxRequests
);
assert(client1Blocked.isLimited, "Client 1 should be blocked");
// Client 2 should still be able to make requests
const client2Result = await rateLimitService.checkRateLimit(client2, undefined, maxRequests);
assert(!client2Result.isLimited, 'Client 2 should not be blocked');
assert(client2Result.totalHits === 1, 'Client 2 should have 1 hit');
const client2Result = await rateLimitService.checkRateLimit(
client2,
undefined,
maxRequests
);
assert(!client2Result.isLimited, "Client 2 should not be blocked");
assert(client2Result.totalHits === 1, "Client 2 should have 1 hit");
});
// Test 6: Decrement functionality
await test('Should decrement correctly', async () => {
await test("Should decrement correctly", async () => {
const clientId = generateClientId();
const maxRequests = 5;
// Make some requests
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
let result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(result.totalHits === 3, 'Should have 3 hits before decrement');
let result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(result.totalHits === 3, "Should have 3 hits before decrement");
// Decrement
await rateLimitService.decrementRateLimit(clientId);
// Next request should reflect the decrement
result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(result.totalHits === 3, 'Should have 3 hits after decrement + increment');
result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(
result.totalHits === 3,
"Should have 3 hits after decrement + increment"
);
});
// Wait a moment for any pending Redis operations
console.log('\nWaiting for Redis sync...');
await new Promise(resolve => setTimeout(resolve, 1000));
console.log("\nWaiting for Redis sync...");
await new Promise((resolve) => setTimeout(resolve, 1000));
// Force sync to test Redis integration
await test('Should sync to Redis', async () => {
await test("Should sync to Redis", async () => {
await rateLimitService.forceSyncAllPendingData();
// If this doesn't throw, Redis sync is working
assert(true, 'Redis sync completed');
assert(true, "Redis sync completed");
});
// Cleanup
@@ -185,18 +232,18 @@ async function runTests() {
console.log(`\n--- Test Results ---`);
console.log(`✅ Passed: ${testsPassed}/${testsTotal}`);
console.log(`❌ Failed: ${testsTotal - testsPassed}/${testsTotal}`);
if (testsPassed === testsTotal) {
console.log('\n🎉 All tests passed!');
console.log("\n🎉 All tests passed!");
process.exit(0);
} else {
console.log('\n💥 Some tests failed!');
console.log("\n💥 Some tests failed!");
process.exit(1);
}
}
// Run the tests
runTests().catch(error => {
console.error('Test runner error:', error);
runTests().catch((error) => {
console.error("Test runner error:", error);
process.exit(1);
});
});

View File

@@ -40,7 +40,8 @@ interface RateLimitResult {
export class RateLimitService {
private localRateLimitTracker: Map<string, RateLimitTracker> = new Map();
private localMessageTypeRateLimitTracker: Map<string, RateLimitTracker> = new Map();
private localMessageTypeRateLimitTracker: Map<string, RateLimitTracker> =
new Map();
private cleanupInterval: NodeJS.Timeout | null = null;
private forceSyncInterval: NodeJS.Timeout | null = null;
@@ -68,12 +69,18 @@ export class RateLimitService {
return `ratelimit:${clientId}`;
}
private getMessageTypeRateLimitKey(clientId: string, messageType: string): string {
private getMessageTypeRateLimitKey(
clientId: string,
messageType: string
): string {
return `ratelimit:${clientId}:${messageType}`;
}
// Helper function to clean up old timestamp fields from a Redis hash
private async cleanupOldTimestamps(key: string, windowStart: number): Promise<void> {
private async cleanupOldTimestamps(
key: string,
windowStart: number
): Promise<void> {
if (!redisManager.isRedisEnabled()) return;
try {
@@ -101,10 +108,15 @@ export class RateLimitService {
const batch = fieldsToDelete.slice(i, i + batchSize);
await client.hdel(key, ...batch);
}
logger.debug(`Cleaned up ${fieldsToDelete.length} old timestamp fields from ${key}`);
logger.debug(
`Cleaned up ${fieldsToDelete.length} old timestamp fields from ${key}`
);
}
} catch (error) {
logger.error(`Failed to cleanup old timestamps for key ${key}:`, error);
logger.error(
`Failed to cleanup old timestamps for key ${key}:`,
error
);
// Don't throw - cleanup failures shouldn't block rate limiting
}
}
@@ -114,7 +126,8 @@ export class RateLimitService {
clientId: string,
tracker: RateLimitTracker
): Promise<void> {
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0)
return;
try {
const currentTime = Math.floor(Date.now() / 1000);
@@ -132,7 +145,11 @@ export class RateLimitService {
const newValue = (
parseInt(currentValue || "0") + tracker.pendingCount
).toString();
await redisManager.hset(globalKey, currentTime.toString(), newValue);
await redisManager.hset(
globalKey,
currentTime.toString(),
newValue
);
// Set TTL using the client directly - this prevents the key from persisting forever
if (redisManager.getClient()) {
@@ -145,7 +162,9 @@ export class RateLimitService {
tracker.lastSyncedCount = tracker.count;
tracker.pendingCount = 0;
logger.debug(`Synced global rate limit to Redis for client ${clientId}`);
logger.debug(
`Synced global rate limit to Redis for client ${clientId}`
);
} catch (error) {
logger.error("Failed to sync global rate limit to Redis:", error);
}
@@ -156,12 +175,16 @@ export class RateLimitService {
messageType: string,
tracker: RateLimitTracker
): Promise<void> {
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0)
return;
try {
const currentTime = Math.floor(Date.now() / 1000);
const windowStart = currentTime - RATE_LIMIT_WINDOW;
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
const messageTypeKey = this.getMessageTypeRateLimitKey(
clientId,
messageType
);
// Clean up old timestamp fields before writing
await this.cleanupOldTimestamps(messageTypeKey, windowStart);
@@ -195,12 +218,17 @@ export class RateLimitService {
`Synced message type rate limit to Redis for client ${clientId}, type ${messageType}`
);
} catch (error) {
logger.error("Failed to sync message type rate limit to Redis:", error);
logger.error(
"Failed to sync message type rate limit to Redis:",
error
);
}
}
// Initialize local tracker from Redis data
private async initializeLocalTracker(clientId: string): Promise<RateLimitTracker> {
private async initializeLocalTracker(
clientId: string
): Promise<RateLimitTracker> {
const currentTime = Math.floor(Date.now() / 1000);
const windowStart = currentTime - RATE_LIMIT_WINDOW;
@@ -215,14 +243,16 @@ export class RateLimitService {
try {
const globalKey = this.getRateLimitKey(clientId);
// Clean up old timestamp fields before reading
await this.cleanupOldTimestamps(globalKey, windowStart);
const globalRateLimitData = await redisManager.hgetall(globalKey);
let count = 0;
for (const [timestamp, countStr] of Object.entries(globalRateLimitData)) {
for (const [timestamp, countStr] of Object.entries(
globalRateLimitData
)) {
const time = parseInt(timestamp);
if (time >= windowStart) {
count += parseInt(countStr);
@@ -236,7 +266,10 @@ export class RateLimitService {
lastSyncedCount: count
};
} catch (error) {
logger.error("Failed to initialize global tracker from Redis:", error);
logger.error(
"Failed to initialize global tracker from Redis:",
error
);
return {
count: 0,
windowStart: currentTime,
@@ -263,15 +296,21 @@ export class RateLimitService {
}
try {
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
const messageTypeKey = this.getMessageTypeRateLimitKey(
clientId,
messageType
);
// Clean up old timestamp fields before reading
await this.cleanupOldTimestamps(messageTypeKey, windowStart);
const messageTypeRateLimitData = await redisManager.hgetall(messageTypeKey);
const messageTypeRateLimitData =
await redisManager.hgetall(messageTypeKey);
let count = 0;
for (const [timestamp, countStr] of Object.entries(messageTypeRateLimitData)) {
for (const [timestamp, countStr] of Object.entries(
messageTypeRateLimitData
)) {
const time = parseInt(timestamp);
if (time >= windowStart) {
count += parseInt(countStr);
@@ -285,7 +324,10 @@ export class RateLimitService {
lastSyncedCount: count
};
} catch (error) {
logger.error("Failed to initialize message type tracker from Redis:", error);
logger.error(
"Failed to initialize message type tracker from Redis:",
error
);
return {
count: 0,
windowStart: currentTime,
@@ -327,7 +369,10 @@ export class RateLimitService {
isLimited: true,
reason: "global",
totalHits: globalTracker.count,
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
resetTime: new Date(
(globalTracker.windowStart + Math.floor(windowMs / 1000)) *
1000
)
};
}
@@ -339,19 +384,32 @@ export class RateLimitService {
// Check message type specific rate limit if messageType is provided
if (messageType) {
const messageTypeKey = `${clientId}:${messageType}`;
let messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
let messageTypeTracker =
this.localMessageTypeRateLimitTracker.get(messageTypeKey);
if (!messageTypeTracker || messageTypeTracker.windowStart < windowStart) {
if (
!messageTypeTracker ||
messageTypeTracker.windowStart < windowStart
) {
// New window or first request for this message type - initialize from Redis if available
messageTypeTracker = await this.initializeMessageTypeTracker(clientId, messageType);
messageTypeTracker = await this.initializeMessageTypeTracker(
clientId,
messageType
);
messageTypeTracker.windowStart = currentTime;
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
this.localMessageTypeRateLimitTracker.set(
messageTypeKey,
messageTypeTracker
);
}
// Increment message type counters
messageTypeTracker.count++;
messageTypeTracker.pendingCount++;
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
this.localMessageTypeRateLimitTracker.set(
messageTypeKey,
messageTypeTracker
);
// Check if message type limit would be exceeded
if (messageTypeTracker.count >= messageTypeLimit) {
@@ -359,25 +417,38 @@ export class RateLimitService {
isLimited: true,
reason: `message_type:${messageType}`,
totalHits: messageTypeTracker.count,
resetTime: new Date((messageTypeTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
resetTime: new Date(
(messageTypeTracker.windowStart +
Math.floor(windowMs / 1000)) *
1000
)
};
}
// Sync to Redis if threshold reached
if (messageTypeTracker.pendingCount >= REDIS_SYNC_THRESHOLD) {
this.syncMessageTypeRateLimitToRedis(clientId, messageType, messageTypeTracker);
this.syncMessageTypeRateLimitToRedis(
clientId,
messageType,
messageTypeTracker
);
}
}
return {
isLimited: false,
totalHits: globalTracker.count,
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
resetTime: new Date(
(globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000
)
};
}
// Decrement function for skipSuccessfulRequests/skipFailedRequests functionality
async decrementRateLimit(clientId: string, messageType?: string): Promise<void> {
async decrementRateLimit(
clientId: string,
messageType?: string
): Promise<void> {
// Decrement global counter
const globalTracker = this.localRateLimitTracker.get(clientId);
if (globalTracker && globalTracker.count > 0) {
@@ -389,7 +460,8 @@ export class RateLimitService {
// Decrement message type counter if provided
if (messageType) {
const messageTypeKey = `${clientId}:${messageType}`;
const messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
const messageTypeTracker =
this.localMessageTypeRateLimitTracker.get(messageTypeKey);
if (messageTypeTracker && messageTypeTracker.count > 0) {
messageTypeTracker.count--;
messageTypeTracker.pendingCount--;
@@ -401,7 +473,7 @@ export class RateLimitService {
async resetKey(clientId: string): Promise<void> {
// Remove from local tracking
this.localRateLimitTracker.delete(clientId);
// Remove all message type entries for this client
for (const [key] of this.localMessageTypeRateLimitTracker) {
if (key.startsWith(`${clientId}:`)) {
@@ -417,9 +489,13 @@ export class RateLimitService {
// Get all message type keys for this client and delete them
const client = redisManager.getClient();
if (client) {
const messageTypeKeys = await client.keys(`ratelimit:${clientId}:*`);
const messageTypeKeys = await client.keys(
`ratelimit:${clientId}:*`
);
if (messageTypeKeys.length > 0) {
await Promise.all(messageTypeKeys.map(key => redisManager.del(key)));
await Promise.all(
messageTypeKeys.map((key) => redisManager.del(key))
);
}
}
}
@@ -431,7 +507,10 @@ export class RateLimitService {
const windowStart = currentTime - RATE_LIMIT_WINDOW;
// Clean up global rate limit tracking and sync pending data
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
for (const [
clientId,
tracker
] of this.localRateLimitTracker.entries()) {
if (tracker.windowStart < windowStart) {
// Sync any pending data before cleanup
if (tracker.pendingCount > 0) {
@@ -442,12 +521,19 @@ export class RateLimitService {
}
// Clean up message type rate limit tracking and sync pending data
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
for (const [
key,
tracker
] of this.localMessageTypeRateLimitTracker.entries()) {
if (tracker.windowStart < windowStart) {
// Sync any pending data before cleanup
if (tracker.pendingCount > 0) {
const [clientId, messageType] = key.split(":", 2);
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
await this.syncMessageTypeRateLimitToRedis(
clientId,
messageType,
tracker
);
}
this.localMessageTypeRateLimitTracker.delete(key);
}
@@ -461,17 +547,27 @@ export class RateLimitService {
logger.debug("Force syncing all pending rate limit data to Redis...");
// Sync all pending global rate limits
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
for (const [
clientId,
tracker
] of this.localRateLimitTracker.entries()) {
if (tracker.pendingCount > 0) {
await this.syncRateLimitToRedis(clientId, tracker);
}
}
// Sync all pending message type rate limits
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
for (const [
key,
tracker
] of this.localMessageTypeRateLimitTracker.entries()) {
if (tracker.pendingCount > 0) {
const [clientId, messageType] = key.split(":", 2);
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
await this.syncMessageTypeRateLimitToRedis(
clientId,
messageType,
tracker
);
}
}
@@ -504,4 +600,4 @@ export class RateLimitService {
}
// Export singleton instance
export const rateLimitService = new RateLimitService();
export const rateLimitService = new RateLimitService();

View File

@@ -17,7 +17,10 @@ import { MemoryStore, Store } from "express-rate-limit";
import RedisStore from "#private/lib/redisStore";
export function createStore(): Store {
if (build != "oss" && privateConfig.getRawPrivateConfig().flags.enable_redis) {
if (
build != "oss" &&
privateConfig.getRawPrivateConfig().flags.enable_redis
) {
const rateLimitStore: Store = new RedisStore({
prefix: "api-rate-limit", // Optional: customize Redis key prefix
skipFailedRequests: true, // Don't count failed requests

View File

@@ -19,7 +19,7 @@ import { build } from "@server/build";
class RedisManager {
public client: Redis | null = null;
private writeClient: Redis | null = null; // Master for writes
private readClient: Redis | null = null; // Replica for reads
private readClient: Redis | null = null; // Replica for reads
private subscriber: Redis | null = null;
private publisher: Redis | null = null;
private isEnabled: boolean = false;
@@ -46,7 +46,8 @@ class RedisManager {
this.isEnabled = false;
return;
}
this.isEnabled = privateConfig.getRawPrivateConfig().flags.enable_redis || false;
this.isEnabled =
privateConfig.getRawPrivateConfig().flags.enable_redis || false;
if (this.isEnabled) {
this.initializeClients();
}
@@ -63,15 +64,19 @@ class RedisManager {
}
private async triggerReconnectionCallbacks(): Promise<void> {
logger.info(`Triggering ${this.reconnectionCallbacks.size} reconnection callbacks`);
const promises = Array.from(this.reconnectionCallbacks).map(async (callback) => {
try {
await callback();
} catch (error) {
logger.error("Error in reconnection callback:", error);
logger.info(
`Triggering ${this.reconnectionCallbacks.size} reconnection callbacks`
);
const promises = Array.from(this.reconnectionCallbacks).map(
async (callback) => {
try {
await callback();
} catch (error) {
logger.error("Error in reconnection callback:", error);
}
}
});
);
await Promise.allSettled(promises);
}
@@ -79,13 +84,17 @@ class RedisManager {
private async resubscribeToChannels(): Promise<void> {
if (!this.subscriber || this.subscribers.size === 0) return;
logger.info(`Re-subscribing to ${this.subscribers.size} channels after Redis reconnection`);
logger.info(
`Re-subscribing to ${this.subscribers.size} channels after Redis reconnection`
);
try {
const channels = Array.from(this.subscribers.keys());
if (channels.length > 0) {
await this.subscriber.subscribe(...channels);
logger.info(`Successfully re-subscribed to channels: ${channels.join(', ')}`);
logger.info(
`Successfully re-subscribed to channels: ${channels.join(", ")}`
);
}
} catch (error) {
logger.error("Failed to re-subscribe to channels:", error);
@@ -98,7 +107,7 @@ class RedisManager {
host: redisConfig.host!,
port: redisConfig.port!,
password: redisConfig.password,
db: redisConfig.db,
db: redisConfig.db
// tls: {
// rejectUnauthorized:
// redisConfig.tls?.reject_unauthorized || false
@@ -112,7 +121,7 @@ class RedisManager {
if (!redisConfig.replicas || redisConfig.replicas.length === 0) {
return null;
}
// Use the first replica for simplicity
// In production, you might want to implement load balancing across replicas
const replica = redisConfig.replicas[0];
@@ -120,7 +129,7 @@ class RedisManager {
host: replica.host!,
port: replica.port!,
password: replica.password,
db: replica.db || redisConfig.db,
db: replica.db || redisConfig.db
// tls: {
// rejectUnauthorized:
// replica.tls?.reject_unauthorized || false
@@ -133,7 +142,7 @@ class RedisManager {
private initializeClients(): void {
const masterConfig = this.getRedisConfig();
const replicaConfig = this.getReplicaRedisConfig();
this.hasReplicas = replicaConfig !== null;
try {
@@ -144,7 +153,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
// Initialize replica connection for reads (if available)
@@ -155,7 +164,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
} else {
// Fallback to master for reads if no replicas
@@ -172,7 +181,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
// Subscriber uses replica if available (reads)
@@ -182,7 +191,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
// Add reconnection handlers for write client
@@ -202,11 +211,14 @@ class RedisManager {
logger.info("Redis write client ready");
this.isWriteHealthy = true;
this.updateOverallHealth();
// Trigger reconnection callbacks when Redis comes back online
if (this.isHealthy) {
this.triggerReconnectionCallbacks().catch(error => {
logger.error("Error triggering reconnection callbacks:", error);
this.triggerReconnectionCallbacks().catch((error) => {
logger.error(
"Error triggering reconnection callbacks:",
error
);
});
}
});
@@ -233,11 +245,14 @@ class RedisManager {
logger.info("Redis read client ready");
this.isReadHealthy = true;
this.updateOverallHealth();
// Trigger reconnection callbacks when Redis comes back online
if (this.isHealthy) {
this.triggerReconnectionCallbacks().catch(error => {
logger.error("Error triggering reconnection callbacks:", error);
this.triggerReconnectionCallbacks().catch((error) => {
logger.error(
"Error triggering reconnection callbacks:",
error
);
});
}
});
@@ -298,8 +313,8 @@ class RedisManager {
}
);
const setupMessage = this.hasReplicas
? "Redis clients initialized successfully with replica support"
const setupMessage = this.hasReplicas
? "Redis clients initialized successfully with replica support"
: "Redis clients initialized successfully (single instance)";
logger.info(setupMessage);
@@ -313,7 +328,8 @@ class RedisManager {
private updateOverallHealth(): void {
// Overall health is true if write is healthy and (read is healthy OR we don't have replicas)
this.isHealthy = this.isWriteHealthy && (this.isReadHealthy || !this.hasReplicas);
this.isHealthy =
this.isWriteHealthy && (this.isReadHealthy || !this.hasReplicas);
}
private async executeWithRetry<T>(
@@ -322,49 +338,61 @@ class RedisManager {
fallbackOperation?: () => Promise<T>
): Promise<T> {
let lastError: Error | null = null;
for (let attempt = 0; attempt <= this.maxRetries; attempt++) {
try {
return await operation();
} catch (error) {
lastError = error as Error;
// If this is the last attempt, try fallback if available
if (attempt === this.maxRetries && fallbackOperation) {
try {
logger.warn(`${operationName} primary operation failed, trying fallback`);
logger.warn(
`${operationName} primary operation failed, trying fallback`
);
return await fallbackOperation();
} catch (fallbackError) {
logger.error(`${operationName} fallback also failed:`, fallbackError);
logger.error(
`${operationName} fallback also failed:`,
fallbackError
);
throw lastError;
}
}
// Don't retry on the last attempt
if (attempt === this.maxRetries) {
break;
}
// Calculate delay with exponential backoff
const delay = Math.min(
this.baseRetryDelay * Math.pow(this.backoffMultiplier, attempt),
this.baseRetryDelay *
Math.pow(this.backoffMultiplier, attempt),
this.maxRetryDelay
);
logger.warn(`${operationName} failed (attempt ${attempt + 1}/${this.maxRetries + 1}), retrying in ${delay}ms:`, error);
logger.warn(
`${operationName} failed (attempt ${attempt + 1}/${this.maxRetries + 1}), retrying in ${delay}ms:`,
error
);
// Wait before retrying
await new Promise(resolve => setTimeout(resolve, delay));
await new Promise((resolve) => setTimeout(resolve, delay));
}
}
logger.error(`${operationName} failed after ${this.maxRetries + 1} attempts:`, lastError);
logger.error(
`${operationName} failed after ${this.maxRetries + 1} attempts:`,
lastError
);
throw lastError;
}
private startHealthMonitoring(): void {
if (!this.isEnabled) return;
// Check health every 30 seconds
setInterval(async () => {
try {
@@ -381,7 +409,7 @@ class RedisManager {
private async checkRedisHealth(): Promise<boolean> {
const now = Date.now();
// Only check health every 30 seconds
if (now - this.lastHealthCheck < this.healthCheckInterval) {
return this.isHealthy;
@@ -400,24 +428,45 @@ class RedisManager {
// Check write client (master) health
await Promise.race([
this.writeClient.ping(),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Write client health check timeout')), 2000)
new Promise((_, reject) =>
setTimeout(
() =>
reject(
new Error("Write client health check timeout")
),
2000
)
)
]);
this.isWriteHealthy = true;
// Check read client health if it's different from write client
if (this.hasReplicas && this.readClient && this.readClient !== this.writeClient) {
if (
this.hasReplicas &&
this.readClient &&
this.readClient !== this.writeClient
) {
try {
await Promise.race([
this.readClient.ping(),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Read client health check timeout')), 2000)
new Promise((_, reject) =>
setTimeout(
() =>
reject(
new Error(
"Read client health check timeout"
)
),
2000
)
)
]);
this.isReadHealthy = true;
} catch (error) {
logger.error("Redis read client health check failed:", error);
logger.error(
"Redis read client health check failed:",
error
);
this.isReadHealthy = false;
}
} else {
@@ -475,16 +524,13 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.writeClient) return false;
try {
await this.executeWithRetry(
async () => {
if (ttl) {
await this.writeClient!.setex(key, ttl, value);
} else {
await this.writeClient!.set(key, value);
}
},
"Redis SET"
);
await this.executeWithRetry(async () => {
if (ttl) {
await this.writeClient!.setex(key, ttl, value);
} else {
await this.writeClient!.set(key, value);
}
}, "Redis SET");
return true;
} catch (error) {
logger.error("Redis SET error:", error);
@@ -496,9 +542,10 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return null;
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
? () => this.writeClient!.get(key)
: undefined;
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.get(key)
: undefined;
return await this.executeWithRetry(
() => this.readClient!.get(key),
@@ -560,9 +607,10 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return [];
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
? () => this.writeClient!.smembers(key)
: undefined;
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.smembers(key)
: undefined;
return await this.executeWithRetry(
() => this.readClient!.smembers(key),
@@ -598,9 +646,10 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return null;
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
? () => this.writeClient!.hget(key, field)
: undefined;
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.hget(key, field)
: undefined;
return await this.executeWithRetry(
() => this.readClient!.hget(key, field),
@@ -632,9 +681,10 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return {};
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
? () => this.writeClient!.hgetall(key)
: undefined;
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.hgetall(key)
: undefined;
return await this.executeWithRetry(
() => this.readClient!.hgetall(key),
@@ -658,18 +708,18 @@ class RedisManager {
}
try {
await this.executeWithRetry(
async () => {
// Add timeout to prevent hanging
return Promise.race([
this.publisher!.publish(channel, message),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Redis publish timeout')), 3000)
await this.executeWithRetry(async () => {
// Add timeout to prevent hanging
return Promise.race([
this.publisher!.publish(channel, message),
new Promise((_, reject) =>
setTimeout(
() => reject(new Error("Redis publish timeout")),
3000
)
]);
},
"Redis PUBLISH"
);
)
]);
}, "Redis PUBLISH");
return true;
} catch (error) {
logger.error("Redis PUBLISH error:", error);
@@ -689,17 +739,20 @@ class RedisManager {
if (!this.subscribers.has(channel)) {
this.subscribers.set(channel, new Set());
// Only subscribe to the channel if it's the first subscriber
await this.executeWithRetry(
async () => {
return Promise.race([
this.subscriber!.subscribe(channel),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Redis subscribe timeout')), 5000)
await this.executeWithRetry(async () => {
return Promise.race([
this.subscriber!.subscribe(channel),
new Promise((_, reject) =>
setTimeout(
() =>
reject(
new Error("Redis subscribe timeout")
),
5000
)
]);
},
"Redis SUBSCRIBE"
);
)
]);
}, "Redis SUBSCRIBE");
}
this.subscribers.get(channel)!.add(callback);

View File

@@ -11,9 +11,9 @@
* This file is not licensed under the AGPLv3.
*/
import { Store, Options, IncrementResponse } from 'express-rate-limit';
import { rateLimitService } from './rateLimit';
import logger from '@server/logger';
import { Store, Options, IncrementResponse } from "express-rate-limit";
import { rateLimitService } from "./rateLimit";
import logger from "@server/logger";
/**
* A Redis-backed rate limiting store for express-rate-limit that optimizes
@@ -57,12 +57,14 @@ export default class RedisStore implements Store {
*
* @param options - Configuration options for the store.
*/
constructor(options: {
prefix?: string;
skipFailedRequests?: boolean;
skipSuccessfulRequests?: boolean;
} = {}) {
this.prefix = options.prefix || 'express-rate-limit';
constructor(
options: {
prefix?: string;
skipFailedRequests?: boolean;
skipSuccessfulRequests?: boolean;
} = {}
) {
this.prefix = options.prefix || "express-rate-limit";
this.skipFailedRequests = options.skipFailedRequests || false;
this.skipSuccessfulRequests = options.skipSuccessfulRequests || false;
}
@@ -101,7 +103,8 @@ export default class RedisStore implements Store {
return {
totalHits: result.totalHits || 1,
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
resetTime:
result.resetTime || new Date(Date.now() + this.windowMs)
};
} catch (error) {
logger.error(`RedisStore increment error for key ${key}:`, error);
@@ -158,7 +161,9 @@ export default class RedisStore implements Store {
*/
async resetAll(): Promise<void> {
try {
logger.warn('RedisStore resetAll called - this operation can be expensive');
logger.warn(
"RedisStore resetAll called - this operation can be expensive"
);
// Force sync all pending data first
await rateLimitService.forceSyncAllPendingData();
@@ -167,9 +172,9 @@ export default class RedisStore implements Store {
// scanning all Redis keys with our prefix, which could be expensive.
// In production, it's better to let entries expire naturally.
logger.info('RedisStore resetAll completed (pending data synced)');
logger.info("RedisStore resetAll completed (pending data synced)");
} catch (error) {
logger.error('RedisStore resetAll error:', error);
logger.error("RedisStore resetAll error:", error);
// Don't throw - this is an optional method
}
}
@@ -181,7 +186,9 @@ export default class RedisStore implements Store {
* @param key - The identifier for a client.
* @returns Current hit count and reset time, or null if no data exists.
*/
async getHits(key: string): Promise<{ totalHits: number; resetTime: Date } | null> {
async getHits(
key: string
): Promise<{ totalHits: number; resetTime: Date } | null> {
try {
const clientId = `${this.prefix}:${key}`;
@@ -200,7 +207,8 @@ export default class RedisStore implements Store {
return {
totalHits: Math.max(0, (result.totalHits || 0) - 1), // Adjust for the decrement
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
resetTime:
result.resetTime || new Date(Date.now() + this.windowMs)
};
} catch (error) {
logger.error(`RedisStore getHits error for key ${key}:`, error);
@@ -215,9 +223,9 @@ export default class RedisStore implements Store {
async shutdown(): Promise<void> {
try {
// The rateLimitService handles its own cleanup
logger.info('RedisStore shutdown completed');
logger.info("RedisStore shutdown completed");
} catch (error) {
logger.error('RedisStore shutdown error:', error);
logger.error("RedisStore shutdown error:", error);
}
}
}

View File

@@ -16,10 +16,10 @@ import privateConfig from "#private/lib/config";
import logger from "@server/logger";
export enum AudienceIds {
SignUps = "6c4e77b2-0851-4bd6-bac8-f51f91360f1a",
Subscribed = "870b43fd-387f-44de-8fc1-707335f30b20",
Churned = "f3ae92bd-2fdb-4d77-8746-2118afd62549",
Newsletter = "5500c431-191c-42f0-a5d4-8b6d445b4ea0"
SignUps = "6c4e77b2-0851-4bd6-bac8-f51f91360f1a",
Subscribed = "870b43fd-387f-44de-8fc1-707335f30b20",
Churned = "f3ae92bd-2fdb-4d77-8746-2118afd62549",
Newsletter = "5500c431-191c-42f0-a5d4-8b6d445b4ea0"
}
const resend = new Resend(
@@ -33,7 +33,9 @@ export async function moveEmailToAudience(
audienceId: AudienceIds
) {
if (process.env.ENVIRONMENT !== "prod") {
logger.debug(`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`);
logger.debug(
`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`
);
return;
}
const { error, data } = await retryWithBackoff(async () => {

View File

@@ -11,4 +11,4 @@
* This file is not licensed under the AGPLv3.
*/
export * from "./getTraefikConfig";
export * from "./getTraefikConfig";

View File

@@ -19,10 +19,7 @@ import * as crypto from "crypto";
* @param publicKey - The public key used for verification (PEM format)
* @returns The decoded payload if validation succeeds, throws an error otherwise
*/
function validateJWT<Payload>(
token: string,
publicKey: string
): Payload {
function validateJWT<Payload>(token: string, publicKey: string): Payload {
// Split the JWT into its three parts
const parts = token.split(".");
if (parts.length !== 3) {

View File

@@ -41,7 +41,11 @@ async function getActionDays(orgId: string): Promise<number> {
}
// store the result in cache
cache.set(`org_${orgId}_actionDays`, org.settingsLogRetentionDaysAction, 300);
cache.set(
`org_${orgId}_actionDays`,
org.settingsLogRetentionDaysAction,
300
);
return org.settingsLogRetentionDaysAction;
}
@@ -141,4 +145,3 @@ export function logActionAudit(action: ActionsEnum) {
}
};
}

View File

@@ -28,7 +28,8 @@ export async function verifyCertificateAccess(
try {
// Assume user/org access is already verified
const orgId = req.params.orgId;
const certId = req.params.certId || req.body?.certId || req.query?.certId;
const certId =
req.params.certId || req.body?.certId || req.query?.certId;
let domainId =
req.params.domainId || req.body?.domainId || req.query?.domainId;
@@ -39,10 +40,12 @@ export async function verifyCertificateAccess(
}
if (!domainId) {
if (!certId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Must provide either certId or domainId")
createHttpError(
HttpCode.BAD_REQUEST,
"Must provide either certId or domainId"
)
);
}
@@ -75,7 +78,10 @@ export async function verifyCertificateAccess(
if (!domainId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Must provide either certId or domainId")
createHttpError(
HttpCode.BAD_REQUEST,
"Must provide either certId or domainId"
)
);
}

View File

@@ -24,8 +24,7 @@ export async function verifyIdpAccess(
) {
try {
const userId = req.user!.userId;
const idpId =
req.params.idpId || req.body.idpId || req.query.idpId;
const idpId = req.params.idpId || req.body.idpId || req.query.idpId;
const orgId = req.params.orgId;
if (!userId) {
@@ -50,9 +49,7 @@ export async function verifyIdpAccess(
.select()
.from(idp)
.innerJoin(idpOrg, eq(idp.idpId, idpOrg.idpId))
.where(
and(eq(idp.idpId, idpId), eq(idpOrg.orgId, orgId))
)
.where(and(eq(idp.idpId, idpId), eq(idpOrg.orgId, orgId)))
.limit(1);
if (!idpRes || !idpRes.idp || !idpRes.idpOrg) {

View File

@@ -26,7 +26,8 @@ export const verifySessionRemoteExitNodeMiddleware = async (
// get the token from the auth header
const token = req.headers["authorization"]?.split(" ")[1] || "";
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
const { session, remoteExitNode } =
await validateRemoteExitNodeSessionToken(token);
if (!session || !remoteExitNode) {
if (config.getRawConfig().app.log_failed_attempts) {

View File

@@ -19,7 +19,11 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { queryAccessAuditLogsParams, queryAccessAuditLogsQuery, queryAccess } from "./queryAccessAuditLog";
import {
queryAccessAuditLogsParams,
queryAccessAuditLogsQuery,
queryAccess
} from "./queryAccessAuditLog";
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
registry.registerPath({
@@ -67,10 +71,13 @@ export async function exportAccessAuditLogs(
const log = await baseQuery.limit(data.limit).offset(data.offset);
const csvData = generateCSV(log);
res.setHeader('Content-Type', 'text/csv');
res.setHeader('Content-Disposition', `attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`);
res.setHeader("Content-Type", "text/csv");
res.setHeader(
"Content-Disposition",
`attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`
);
return res.send(csvData);
} catch (error) {
logger.error(error);
@@ -78,4 +85,4 @@ export async function exportAccessAuditLogs(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}
}

View File

@@ -19,7 +19,11 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { queryActionAuditLogsParams, queryActionAuditLogsQuery, queryAction } from "./queryActionAuditLog";
import {
queryActionAuditLogsParams,
queryActionAuditLogsQuery,
queryAction
} from "./queryActionAuditLog";
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
registry.registerPath({
@@ -60,17 +64,20 @@ export async function exportActionAuditLogs(
);
}
const data = { ...parsedQuery.data, ...parsedParams.data };
const data = { ...parsedQuery.data, ...parsedParams.data };
const baseQuery = queryAction(data);
const log = await baseQuery.limit(data.limit).offset(data.offset);
const csvData = generateCSV(log);
res.setHeader('Content-Type', 'text/csv');
res.setHeader('Content-Disposition', `attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`);
res.setHeader("Content-Type", "text/csv");
res.setHeader(
"Content-Disposition",
`attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`
);
return res.send(csvData);
} catch (error) {
logger.error(error);
@@ -78,4 +85,4 @@ export async function exportActionAuditLogs(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}
}

View File

@@ -14,4 +14,4 @@
export * from "./queryActionAuditLog";
export * from "./exportActionAuditLog";
export * from "./queryAccessAuditLog";
export * from "./exportAccessAuditLog";
export * from "./exportAccessAuditLog";

View File

@@ -44,7 +44,8 @@ export const queryAccessAuditLogsQuery = z.object({
.openapi({
type: "string",
format: "date-time",
description: "End time as ISO date string (defaults to current time)"
description:
"End time as ISO date string (defaults to current time)"
}),
action: z
.union([z.boolean(), z.string()])
@@ -181,9 +182,15 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
return {
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
resources: uniqueResources.filter((row): row is { id: number; name: string | null } => row.id !== null),
locations: uniqueLocations.map(row => row.locations).filter((location): location is string => location !== null)
actors: uniqueActors
.map((row) => row.actor)
.filter((actor): actor is string => actor !== null),
resources: uniqueResources.filter(
(row): row is { id: number; name: string | null } => row.id !== null
),
locations: uniqueLocations
.map((row) => row.locations)
.filter((location): location is string => location !== null)
};
}

View File

@@ -44,7 +44,8 @@ export const queryActionAuditLogsQuery = z.object({
.openapi({
type: "string",
format: "date-time",
description: "End time as ISO date string (defaults to current time)"
description:
"End time as ISO date string (defaults to current time)"
}),
action: z.string().optional(),
actorType: z.string().optional(),
@@ -68,8 +69,9 @@ export const queryActionAuditLogsParams = z.object({
orgId: z.string()
});
export const queryActionAuditLogsCombined =
queryActionAuditLogsQuery.merge(queryActionAuditLogsParams);
export const queryActionAuditLogsCombined = queryActionAuditLogsQuery.merge(
queryActionAuditLogsParams
);
type Q = z.infer<typeof queryActionAuditLogsCombined>;
function getWhere(data: Q) {
@@ -78,7 +80,9 @@ function getWhere(data: Q) {
lt(actionAuditLog.timestamp, data.timeEnd),
eq(actionAuditLog.orgId, data.orgId),
data.actor ? eq(actionAuditLog.actor, data.actor) : undefined,
data.actorType ? eq(actionAuditLog.actorType, data.actorType) : undefined,
data.actorType
? eq(actionAuditLog.actorType, data.actorType)
: undefined,
data.actorId ? eq(actionAuditLog.actorId, data.actorId) : undefined,
data.action ? eq(actionAuditLog.action, data.action) : undefined
);
@@ -135,8 +139,12 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
return {
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
actions: uniqueActions.map(row => row.action).filter((action): action is string => action !== null),
actors: uniqueActors
.map((row) => row.actor)
.filter((actor): actor is string => actor !== null),
actions: uniqueActions
.map((row) => row.action)
.filter((action): action is string => action !== null)
};
}

View File

@@ -13,4 +13,4 @@
export * from "./transferSession";
export * from "./getSessionTransferToken";
export * from "./quickStart";
export * from "./quickStart";

View File

@@ -395,7 +395,8 @@ export async function quickStart(
.values({
targetId: newTarget[0].targetId,
hcEnabled: false
}).returning();
})
.returning();
// add the new target to the targetIps array
targetIps.push(`${ip}/32`);
@@ -406,7 +407,12 @@ export async function quickStart(
.where(eq(newts.siteId, siteId!))
.limit(1);
await addTargets(newt.newtId, newTarget, newHealthcheck, resource.protocol);
await addTargets(
newt.newtId,
newTarget,
newHealthcheck,
resource.protocol
);
// Set resource pincode if provided
if (pincode) {

Some files were not shown because too many files have changed in this diff Show More