Handle new usage tracking with multi org

This commit is contained in:
Owen
2026-02-17 17:09:48 -08:00
parent 79cf7c84dc
commit 4d6240c987
17 changed files with 432 additions and 584 deletions

View File

@@ -4,6 +4,7 @@ export enum FeatureId {
EGRESS_DATA_MB = "egressDataMb",
DOMAINS = "domains",
REMOTE_EXIT_NODES = "remoteExitNodes",
ORGINIZATIONS = "organizations",
TIER1 = "tier1"
}
@@ -19,6 +20,8 @@ export async function getFeatureDisplayName(featureId: FeatureId): Promise<strin
return "Domains";
case FeatureId.REMOTE_EXIT_NODES:
return "Remote Exit Nodes";
case FeatureId.ORGINIZATIONS:
return "Organizations";
case FeatureId.TIER1:
return "Home Lab";
default:

View File

@@ -7,18 +7,12 @@ export type LimitSet = Partial<{
};
}>;
export const sandboxLimitSet: LimitSet = {
[FeatureId.USERS]: { value: 1, description: "Sandbox limit" },
[FeatureId.SITES]: { value: 1, description: "Sandbox limit" },
[FeatureId.DOMAINS]: { value: 0, description: "Sandbox limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" },
};
export const freeLimitSet: LimitSet = {
[FeatureId.SITES]: { value: 5, description: "Basic limit" },
[FeatureId.USERS]: { value: 5, description: "Basic limit" },
[FeatureId.DOMAINS]: { value: 5, description: "Basic limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Basic limit" },
[FeatureId.ORGINIZATIONS]: { value: 1, description: "Basic limit" },
};
export const tier1LimitSet: LimitSet = {
@@ -26,6 +20,7 @@ export const tier1LimitSet: LimitSet = {
[FeatureId.SITES]: { value: 10, description: "Home limit" },
[FeatureId.DOMAINS]: { value: 10, description: "Home limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Home limit" },
[FeatureId.ORGINIZATIONS]: { value: 1, description: "Home limit" },
};
export const tier2LimitSet: LimitSet = {
@@ -45,6 +40,10 @@ export const tier2LimitSet: LimitSet = {
value: 3,
description: "Team limit"
},
[FeatureId.ORGINIZATIONS]: {
value: 1,
description: "Team limit"
}
};
export const tier3LimitSet: LimitSet = {
@@ -64,4 +63,8 @@ export const tier3LimitSet: LimitSet = {
value: 20,
description: "Business limit"
},
[FeatureId.ORGINIZATIONS]: {
value: 20,
description: "Business limit"
},
};

View File

@@ -1,12 +1,8 @@
import { eq, sql, and } from "drizzle-orm";
import { v4 as uuidv4 } from "uuid";
import { PutObjectCommand } from "@aws-sdk/client-s3";
import {
db,
usage,
customers,
sites,
newts,
limits,
Usage,
Limit,
@@ -15,21 +11,9 @@ import {
} from "@server/db";
import { FeatureId, getFeatureMeterId } from "./features";
import logger from "@server/logger";
import { sendToClient } from "#dynamic/routers/ws";
import { build } from "@server/build";
import { s3Client } from "@server/lib/s3";
import cache from "@server/lib/cache";
interface StripeEvent {
identifier?: string;
timestamp: number;
event_name: string;
payload: {
value: number;
stripe_customer_id: string;
};
}
export function noop() {
if (build !== "saas") {
return true;
@@ -38,41 +22,11 @@ export function noop() {
}
export class UsageService {
// private bucketName: string | undefined;
// private events: StripeEvent[] = [];
// private lastUploadTime: number = Date.now();
// private isUploading: boolean = false;
constructor() {
if (noop()) {
return;
}
// this.bucketName = process.env.S3_BUCKET || undefined;
// // Periodically check and upload events
// setInterval(() => {
// this.checkAndUploadEvents().catch((err) => {
// logger.error("Error in periodic event upload:", err);
// });
// }, 30000); // every 30 seconds
// // Handle graceful shutdown on SIGTERM
// process.on("SIGTERM", async () => {
// logger.info(
// "SIGTERM received, uploading events before shutdown..."
// );
// await this.forceUpload();
// logger.info("Events uploaded, proceeding with shutdown");
// });
// // Handle SIGINT as well (Ctrl+C)
// process.on("SIGINT", async () => {
// logger.info("SIGINT received, uploading events before shutdown...");
// await this.forceUpload();
// logger.info("Events uploaded, proceeding with shutdown");
// process.exit(0);
// });
}
/**
@@ -103,16 +57,6 @@ export class UsageService {
while (attempt <= maxRetries) {
try {
// // Get subscription data for this org (with caching)
// const customerId = await this.getCustomerId(orgIdToUse, featureId);
// if (!customerId) {
// logger.warn(
// `No subscription data found for org ${orgIdToUse} and feature ${featureId}`
// );
// return null;
// }
let usage;
if (transaction) {
usage = await this.internalAddUsage(
@@ -132,11 +76,6 @@ export class UsageService {
});
}
// Log event for Stripe
// if (privateConfig.getRawPrivateConfig().flags.usage_reporting) {
// await this.logStripeEvent(featureId, value, customerId);
// }
return usage || null;
} catch (error: any) {
// Check if this is a deadlock error
@@ -191,13 +130,14 @@ export class UsageService {
featureId,
orgId,
meterId,
instantaneousValue: value,
latestValue: value,
updatedAt: Math.floor(Date.now() / 1000)
})
.onConflictDoUpdate({
target: usage.usageId,
set: {
latestValue: sql`${usage.latestValue} + ${value}`
instantaneousValue: sql`${usage.instantaneousValue} + ${value}`
}
})
.returning();
@@ -228,17 +168,6 @@ export class UsageService {
let orgIdToUse = await this.getBillingOrg(orgId);
try {
// if (!customerId) {
// customerId =
// (await this.getCustomerId(orgIdToUse, featureId)) || undefined;
// if (!customerId) {
// logger.warn(
// `No subscription data found for org ${orgIdToUse} and feature ${featureId}`
// );
// return;
// }
// }
// Truncate value to 11 decimal places if provided
if (value !== undefined && value !== null) {
value = this.truncateValue(value);
@@ -523,114 +452,6 @@ export class UsageService {
return hasExceededLimits;
}
// private async logStripeEvent(
// featureId: FeatureId,
// value: number,
// customerId: string
// ): Promise<void> {
// // Truncate value to 11 decimal places before sending to Stripe
// const truncatedValue = this.truncateValue(value);
// const event: StripeEvent = {
// identifier: uuidv4(),
// timestamp: Math.floor(new Date().getTime() / 1000),
// event_name: featureId,
// payload: {
// value: truncatedValue,
// stripe_customer_id: customerId
// }
// };
// this.addEventToMemory(event);
// await this.checkAndUploadEvents();
// }
// private addEventToMemory(event: StripeEvent): void {
// if (!this.bucketName) {
// logger.warn(
// "S3 bucket name is not configured, skipping event storage."
// );
// return;
// }
// this.events.push(event);
// }
// private async checkAndUploadEvents(): Promise<void> {
// const now = Date.now();
// const timeSinceLastUpload = now - this.lastUploadTime;
// // Check if at least 1 minute has passed since last upload
// if (timeSinceLastUpload >= 60000 && this.events.length > 0) {
// await this.uploadEventsToS3();
// }
// }
// private async uploadEventsToS3(): Promise<void> {
// if (!this.bucketName) {
// logger.warn(
// "S3 bucket name is not configured, skipping S3 upload."
// );
// return;
// }
// if (this.events.length === 0) {
// return;
// }
// // Check if already uploading
// if (this.isUploading) {
// logger.debug("Already uploading events, skipping");
// return;
// }
// this.isUploading = true;
// try {
// // Take a snapshot of current events and clear the array
// const eventsToUpload = [...this.events];
// this.events = [];
// this.lastUploadTime = Date.now();
// const fileName = this.generateEventFileName();
// const fileContent = JSON.stringify(eventsToUpload, null, 2);
// // Upload to S3
// const uploadCommand = new PutObjectCommand({
// Bucket: this.bucketName,
// Key: fileName,
// Body: fileContent,
// ContentType: "application/json"
// });
// await s3Client.send(uploadCommand);
// logger.info(
// `Uploaded ${fileName} to S3 with ${eventsToUpload.length} events`
// );
// } catch (error) {
// logger.error("Failed to upload events to S3:", error);
// // Note: Events are lost if upload fails. In a production system,
// // you might want to add the events back to the array or implement retry logic
// } finally {
// this.isUploading = false;
// }
// }
// private generateEventFileName(): string {
// const timestamp = new Date().toISOString().replace(/[:.]/g, "-");
// const uuid = uuidv4().substring(0, 8);
// return `events-${timestamp}-${uuid}.json`;
// }
// public async forceUpload(): Promise<void> {
// if (this.events.length > 0) {
// // Force upload regardless of time
// this.lastUploadTime = 0; // Reset to force upload
// await this.uploadEventsToS3();
// }
// }
}
export const usageService = new UsageService();

View File

@@ -1,206 +0,0 @@
import { isValidCIDR } from "@server/lib/validators";
import { getNextAvailableOrgSubnet } from "@server/lib/ip";
import {
actions,
apiKeyOrg,
apiKeys,
db,
domains,
Org,
orgDomains,
orgs,
roleActions,
roles,
userOrgs
} from "@server/db";
import { eq } from "drizzle-orm";
import { defaultRoleAllowedActions } from "@server/routers/role";
import { FeatureId, limitsService, sandboxLimitSet } from "@server/lib/billing";
import { createCustomer } from "#dynamic/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import config from "@server/lib/config";
import { generateCA } from "@server/private/lib/sshCA";
import { encrypt } from "@server/lib/crypto";
export async function createUserAccountOrg(
userId: string,
userEmail: string
): Promise<{
success: boolean;
org?: {
orgId: string;
name: string;
subnet: string;
};
error?: string;
}> {
// const subnet = await getNextAvailableOrgSubnet();
const orgId = "org_" + userId;
const name = `${userEmail}'s Organization`;
// if (!isValidCIDR(subnet)) {
// return {
// success: false,
// error: "Invalid subnet format. Please provide a valid CIDR notation."
// };
// }
// // make sure the subnet is unique
// const subnetExists = await db
// .select()
// .from(orgs)
// .where(eq(orgs.subnet, subnet))
// .limit(1);
// if (subnetExists.length > 0) {
// return { success: false, error: `Subnet ${subnet} already exists` };
// }
// make sure the orgId is unique
const orgExists = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (orgExists.length > 0) {
return {
success: false,
error: `Organization with ID ${orgId} already exists`
};
}
let error = "";
let org: Org | null = null;
await db.transaction(async (trx) => {
const allDomains = await trx
.select()
.from(domains)
.where(eq(domains.configManaged, true));
const utilitySubnet = config.getRawConfig().orgs.utility_subnet_group;
// Generate SSH CA keys for the org
const ca = generateCA(`${orgId}-ca`);
const encryptionKey = config.getRawConfig().server.secret!;
const encryptedCaPrivateKey = encrypt(ca.privateKeyPem, encryptionKey);
const newOrg = await trx
.insert(orgs)
.values({
orgId,
name,
// subnet
subnet: "100.90.128.0/24", // TODO: this should not be hardcoded - or can it be the same in all orgs?
utilitySubnet: utilitySubnet,
createdAt: new Date().toISOString(),
sshCaPrivateKey: encryptedCaPrivateKey,
sshCaPublicKey: ca.publicKeyOpenSSH
})
.returning();
if (newOrg.length === 0) {
error = "Failed to create organization";
trx.rollback();
return;
}
org = newOrg[0];
// Create admin role within the same transaction
const [insertedRole] = await trx
.insert(roles)
.values({
orgId: newOrg[0].orgId,
isAdmin: true,
name: "Admin",
description: "Admin role with the most permissions"
})
.returning({ roleId: roles.roleId });
if (!insertedRole || !insertedRole.roleId) {
error = "Failed to create Admin role";
trx.rollback();
return;
}
const roleId = insertedRole.roleId;
// Get all actions and create role actions
const actionIds = await trx.select().from(actions).execute();
if (actionIds.length > 0) {
await trx.insert(roleActions).values(
actionIds.map((action) => ({
roleId,
actionId: action.actionId,
orgId: newOrg[0].orgId
}))
);
}
if (allDomains.length) {
await trx.insert(orgDomains).values(
allDomains.map((domain) => ({
orgId: newOrg[0].orgId,
domainId: domain.domainId
}))
);
}
await trx.insert(userOrgs).values({
userId,
orgId: newOrg[0].orgId,
roleId: roleId,
isOwner: true
});
const memberRole = await trx
.insert(roles)
.values({
name: "Member",
description: "Members can only view resources",
orgId
})
.returning();
await trx.insert(roleActions).values(
defaultRoleAllowedActions.map((action) => ({
roleId: memberRole[0].roleId,
actionId: action,
orgId
}))
);
});
await limitsService.applyLimitSetToOrg(orgId, sandboxLimitSet);
if (!org) {
return { success: false, error: "Failed to create org" };
}
if (error) {
return {
success: false,
error: `Failed to create org: ${error}`
};
}
// make sure we have the stripe customer
const customerId = await createCustomer(orgId, userEmail);
if (customerId) {
await usageService.updateCount(orgId, FeatureId.USERS, 1, customerId); // Only 1 because we are crating the org
}
return {
org: {
orgId,
name,
// subnet
subnet: "100.90.128.0/24"
},
success: true
};
}

View File

@@ -19,6 +19,8 @@ import { sendToClient } from "#dynamic/routers/ws";
import { deletePeer } from "@server/routers/gerbil/peers";
import { OlmErrorCodes } from "@server/routers/olm/error";
import { sendTerminateClient } from "@server/routers/client/terminate";
import { usageService } from "./billing/usageService";
import { FeatureId } from "./billing";
export type DeleteOrgByIdResult = {
deletedNewtIds: string[];
@@ -74,9 +76,7 @@ export async function deleteOrgById(
deletedNewtIds.push(deletedNewt.newtId);
await trx
.delete(newtSessions)
.where(
eq(newtSessions.newtId, deletedNewt.newtId)
);
.where(eq(newtSessions.newtId, deletedNewt.newtId));
}
}
}
@@ -137,6 +137,9 @@ export async function deleteOrgById(
.where(inArray(domains.domainId, domainIdsToDelete));
}
await trx.delete(resources).where(eq(resources.orgId, orgId));
await usageService.add(orgId, FeatureId.ORGINIZATIONS, -1, trx); // here we are decreasing the org count BEFORE deleting the org because we need to still be able to get the org to get the billing org inside of here
await trx.delete(orgs).where(eq(orgs.orgId, orgId));
});
@@ -155,15 +158,13 @@ export function sendTerminationMessages(result: DeleteOrgByIdResult): void {
);
}
for (const olmId of result.olmsToTerminate) {
sendTerminateClient(
0,
OlmErrorCodes.TERMINATED_REKEYED,
olmId
).catch((error) => {
logger.error(
"Failed to send termination message to olm:",
error
);
});
sendTerminateClient(0, OlmErrorCodes.TERMINATED_REKEYED, olmId).catch(
(error) => {
logger.error(
"Failed to send termination message to olm:",
error
);
}
);
}
}

View File

@@ -85,10 +85,14 @@ export async function getOrgUsage(
orgId,
FeatureId.REMOTE_EXIT_NODES
);
const egressData = await usageService.getUsage(
const organizations = await usageService.getUsage(
orgId,
FeatureId.EGRESS_DATA_MB
FeatureId.ORGINIZATIONS
);
// const egressData = await usageService.getUsage(
// orgId,
// FeatureId.EGRESS_DATA_MB
// );
if (sites) {
usageData.push(sites);
@@ -96,15 +100,18 @@ export async function getOrgUsage(
if (users) {
usageData.push(users);
}
if (egressData) {
usageData.push(egressData);
}
// if (egressData) {
// usageData.push(egressData);
// }
if (domains) {
usageData.push(domains);
}
if (remoteExitNodes) {
usageData.push(remoteExitNodes);
}
if (organizations) {
usageData.push(organizations);
}
const orgLimits = await db
.select()

View File

@@ -12,7 +12,14 @@
*/
import { NextFunction, Request, Response } from "express";
import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg } from "@server/db";
import {
db,
exitNodes,
exitNodeOrgs,
ExitNode,
ExitNodeOrg,
orgs
} from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { remoteExitNodes } from "@server/db";
@@ -25,7 +32,7 @@ import { createRemoteExitNodeSession } from "#private/auth/sessions/remoteExitNo
import { fromError } from "zod-validation-error";
import { hashPassword, verifyPassword } from "@server/auth/password";
import logger from "@server/logger";
import { and, eq } from "drizzle-orm";
import { and, eq, inArray, ne } from "drizzle-orm";
import { getNextAvailableSubnet } from "@server/lib/exitNodes";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
@@ -169,7 +176,17 @@ export async function createRemoteExitNode(
);
}
let numExitNodeOrgs: ExitNodeOrg[] | undefined;
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
);
}
await db.transaction(async (trx) => {
if (!existingExitNode) {
@@ -217,19 +234,43 @@ export async function createRemoteExitNode(
});
}
numExitNodeOrgs = await trx
.select()
.from(exitNodeOrgs)
.where(eq(exitNodeOrgs.orgId, orgId));
});
// calculate if the node is in any other of the orgs before we count it as an add to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(
and(
eq(orgs.billingOrgId, org.billingOrgId),
ne(orgs.orgId, orgId)
)
);
if (numExitNodeOrgs) {
await usageService.updateCount(
orgId,
FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length
);
}
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheNodeIsStillIn = await trx
.select()
.from(exitNodeOrgs)
.where(
and(
eq(
exitNodeOrgs.exitNodeId,
existingExitNode.exitNodeId
),
inArray(exitNodeOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) {
await usageService.add(
orgId,
FeatureId.REMOTE_EXIT_NODES,
1,
trx
);
}
}
});
const token = generateSessionToken();
await createRemoteExitNodeSession(token, remoteExitNodeId);

