From 3f5c788d481c4908095e58a4ecde9bdcea63bc9e Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 10 Feb 2026 16:11:02 -0800 Subject: [PATCH] Disable features when downgrading --- server/lib/billing/tierMatrix.ts | 20 +- .../routers/billing/featureLifecycle.ts | 297 ++++++++++++++++++ .../hooks/handleSubscriptionCreated.ts | 7 + .../hooks/handleSubscriptionDeleted.ts | 7 + .../hooks/handleSubscriptionUpdated.ts | 24 +- 5 files changed, 344 insertions(+), 11 deletions(-) create mode 100644 server/private/routers/billing/featureLifecycle.ts diff --git a/server/lib/billing/tierMatrix.ts b/server/lib/billing/tierMatrix.ts index e878f8fe..d1fe362a 100644 --- a/server/lib/billing/tierMatrix.ts +++ b/server/lib/billing/tierMatrix.ts @@ -2,19 +2,19 @@ import { Tier } from "@server/types/Tiers"; export enum TierFeature { OrgOidc = "orgOidc", - LoginPageDomain = "loginPageDomain", - DeviceApprovals = "deviceApprovals", - LoginPageBranding = "loginPageBranding", + LoginPageDomain = "loginPageDomain", // handle downgrade by removing custom domain + DeviceApprovals = "deviceApprovals", // handle downgrade by disabling device approvals + LoginPageBranding = "loginPageBranding", // handle downgrade by setting to default branding LogExport = "logExport", - AccessLogs = "accessLogs", - ActionLogs = "actionLogs", + AccessLogs = "accessLogs", // set the retention period to none on downgrade + ActionLogs = "actionLogs", // set the retention period to none on downgrade RotateCredentials = "rotateCredentials", - MaintencePage = "maintencePage", + MaintencePage = "maintencePage", // handle downgrade DevicePosture = "devicePosture", - TwoFactorEnforcement = "twoFactorEnforcement", - SessionDurationPolicies = "sessionDurationPolicies", - PasswordExpirationPolicies = "passwordExpirationPolicies", - AutoProvisioning = "autoProvisioning" + TwoFactorEnforcement = "twoFactorEnforcement", // handle downgrade by setting to optional + SessionDurationPolicies = "sessionDurationPolicies", // handle downgrade by setting to default duration + PasswordExpirationPolicies = "passwordExpirationPolicies", // handle downgrade by setting to default duration + AutoProvisioning = "autoProvisioning" // handle downgrade by disabling auto provisioning } export const tierMatrix: Record = { diff --git a/server/private/routers/billing/featureLifecycle.ts b/server/private/routers/billing/featureLifecycle.ts new file mode 100644 index 00000000..46337fed --- /dev/null +++ b/server/private/routers/billing/featureLifecycle.ts @@ -0,0 +1,297 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + +import { SubscriptionType } from "./hooks/getSubType"; +import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix"; +import { Tier } from "@server/types/Tiers"; +import logger from "@server/logger"; +import { db, idp, idpOrg, loginPage, loginPageBranding, loginPageBrandingOrg, loginPageOrg, orgs, resources, roles } from "@server/db"; +import { eq } from "drizzle-orm"; + +export async function handleTierChange( + orgId: string, + newTier: SubscriptionType | null, + previousTier?: SubscriptionType | null +): Promise { + logger.info( + `Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}` + ); + + // License subscriptions are handled separately and don't use the tier matrix + if (newTier === "license") { + logger.debug( + `New tier is license for org ${orgId}, no feature lifecycle handling needed` + ); + return; + } + + // If newTier is null, treat as free tier - disable all features + if (newTier === null) { + logger.info( + `Org ${orgId} is reverting to free tier, disabling all paid features` + ); + // Disable all features in the tier matrix + for (const [featureKey] of Object.entries(tierMatrix)) { + const feature = featureKey as TierFeature; + logger.info( + `Feature ${feature} is not available in free tier for org ${orgId}. Disabling...` + ); + await disableFeature(orgId, feature); + } + logger.info( + `Completed free tier feature lifecycle handling for org ${orgId}` + ); + return; + } + + // Get the tier (cast as Tier since we've ruled out "license" and null) + const tier = newTier as Tier; + + // Check each feature in the tier matrix + for (const [featureKey, allowedTiers] of Object.entries(tierMatrix)) { + const feature = featureKey as TierFeature; + const isFeatureAvailable = allowedTiers.includes(tier); + + if (!isFeatureAvailable) { + logger.info( + `Feature ${feature} is not available in tier ${tier} for org ${orgId}. Disabling...` + ); + await disableFeature(orgId, feature); + } else { + logger.debug( + `Feature ${feature} is available in tier ${tier} for org ${orgId}` + ); + } + } + + logger.info( + `Completed tier change feature lifecycle handling for org ${orgId}` + ); +} + +async function disableFeature( + orgId: string, + feature: TierFeature +): Promise { + try { + switch (feature) { + case TierFeature.OrgOidc: + await disableOrgOidc(orgId); + break; + + case TierFeature.LoginPageDomain: + await disableLoginPageDomain(orgId); + break; + + case TierFeature.DeviceApprovals: + await disableDeviceApprovals(orgId); + break; + + case TierFeature.LoginPageBranding: + await disableLoginPageBranding(orgId); + break; + + case TierFeature.LogExport: + await disableLogExport(orgId); + break; + + case TierFeature.AccessLogs: + await disableAccessLogs(orgId); + break; + + case TierFeature.ActionLogs: + await disableActionLogs(orgId); + break; + + case TierFeature.RotateCredentials: + await disableRotateCredentials(orgId); + break; + + case TierFeature.MaintencePage: + await disableMaintencePage(orgId); + break; + + case TierFeature.DevicePosture: + await disableDevicePosture(orgId); + break; + + case TierFeature.TwoFactorEnforcement: + await disableTwoFactorEnforcement(orgId); + break; + + case TierFeature.SessionDurationPolicies: + await disableSessionDurationPolicies(orgId); + break; + + case TierFeature.PasswordExpirationPolicies: + await disablePasswordExpirationPolicies(orgId); + break; + + case TierFeature.AutoProvisioning: + await disableAutoProvisioning(orgId); + break; + + default: + logger.warn( + `Unknown feature ${feature} for org ${orgId}, skipping` + ); + } + + logger.info( + `Successfully disabled feature ${feature} for org ${orgId}` + ); + } catch (error) { + logger.error( + `Error disabling feature ${feature} for org ${orgId}:`, + error + ); + throw error; + } +} + +async function disableOrgOidc(orgId: string): Promise {} + +async function disableDeviceApprovals(orgId: string): Promise { + await db + .update(roles) + .set({ requireDeviceApproval: false }) + .where(eq(roles.orgId, orgId)); + + logger.info(`Disabled device approvals on all roles for org ${orgId}`); +} + +async function disableLoginPageBranding(orgId: string): Promise { + const [existingBranding] = await db + .select() + .from(loginPageBrandingOrg) + .where(eq(loginPageBrandingOrg.orgId, orgId)); + + if (existingBranding) { + await db + .delete(loginPageBranding) + .where( + eq( + loginPageBranding.loginPageBrandingId, + existingBranding.loginPageBrandingId + ) + ); + + logger.info(`Disabled login page branding for org ${orgId}`); + } +} + +async function disableLoginPageDomain(orgId: string): Promise { + const [existingLoginPage] = await db + .select() + .from(loginPageOrg) + .where(eq(loginPageOrg.orgId, orgId)) + .innerJoin( + loginPage, + eq(loginPage.loginPageId, loginPageOrg.loginPageId) + ); + + if (existingLoginPage) { + await db + .delete(loginPageOrg) + .where(eq(loginPageOrg.orgId, orgId)); + + await db + .delete(loginPage) + .where( + eq( + loginPage.loginPageId, + existingLoginPage.loginPageOrg.loginPageId + ) + ); + + logger.info(`Disabled login page domain for org ${orgId}`); + } +} + +async function disableLogExport(orgId: string): Promise {} + +async function disableAccessLogs(orgId: string): Promise { + await db + .update(orgs) + .set({ settingsLogRetentionDaysAccess: 0 }) + .where(eq(orgs.orgId, orgId)); + + logger.info(`Disabled access logs for org ${orgId}`); +} + +async function disableActionLogs(orgId: string): Promise { + await db + .update(orgs) + .set({ settingsLogRetentionDaysAction: 0 }) + .where(eq(orgs.orgId, orgId)); + + logger.info(`Disabled action logs for org ${orgId}`); +} + +async function disableRotateCredentials(orgId: string): Promise {} + +async function disableMaintencePage(orgId: string): Promise { + await db + .update(resources) + .set({ + maintenanceModeEnabled: false + }) + .where(eq(resources.orgId, orgId)); + + logger.info(`Disabled maintenance page on all resources for org ${orgId}`); +} + +async function disableDevicePosture(orgId: string): Promise {} + +async function disableTwoFactorEnforcement(orgId: string): Promise { + await db + .update(orgs) + .set({ requireTwoFactor: false }) + .where(eq(orgs.orgId, orgId)); + + logger.info(`Disabled two-factor enforcement for org ${orgId}`); +} + +async function disableSessionDurationPolicies(orgId: string): Promise { + await db + .update(orgs) + .set({ maxSessionLengthHours: null }) + .where(eq(orgs.orgId, orgId)); + + logger.info(`Disabled session duration policies for org ${orgId}`); +} + +async function disablePasswordExpirationPolicies(orgId: string): Promise { + await db + .update(orgs) + .set({ passwordExpiryDays: null }) + .where(eq(orgs.orgId, orgId)); + + logger.info(`Disabled password expiration policies for org ${orgId}`); +} + +async function disableAutoProvisioning(orgId: string): Promise { + // Get all IDP IDs for this org through the idpOrg join table + const orgIdps = await db + .select({ idpId: idpOrg.idpId }) + .from(idpOrg) + .where(eq(idpOrg.orgId, orgId)); + + // Update autoProvision to false for all IDPs in this org + for (const { idpId } of orgIdps) { + await db + .update(idp) + .set({ autoProvision: false }) + .where(eq(idp.idpId, idpId)); + } +} diff --git a/server/private/routers/billing/hooks/handleSubscriptionCreated.ts b/server/private/routers/billing/hooks/handleSubscriptionCreated.ts index 773ffbae..1152f223 100644 --- a/server/private/routers/billing/hooks/handleSubscriptionCreated.ts +++ b/server/private/routers/billing/hooks/handleSubscriptionCreated.ts @@ -32,6 +32,7 @@ import { sendEmail } from "@server/emails"; import EnterpriseEditionKeyGenerated from "@server/emails/templates/EnterpriseEditionKeyGenerated"; import config from "@server/lib/config"; import { getFeatureIdByPriceId } from "@server/lib/billing/features"; +import { handleTierChange } from "../featureLifecycle"; export async function handleSubscriptionCreated( subscription: Stripe.Subscription @@ -150,6 +151,12 @@ export async function handleSubscriptionCreated( type ); + // Handle initial tier setup - disable features not available in this tier + logger.info( + `Setting up initial tier features for org ${customer.orgId} with type ${type}` + ); + await handleTierChange(customer.orgId, type); + const [orgUserRes] = await db .select() .from(userOrgs) diff --git a/server/private/routers/billing/hooks/handleSubscriptionDeleted.ts b/server/private/routers/billing/hooks/handleSubscriptionDeleted.ts index 45c70eed..d92741be 100644 --- a/server/private/routers/billing/hooks/handleSubscriptionDeleted.ts +++ b/server/private/routers/billing/hooks/handleSubscriptionDeleted.ts @@ -27,6 +27,7 @@ import { AudienceIds, moveEmailToAudience } from "#private/lib/resend"; import { getSubType } from "./getSubType"; import stripe from "#private/lib/stripe"; import privateConfig from "#private/lib/config"; +import { handleTierChange } from "../featureLifecycle"; export async function handleSubscriptionDeleted( subscription: Stripe.Subscription @@ -87,6 +88,12 @@ export async function handleSubscriptionDeleted( type ); + // Handle feature lifecycle for cancellation - disable all tier-specific features + logger.info( + `Disabling tier-specific features for org ${customer.orgId} due to subscription deletion` + ); + await handleTierChange(customer.orgId, null, type); + const [orgUserRes] = await db .select() .from(userOrgs) diff --git a/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts b/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts index 9b36b55e..c431f386 100644 --- a/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts +++ b/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts @@ -26,8 +26,9 @@ import logger from "@server/logger"; import { getFeatureIdByMetricId, getFeatureIdByPriceId } from "@server/lib/billing/features"; import stripe from "#private/lib/stripe"; import { handleSubscriptionLifesycle } from "../subscriptionLifecycle"; -import { getSubType } from "./getSubType"; +import { getSubType, SubscriptionType } from "./getSubType"; import privateConfig from "#private/lib/config"; +import { handleTierChange } from "../featureLifecycle"; export async function handleSubscriptionUpdated( subscription: Stripe.Subscription, @@ -65,6 +66,7 @@ export async function handleSubscriptionUpdated( .limit(1); const type = getSubType(fullSubscription); + const previousType = existingSubscription.type as SubscriptionType | null; await db .update(subscriptions) @@ -79,6 +81,14 @@ export async function handleSubscriptionUpdated( }) .where(eq(subscriptions.subscriptionId, subscription.id)); + // Handle tier change if the subscription type changed + if (type && type !== previousType) { + logger.info( + `Tier change detected for org ${customer.orgId}: ${previousType} -> ${type}` + ); + await handleTierChange(customer.orgId, type, previousType ?? undefined); + } + // Upsert subscription items if (Array.isArray(fullSubscription.items?.data)) { // First, get existing items to preserve featureId when there's no match @@ -268,6 +278,18 @@ export async function handleSubscriptionUpdated( subscription.status, type ); + + // Handle feature lifecycle when subscription is canceled or becomes unpaid + if ( + subscription.status === "canceled" || + subscription.status === "unpaid" || + subscription.status === "incomplete_expired" + ) { + logger.info( + `Subscription ${subscription.id} for org ${customer.orgId} is ${subscription.status}, disabling paid features` + ); + await handleTierChange(customer.orgId, null, previousType ?? undefined); + } } else if (type === "license") { if (subscription.status === "canceled" || subscription.status == "unpaid" || subscription.status == "incomplete_expired") { try {