import { eq, sql, and } from "drizzle-orm"; import { v4 as uuidv4 } from "uuid"; import { PutObjectCommand } from "@aws-sdk/client-s3"; import { db, usage, customers, sites, newts, limits, Usage, Limit, Transaction } from "@server/db"; import { FeatureId, getFeatureMeterId } from "./features"; import logger from "@server/logger"; import { sendToClient } from "#dynamic/routers/ws"; import { build } from "@server/build"; import { s3Client } from "@server/lib/s3"; import cache from "@server/lib/cache"; interface StripeEvent { identifier?: string; timestamp: number; event_name: string; payload: { value: number; stripe_customer_id: string; }; } export function noop() { if ( build !== "saas" || !process.env.S3_BUCKET ) { return true; } return false; } export class UsageService { private bucketName: string | undefined; private events: StripeEvent[] = []; private lastUploadTime: number = Date.now(); private isUploading: boolean = false; constructor() { if (noop()) { return; } this.bucketName = process.env.S3_BUCKET || undefined; // Periodically check and upload events setInterval(() => { this.checkAndUploadEvents().catch((err) => { logger.error("Error in periodic event upload:", err); }); }, 30000); // every 30 seconds // Handle graceful shutdown on SIGTERM process.on("SIGTERM", async () => { logger.info("SIGTERM received, uploading events before shutdown..."); await this.forceUpload(); logger.info("Events uploaded, proceeding with shutdown"); }); // Handle SIGINT as well (Ctrl+C) process.on("SIGINT", async () => { logger.info("SIGINT received, uploading events before shutdown..."); await this.forceUpload(); logger.info("Events uploaded, proceeding with shutdown"); process.exit(0); }); } /** * Truncate a number to 11 decimal places to prevent precision issues */ private truncateValue(value: number): number { return Math.round(value * 100000000000) / 100000000000; // 11 decimal places } public async add( orgId: string, featureId: FeatureId, value: number, transaction: any = null ): Promise { if (noop()) { return null; } // Truncate value to 11 decimal places value = this.truncateValue(value); // Implement retry logic for deadlock handling const maxRetries = 3; let attempt = 0; while (attempt <= maxRetries) { try { // Get subscription data for this org (with caching) const customerId = await this.getCustomerId(orgId, featureId); if (!customerId) { logger.warn( `No subscription data found for org ${orgId} and feature ${featureId}` ); return null; } let usage; if (transaction) { usage = await this.internalAddUsage( orgId, featureId, value, transaction ); } else { await db.transaction(async (trx) => { usage = await this.internalAddUsage( orgId, featureId, value, trx ); }); } // Log event for Stripe await this.logStripeEvent(featureId, value, customerId); return usage || null; } catch (error: any) { // Check if this is a deadlock error const isDeadlock = error?.code === "40P01" || error?.cause?.code === "40P01" || (error?.message && error.message.includes("deadlock")); if (isDeadlock && attempt < maxRetries) { attempt++; // Exponential backoff with jitter: 50-150ms, 100-300ms, 200-600ms const baseDelay = Math.pow(2, attempt - 1) * 50; const jitter = Math.random() * baseDelay; const delay = baseDelay + jitter; logger.warn( `Deadlock detected for ${orgId}/${featureId}, retrying attempt ${attempt}/${maxRetries} after ${delay.toFixed(0)}ms` ); await new Promise((resolve) => setTimeout(resolve, delay)); continue; } logger.error( `Failed to add usage for ${orgId}/${featureId} after ${attempt} attempts:`, error ); break; } } return null; } private async internalAddUsage( orgId: string, featureId: FeatureId, value: number, trx: Transaction ): Promise { // Truncate value to 11 decimal places value = this.truncateValue(value); const usageId = `${orgId}-${featureId}`; const meterId = getFeatureMeterId(featureId); // Use upsert: insert if not exists, otherwise increment const [returnUsage] = await trx .insert(usage) .values({ usageId, featureId, orgId, meterId, latestValue: value, updatedAt: Math.floor(Date.now() / 1000) }) .onConflictDoUpdate({ target: usage.usageId, set: { latestValue: sql`${usage.latestValue} + ${value}` } }) .returning(); return returnUsage; } // Helper function to get today's date as string (YYYY-MM-DD) getTodayDateString(): string { return new Date().toISOString().split("T")[0]; } // Helper function to get date string from Date object getDateString(date: number): string { return new Date(date * 1000).toISOString().split("T")[0]; } async updateDaily( orgId: string, featureId: FeatureId, value?: number, customerId?: string ): Promise { if (noop()) { return; } try { if (!customerId) { customerId = (await this.getCustomerId(orgId, featureId)) || undefined; if (!customerId) { logger.warn( `No subscription data found for org ${orgId} and feature ${featureId}` ); return; } } // Truncate value to 11 decimal places if provided if (value !== undefined && value !== null) { value = this.truncateValue(value); } const today = this.getTodayDateString(); let currentUsage: Usage | null = null; await db.transaction(async (trx) => { // Get existing meter record const usageId = `${orgId}-${featureId}`; // Get current usage record [currentUsage] = await trx .select() .from(usage) .where(eq(usage.usageId, usageId)) .limit(1); if (currentUsage) { const lastUpdateDate = this.getDateString( currentUsage.updatedAt ); const currentRunningTotal = currentUsage.latestValue; const lastDailyValue = currentUsage.instantaneousValue || 0; if (value == undefined || value === null) { value = currentUsage.instantaneousValue || 0; } if (lastUpdateDate === today) { // Same day update: replace the daily value // Remove old daily value from running total, add new value const newRunningTotal = this.truncateValue( currentRunningTotal - lastDailyValue + value ); await trx .update(usage) .set({ latestValue: newRunningTotal, instantaneousValue: value, updatedAt: Math.floor(Date.now() / 1000) }) .where(eq(usage.usageId, usageId)); } else { // New day: add to running total const newRunningTotal = this.truncateValue( currentRunningTotal + value ); await trx .update(usage) .set({ latestValue: newRunningTotal, instantaneousValue: value, updatedAt: Math.floor(Date.now() / 1000) }) .where(eq(usage.usageId, usageId)); } } else { // First record for this meter const meterId = getFeatureMeterId(featureId); const truncatedValue = this.truncateValue(value || 0); await trx.insert(usage).values({ usageId, featureId, orgId, meterId, instantaneousValue: truncatedValue, latestValue: truncatedValue, updatedAt: Math.floor(Date.now() / 1000) }); } }); await this.logStripeEvent(featureId, value || 0, customerId); } catch (error) { logger.error( `Failed to update daily usage for ${orgId}/${featureId}:`, error ); } } private async getCustomerId( orgId: string, featureId: FeatureId ): Promise { const cacheKey = `customer_${orgId}_${featureId}`; const cached = cache.get(cacheKey); if (cached) { return cached; } try { // Query subscription data const [customer] = await db .select({ customerId: customers.customerId }) .from(customers) .where(eq(customers.orgId, orgId)) .limit(1); if (!customer) { return null; } const customerId = customer.customerId; // Cache the result cache.set(cacheKey, customerId, 300); // 5 minute TTL return customerId; } catch (error) { logger.error( `Failed to get subscription data for ${orgId}/${featureId}:`, error ); return null; } } private async logStripeEvent( featureId: FeatureId, value: number, customerId: string ): Promise { // Truncate value to 11 decimal places before sending to Stripe const truncatedValue = this.truncateValue(value); const event: StripeEvent = { identifier: uuidv4(), timestamp: Math.floor(new Date().getTime() / 1000), event_name: featureId, payload: { value: truncatedValue, stripe_customer_id: customerId } }; this.addEventToMemory(event); await this.checkAndUploadEvents(); } private addEventToMemory(event: StripeEvent): void { if (!this.bucketName) { logger.warn( "S3 bucket name is not configured, skipping event storage." ); return; } this.events.push(event); } private async checkAndUploadEvents(): Promise { const now = Date.now(); const timeSinceLastUpload = now - this.lastUploadTime; // Check if at least 1 minute has passed since last upload if (timeSinceLastUpload >= 60000 && this.events.length > 0) { await this.uploadEventsToS3(); } } private async uploadEventsToS3(): Promise { if (!this.bucketName) { logger.warn( "S3 bucket name is not configured, skipping S3 upload." ); return; } if (this.events.length === 0) { return; } // Check if already uploading if (this.isUploading) { logger.debug("Already uploading events, skipping"); return; } this.isUploading = true; try { // Take a snapshot of current events and clear the array const eventsToUpload = [...this.events]; this.events = []; this.lastUploadTime = Date.now(); const fileName = this.generateEventFileName(); const fileContent = JSON.stringify(eventsToUpload, null, 2); // Upload to S3 const uploadCommand = new PutObjectCommand({ Bucket: this.bucketName, Key: fileName, Body: fileContent, ContentType: "application/json" }); await s3Client.send(uploadCommand); logger.info( `Uploaded ${fileName} to S3 with ${eventsToUpload.length} events` ); } catch (error) { logger.error("Failed to upload events to S3:", error); // Note: Events are lost if upload fails. In a production system, // you might want to add the events back to the array or implement retry logic } finally { this.isUploading = false; } } private generateEventFileName(): string { const timestamp = new Date().toISOString().replace(/[:.]/g, "-"); const uuid = uuidv4().substring(0, 8); return `events-${timestamp}-${uuid}.json`; } public async getUsage( orgId: string, featureId: FeatureId, trx: Transaction | typeof db = db ): Promise { if (noop()) { return null; } const usageId = `${orgId}-${featureId}`; try { const [result] = await trx .select() .from(usage) .where(eq(usage.usageId, usageId)) .limit(1); if (!result) { // Lets create one if it doesn't exist using upsert to handle race conditions logger.info( `Creating new usage record for ${orgId}/${featureId}` ); const meterId = getFeatureMeterId(featureId); try { const [newUsage] = await trx .insert(usage) .values({ usageId, featureId, orgId, meterId, latestValue: 0, updatedAt: Math.floor(Date.now() / 1000) }) .onConflictDoNothing() .returning(); if (newUsage) { return newUsage; } else { // Record was created by another process, fetch it const [existingUsage] = await trx .select() .from(usage) .where(eq(usage.usageId, usageId)) .limit(1); return existingUsage || null; } } catch (insertError) { // Fallback: try to fetch existing record in case of any insert issues logger.warn( `Insert failed for ${orgId}/${featureId}, attempting to fetch existing record:`, insertError ); const [existingUsage] = await trx .select() .from(usage) .where(eq(usage.usageId, usageId)) .limit(1); return existingUsage || null; } } return result; } catch (error) { logger.error( `Failed to get usage for ${orgId}/${featureId}:`, error ); throw error; } } public async getUsageDaily( orgId: string, featureId: FeatureId ): Promise { if (noop()) { return null; } await this.updateDaily(orgId, featureId); // Ensure daily usage is updated return this.getUsage(orgId, featureId); } public async forceUpload(): Promise { if (this.events.length > 0) { // Force upload regardless of time this.lastUploadTime = 0; // Reset to force upload await this.uploadEventsToS3(); } } public async checkLimitSet( orgId: string, kickSites = false, featureId?: FeatureId, usage?: Usage, trx: Transaction | typeof db = db ): Promise { if (noop()) { return false; } // This method should check the current usage against the limits set for the organization // and kick out all of the sites on the org let hasExceededLimits = false; try { let orgLimits: Limit[] = []; if (featureId) { // Get all limits set for this organization orgLimits = await trx .select() .from(limits) .where( and( eq(limits.orgId, orgId), eq(limits.featureId, featureId) ) ); } else { // Get all limits set for this organization orgLimits = await trx .select() .from(limits) .where(eq(limits.orgId, orgId)); } if (orgLimits.length === 0) { logger.debug(`No limits set for org ${orgId}`); return false; } // Check each limit against current usage for (const limit of orgLimits) { let currentUsage: Usage | null; if (usage) { currentUsage = usage; } else { currentUsage = await this.getUsage( orgId, limit.featureId as FeatureId, trx ); } const usageValue = currentUsage?.instantaneousValue || currentUsage?.latestValue || 0; logger.debug( `Current usage for org ${orgId} on feature ${limit.featureId}: ${usageValue}` ); logger.debug( `Limit for org ${orgId} on feature ${limit.featureId}: ${limit.value}` ); if ( currentUsage && limit.value !== null && usageValue > limit.value ) { logger.debug( `Org ${orgId} has exceeded limit for ${limit.featureId}: ` + `${usageValue} > ${limit.value}` ); hasExceededLimits = true; break; // Exit early if any limit is exceeded } } // If any limits are exceeded, disconnect all sites for this organization if (hasExceededLimits && kickSites) { logger.warn( `Disconnecting all sites for org ${orgId} due to exceeded limits` ); // Get all sites for this organization const orgSites = await trx .select() .from(sites) .where(eq(sites.orgId, orgId)); // Mark all sites as offline and send termination messages const siteUpdates = orgSites.map((site) => site.siteId); if (siteUpdates.length > 0) { // Send termination messages to newt sites for (const site of orgSites) { if (site.type === "newt") { const [newt] = await trx .select() .from(newts) .where(eq(newts.siteId, site.siteId)) .limit(1); if (newt) { const payload = { type: `newt/wg/terminate`, data: { reason: "Usage limits exceeded" } }; // Don't await to prevent blocking await sendToClient(newt.newtId, payload).catch( (error: any) => { logger.error( `Failed to send termination message to newt ${newt.newtId}:`, error ); } ); } } } logger.info( `Disconnected ${orgSites.length} sites for org ${orgId} due to exceeded limits` ); } } } catch (error) { logger.error(`Error checking limits for org ${orgId}:`, error); } return hasExceededLimits; } } export const usageService = new UsageService();