View File

@@ -13,9 +13,9 @@
import { NextFunction, Request, Response } from "express";
import { z } from "zod";
import { db, ExitNodeOrg, exitNodeOrgs, exitNodes } from "@server/db";
import { db, ExitNodeOrg, exitNodeOrgs, exitNodes, orgs } from "@server/db";
import { remoteExitNodes } from "@server/db";
import { and, count, eq } from "drizzle-orm";
import { and, count, eq, inArray } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -50,7 +50,8 @@ export async function deleteRemoteExitNode(
const [remoteExitNode] = await db
.select()
.from(remoteExitNodes)
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId));
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId))
.limit(1);
if (!remoteExitNode) {
return next(
@@ -70,7 +71,17 @@ export async function deleteRemoteExitNode(
);
}
let numExitNodeOrgs: ExitNodeOrg[] | undefined;
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Org with ID ${orgId} not found`
)
);
}
await db.transaction(async (trx) => {
await trx
.delete(exitNodeOrgs)
@@ -81,38 +92,39 @@ export async function deleteRemoteExitNode(
)
);
const [remainingExitNodeOrgs] = await trx
.select({ count: count() })
.from(exitNodeOrgs)
.where(eq(exitNodeOrgs.exitNodeId, remoteExitNode.exitNodeId!));
// calculate if the user is in any other of the orgs before we count it as an remove to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, org.billingOrgId));
if (remainingExitNodeOrgs.count === 0) {
await trx
.delete(remoteExitNodes)
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheNodeIsStillIn = await trx
.select()
.from(exitNodeOrgs)
.where(
eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId)
and(
eq(
exitNodeOrgs.exitNodeId,
remoteExitNode.exitNodeId!
),
inArray(exitNodeOrgs.orgId, billingOrgIds)
)
);
await trx
.delete(exitNodes)
.where(
eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId!)
if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) {
await usageService.add(
orgId,
FeatureId.REMOTE_EXIT_NODES,
-1,
trx
);
}
}
numExitNodeOrgs = await trx
.select()
.from(exitNodeOrgs)
.where(eq(exitNodeOrgs.orgId, orgId));
});
if (numExitNodeOrgs) {
await usageService.updateCount(
orgId,
FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length
);
}
return response(res, {
data: null,
success: true,

View File

@@ -148,7 +148,6 @@ export async function createOrgDomain(
}
}
let numOrgDomains: OrgDomains[] | undefined;
let aRecords: CreateDomainResponse["aRecords"];
let cnameRecords: CreateDomainResponse["cnameRecords"];
let txtRecords: CreateDomainResponse["txtRecords"];
@@ -347,20 +346,9 @@ export async function createOrgDomain(
await trx.insert(dnsRecords).values(recordsToInsert);
}
numOrgDomains = await trx
.select()
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId));
await usageService.add(orgId, FeatureId.DOMAINS, 1, trx);
});
if (numOrgDomains) {
await usageService.updateCount(
orgId,
FeatureId.DOMAINS,
numOrgDomains.length
);
}
if (!returned) {
return next(
createHttpError(

View File

@@ -36,8 +36,6 @@ export async function deleteAccountDomain(
}
const { domainId, orgId } = parsed.data;
let numOrgDomains: OrgDomains[] | undefined;
await db.transaction(async (trx) => {
const [existing] = await trx
.select()
@@ -79,20 +77,9 @@ export async function deleteAccountDomain(
await trx.delete(domains).where(eq(domains.domainId, domainId));
numOrgDomains = await trx
.select()
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId));
await usageService.add(orgId, FeatureId.DOMAINS, -1, trx);
});
if (numOrgDomains) {
await usageService.updateCount(
orgId,
FeatureId.DOMAINS,
numOrgDomains.length
);
}
return response<DeleteAccountDomainResponse>(res, {
data: { success: true },
success: true,

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { and, eq } from "drizzle-orm";
import { and, count, eq } from "drizzle-orm";
import {
domains,
Org,
@@ -24,11 +24,7 @@ import { OpenAPITags, registry } from "@server/openApi";
import { isValidCIDR } from "@server/lib/validators";
import { createCustomer } from "#dynamic/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import {
FeatureId,
limitsService,
sandboxLimitSet
} from "@server/lib/billing";
import { FeatureId, limitsService, freeLimitSet } from "@server/lib/billing";
import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { doCidrsOverlap } from "@server/lib/ip";
@@ -114,6 +110,7 @@ export async function createOrg(
// )
// );
// }
//
// make sure the orgId is unique
const orgExists = await db
@@ -174,8 +171,46 @@ export async function createOrg(
}
}
if (build == "saas") {
if (!billingOrgIdForNewOrg) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Billing org not found for user. Cannot create new organization."
)
);
}
const usage = await usageService.getUsage(billingOrgIdForNewOrg, FeatureId.ORGINIZATIONS);
if (!usage) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"No usage data found for this organization"
)
);
}
const rejectOrgs = await usageService.checkLimitSet(
billingOrgIdForNewOrg,
FeatureId.ORGINIZATIONS,
{
...usage,
instantaneousValue: (usage.instantaneousValue || 0) + 1
} // We need to add one to know if we are violating the limit
);
if (rejectOrgs) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Organization limit exceeded. Please upgrade your plan."
)
);
}
}
let error = "";
let org: Org | null = null;
let numOrgs: number | null = null;
await db.transaction(async (trx) => {
const allDomains = await trx
@@ -184,14 +219,14 @@ export async function createOrg(
.where(eq(domains.configManaged, true));
// Generate SSH CA keys for the org
const ca = generateCA(`${orgId}-ca`);
const encryptionKey = config.getRawConfig().server.secret!;
const encryptedCaPrivateKey = encrypt(ca.privateKeyPem, encryptionKey);
// const ca = generateCA(`${orgId}-ca`);
// const encryptionKey = config.getRawConfig().server.secret!;
// const encryptedCaPrivateKey = encrypt(ca.privateKeyPem, encryptionKey);
const saasBillingFields =
build === "saas" && req.user && isFirstOrg !== null
? isFirstOrg
? { isBillingOrg: true as const, billingOrgId: null }
? { isBillingOrg: true as const, billingOrgId: orgId } // if this is the first org, it becomes the billing org for itself
: {
isBillingOrg: false as const,
billingOrgId: billingOrgIdForNewOrg
@@ -206,8 +241,8 @@ export async function createOrg(
subnet,
utilitySubnet,
createdAt: new Date().toISOString(),
sshCaPrivateKey: encryptedCaPrivateKey,
sshCaPublicKey: ca.publicKeyOpenSSH,
// sshCaPrivateKey: encryptedCaPrivateKey,
// sshCaPublicKey: ca.publicKeyOpenSSH,
...saasBillingFields
})
.returning();
@@ -310,6 +345,17 @@ export async function createOrg(
);
await calculateUserClientsForOrgs(ownerUserId, trx);
if (billingOrgIdForNewOrg) {
const [numOrgsResult] = await trx
.select({ count: count() })
.from(orgs)
.where(eq(orgs.billingOrgId, billingOrgIdForNewOrg)); // all the billable orgs including the primary org that is the billing org itself
numOrgs = numOrgsResult.count;
} else {
numOrgs = 1; // we only have one org if there is no billing org found out
}
});
if (!org) {
@@ -326,7 +372,7 @@ export async function createOrg(
}
if (build === "saas" && isFirstOrg === true) {
await limitsService.applyLimitSetToOrg(orgId, sandboxLimitSet);
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
const customerId = await createCustomer(orgId, req.user?.email);
if (customerId) {
await usageService.updateCount(
@@ -338,6 +384,14 @@ export async function createOrg(
}
}
if (numOrgs) {
usageService.updateCount(
billingOrgIdForNewOrg || orgId,
FeatureId.ORGINIZATIONS,
numOrgs
);
}
return response(res, {
data: org,
success: true,

View File

@@ -6,7 +6,7 @@ import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { eq, and } from "drizzle-orm";
import { eq, and, count } from "drizzle-orm";
import { getUniqueSiteName } from "../../db/names";
import { addPeer } from "../gerbil/peers";
import { fromError } from "zod-validation-error";
@@ -288,7 +288,6 @@ export async function createSite(
const niceId = await getUniqueSiteName(orgId);
let newSite: Site | undefined;
let numSites: Site[] | undefined;
await db.transaction(async (trx) => {
if (type == "newt") {
[newSite] = await trx
@@ -443,20 +442,9 @@ export async function createSite(
});
}
numSites = await trx
.select()
.from(sites)
.where(eq(sites.orgId, orgId));
await usageService.add(orgId, FeatureId.SITES, 1, trx);
});
if (numSites) {
await usageService.updateCount(
orgId,
FeatureId.SITES,
numSites.length
);
}
if (!newSite) {
return next(
createHttpError(

View File

@@ -64,7 +64,6 @@ export async function deleteSite(
}
let deletedNewtId: string | null = null;
let numSites: Site[] | undefined;
await db.transaction(async (trx) => {
if (site.type == "wireguard") {
@@ -101,21 +100,9 @@ export async function deleteSite(
}
}
await trx.delete(sites).where(eq(sites.siteId, siteId));
numSites = await trx
.select()
.from(sites)
.where(eq(sites.orgId, site.orgId));
await usageService.add(site.orgId, FeatureId.SITES, -1, trx);
});
if (numSites) {
await usageService.updateCount(
site.orgId,
FeatureId.SITES,
numSites.length
);
}
// Send termination message outside of transaction to prevent blocking
if (deletedNewtId) {
const payload = {

View File

@@ -1,8 +1,8 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, UserOrg } from "@server/db";
import { db, orgs, UserOrg } from "@server/db";
import { roles, userInvites, userOrgs, users } from "@server/db";
import { eq } from "drizzle-orm";
import { eq, and, inArray, ne } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -125,8 +125,22 @@ export async function acceptInvite(
}
}
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, existingInvite.orgId))
.limit(1);
if (!org) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Organization does not exist. Please contact an admin."
)
);
}
let roleId: number;
let totalUsers: UserOrg[] | undefined;
// get the role to make sure it exists
const existingRole = await db
.select()
@@ -160,25 +174,45 @@ export async function acceptInvite(
await calculateUserClientsForOrgs(existingUser[0].userId, trx);
// Get the total number of users in the org now
totalUsers = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, existingInvite.orgId));
// calculate if the user is in any other of the orgs before we count it as an add to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(
and(
eq(orgs.billingOrgId, org.billingOrgId),
ne(orgs.orgId, existingInvite.orgId)
)
);
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheUserIsStillIn = await trx
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, existingUser[0].userId),
inArray(userOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) {
await usageService.add(
existingInvite.orgId,
FeatureId.USERS,
1,
trx
);
}
}
logger.debug(
`User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}. Total users in org: ${totalUsers.length}`
`User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}`
);
});
if (totalUsers) {
await usageService.updateCount(
existingInvite.orgId,
FeatureId.USERS,
totalUsers.length
);
}
return response<AcceptInviteResponse>(res, {
data: { accepted: true, orgId: existingInvite.orgId },
success: true,

View File

@@ -6,8 +6,8 @@ import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { db, UserOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import { db, orgs, UserOrg } from "@server/db";
import { and, eq, inArray, ne } from "drizzle-orm";
import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db";
import { generateId } from "@server/auth/sessions/app";
import { usageService } from "@server/lib/billing/usageService";
@@ -151,6 +151,21 @@ export async function createOrgUser(
);
}
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Organization not found"
)
);
}
const [idpRes] = await db
.select()
.from(idp)
@@ -172,8 +187,6 @@ export async function createOrgUser(
);
}
let orgUsers: UserOrg[] | undefined;
await db.transaction(async (trx) => {
const [existingUser] = await trx
.select()
@@ -244,22 +257,37 @@ export async function createOrgUser(
.returning();
}
// List all of the users in the org
orgUsers = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, orgId));
await calculateUserClientsForOrgs(userId, trx);
});
if (orgUsers) {
await usageService.updateCount(
orgId,
FeatureId.USERS,
orgUsers.length
);
}
// calculate if the user is in any other of the orgs before we count it as an add to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(
and(
eq(orgs.billingOrgId, org.billingOrgId),
ne(orgs.orgId, orgId)
)
);
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheUserIsStillIn = await trx
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
inArray(userOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) {
await usageService.add(orgId, FeatureId.USERS, 1, trx);
}
}
});
} else {
return next(
createHttpError(HttpCode.BAD_REQUEST, "User type is required")

View File

@@ -1,8 +1,16 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, resources, sites, UserOrg } from "@server/db";
import {
db,
orgs,
resources,
siteResources,
sites,
UserOrg,
userSiteResources
} from "@server/db";
import { userOrgs, userResources, users, userSites } from "@server/db";
import { and, count, eq, exists } from "drizzle-orm";
import { and, count, eq, exists, inArray } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -50,16 +58,16 @@ export async function removeUserOrg(
const { userId, orgId } = parsedParams.data;
// get the user first
const user = await db
const [user] = await db
.select()
.from(userOrgs)
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)));
if (!user || user.length === 0) {
if (!user) {
return next(createHttpError(HttpCode.NOT_FOUND, "User not found"));
}
if (user[0].isOwner) {
if (user.isOwner) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
@@ -68,7 +76,17 @@ export async function removeUserOrg(
);
}
let userCount: UserOrg[] | undefined;
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
);
}
await db.transaction(async (trx) => {
await trx
@@ -97,6 +115,26 @@ export async function removeUserOrg(
)
);
await db.delete(userSiteResources).where(
and(
eq(userSiteResources.userId, userId),
exists(
db
.select()
.from(siteResources)
.where(
and(
eq(
siteResources.siteResourceId,
userSiteResources.siteResourceId
),
eq(siteResources.orgId, orgId)
)
)
)
)
);
await db.delete(userSites).where(
and(
eq(userSites.userId, userId),
@@ -114,11 +152,6 @@ export async function removeUserOrg(
)
);
userCount = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, orgId));
// if (build === "saas") {
// const [rootUser] = await trx
// .select()
@@ -137,15 +170,31 @@ export async function removeUserOrg(
// }
await calculateUserClientsForOrgs(userId, trx);
});
if (userCount) {
await usageService.updateCount(
orgId,
FeatureId.USERS,
userCount.length
);
}
// calculate if the user is in any other of the orgs before we count it as an remove to the billing org
if (org.billingOrgId) {
const billingOrgs = await trx
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, org.billingOrgId));
const billingOrgIds = billingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheUserIsStillIn = await trx
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
inArray(userOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) {
await usageService.add(orgId, FeatureId.USERS, -1, trx);
}
}
});
return response(res, {
data: null,

View File

@@ -110,37 +110,42 @@ const planOptions: PlanOption[] = [
// Tier limits mapping derived from limit sets
const tierLimits: Record<
Tier | "basic",
{ users: number; sites: number; domains: number; remoteNodes: number }
{ users: number; sites: number; domains: number; remoteNodes: number; organizations: number }
> = {
basic: {
users: freeLimitSet[FeatureId.USERS]?.value ?? 0,
sites: freeLimitSet[FeatureId.SITES]?.value ?? 0,
domains: freeLimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: freeLimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0
remoteNodes: freeLimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: freeLimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
},
tier1: {
users: tier1LimitSet[FeatureId.USERS]?.value ?? 0,
sites: tier1LimitSet[FeatureId.SITES]?.value ?? 0,
domains: tier1LimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: tier1LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0
remoteNodes: tier1LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: tier1LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
},
tier2: {
users: tier2LimitSet[FeatureId.USERS]?.value ?? 0,
sites: tier2LimitSet[FeatureId.SITES]?.value ?? 0,
domains: tier2LimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: tier2LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0
remoteNodes: tier2LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: tier2LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
},
tier3: {
users: tier3LimitSet[FeatureId.USERS]?.value ?? 0,
sites: tier3LimitSet[FeatureId.SITES]?.value ?? 0,
domains: tier3LimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: tier3LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0
remoteNodes: tier3LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: tier3LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
},
enterprise: {
users: 0, // Custom for enterprise
sites: 0, // Custom for enterprise
domains: 0, // Custom for enterprise
remoteNodes: 0 // Custom for enterprise
remoteNodes: 0, // Custom for enterprise
organizations: 0 // Custom for enterprise
}
};
@@ -179,6 +184,7 @@ export default function BillingPage() {
const SITES = "sites";
const DOMAINS = "domains";
const REMOTE_EXIT_NODES = "remoteExitNodes";
const ORGINIZATIONS = "organizations";
// Confirmation dialog state
const [showConfirmDialog, setShowConfirmDialog] = useState(false);
@@ -619,6 +625,16 @@ export default function BillingPage() {
});
}
// Check organizations
const organizationsUsage = getUsageValue(ORGINIZATIONS);
if (limits.organizations > 0 && organizationsUsage > limits.organizations) {
violations.push({
feature: "Organizations",
currentUsage: organizationsUsage,
newLimit: limits.organizations
});
}
return violations;
};
@@ -855,6 +871,41 @@ export default function BillingPage() {
)}
</InfoSectionContent>
</InfoSection>
<InfoSection>
<InfoSectionTitle className="flex items-center gap-1 text-xs">
{t("billingOrganizations") ||
"Organizations"}
</InfoSectionTitle>
<InfoSectionContent className="text-sm">
{isOverLimit(ORGINIZATIONS) ? (
<Tooltip>
<TooltipTrigger className="flex items-center gap-1">
<AlertTriangle className="h-3 w-3 text-orange-400" />
<span className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}>
{getLimitValue(ORGINIZATIONS) ??
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(ORGINIZATIONS) !==
null && "organizations"}
</span>
</TooltipTrigger>
<TooltipContent>
<p>{t("billingUsageExceedsLimit", { current: getUsageValue(ORGINIZATIONS), limit: getLimitValue(ORGINIZATIONS) ?? 0 }) || `Current usage (${getUsageValue(ORGINIZATIONS)}) exceeds limit (${getLimitValue(ORGINIZATIONS)})`}</p>
</TooltipContent>
</Tooltip>
) : (
<>
{getLimitValue(ORGINIZATIONS) ??
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(ORGINIZATIONS) !==
null && "organizations"}
</>
)}
</InfoSectionContent>
</InfoSection>
<InfoSection>
<InfoSectionTitle className="flex items-center gap-1 text-xs">
{t("billingRemoteNodes") ||