From 8e1905a695add77d18ca8a2e16cd4c40437bbca9 Mon Sep 17 00:00:00 2001 From: Mustafa <104644957+Blacks-Army@users.noreply.github.com> Date: Sun, 12 Apr 2026 20:19:32 +0200 Subject: [PATCH 1/8] Exclude local/private/CGNAT IPs from COUNTRY=ALL and ASN=ALL/AS0 geo-blocking rules --- server/routers/badger/verifySession.ts | 57 ++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/server/routers/badger/verifySession.ts b/server/routers/badger/verifySession.ts index e2e5f6766..d3c110728 100644 --- a/server/routers/badger/verifySession.ts +++ b/server/routers/badger/verifySession.ts @@ -1003,7 +1003,11 @@ async function checkRules( isIpInCidr(clientIp, rule.value) ) { return rule.action as any; - } else if (clientIp && rule.match == "IP" && clientIp == rule.value) { + } else if ( + clientIp && + rule.match == "IP" && + clientIp == rule.value + ) { return rule.action as any; } else if ( path && @@ -1013,16 +1017,35 @@ async function checkRules( return rule.action as any; } else if ( clientIp && - rule.match == "COUNTRY" && - (await isIpInGeoIP(ipCC, rule.value)) + rule.match == "COUNTRY" ) { - return rule.action as any; + // COUNTRY=ALL should not affect local/private/CGNAT addresses. + if ( + rule.value.toUpperCase() === "ALL" && + isLocalOrCarrierGradeNatIp(clientIp) + ) { + continue; + } + + if (await isIpInGeoIP(ipCC, rule.value)) { + return rule.action as any; + } } else if ( clientIp && - rule.match == "ASN" && - (await isIpInAsn(ipAsn, rule.value)) + rule.match == "ASN" ) { - return rule.action as any; + // ASN=ALL/AS0 should not affect local/private/CGNAT addresses. + if ( + (rule.value.toUpperCase() === "ALL" || + rule.value.toUpperCase() === "AS0") && + isLocalOrCarrierGradeNatIp(clientIp) + ) { + continue; + } + + if (await isIpInAsn(ipAsn, rule.value)) { + return rule.action as any; + } } else if ( clientIp && rule.match == "REGION" && @@ -1184,6 +1207,26 @@ async function isIpInGeoIP( return ipCountryCode?.toUpperCase() === checkCountryCode.toUpperCase(); } +function isLocalOrCarrierGradeNatIp(ip: string): boolean { + const localAndCgnatCidrs = [ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "100.64.0.0/10", + "127.0.0.0/8", + "169.254.0.0/16", + "::1/128", + "fc00::/7", + "fe80::/10" + ]; + + try { + return localAndCgnatCidrs.some((cidr) => isIpInCidr(ip, cidr)); + } catch { + return false; + } +} + async function isIpInAsn( ipAsn: number | undefined, checkAsn: string From bcd164219f3a3cd226261ea15c15b89462dac039 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 3 May 2026 11:48:30 -0700 Subject: [PATCH 2/8] Try to speed up --- server/routers/idp/validateOidcCallback.ts | 59 ++++++++++------------ 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/server/routers/idp/validateOidcCallback.ts b/server/routers/idp/validateOidcCallback.ts index d26a8fbe3..fc8e9b3da 100644 --- a/server/routers/idp/validateOidcCallback.ts +++ b/server/routers/idp/validateOidcCallback.ts @@ -38,10 +38,7 @@ import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsFor import { isSubscribed } from "#dynamic/lib/isSubscribed"; import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; import { tierMatrix } from "@server/lib/billing/tierMatrix"; -import { - assignUserToOrg, - removeUserFromOrg -} from "@server/lib/userOrg"; +import { assignUserToOrg, removeUserFromOrg } from "@server/lib/userOrg"; import { unwrapRoleMapping } from "@app/lib/idpRoleMapping"; const ensureTrailingSlash = (url: string): string => { @@ -336,23 +333,23 @@ export async function validateOidcCallback( .innerJoin(orgs, eq(orgs.orgId, idpOrg.orgId)); allOrgs = idpOrgs.map((o) => o.orgs); - for (const org of allOrgs) { - const subscribed = await isSubscribed( - org.orgId, - tierMatrix.autoProvisioning - ); - if (!subscribed) { - // filter out the org - allOrgs = allOrgs.filter((o) => o.orgId !== org.orgId); + // for (const org of allOrgs) { + // const subscribed = await isSubscribed( + // org.orgId, + // tierMatrix.autoProvisioning + // ); + // if (!subscribed) { + // // filter out the org + // allOrgs = allOrgs.filter((o) => o.orgId !== org.orgId); - // return next( - // createHttpError( - // HttpCode.FORBIDDEN, - // "This organization's current plan does not support this feature." - // ) - // ); - } - } + // // return next( + // // createHttpError( + // // HttpCode.FORBIDDEN, + // // "This organization's current plan does not support this feature." + // // ) + // // ); + // } + // } } else { allOrgs = await db.select().from(orgs); } @@ -396,16 +393,14 @@ export async function validateOidcCallback( idpOrgRes?.roleMapping || defaultRoleMapping; if (roleMapping) { logger.debug("Role Mapping", { roleMapping }); - const roleMappingJmes = unwrapRoleMapping( - roleMapping - ).evaluationExpression; + const roleMappingJmes = + unwrapRoleMapping(roleMapping).evaluationExpression; const roleMappingResult = jmespath.search( claims, roleMappingJmes ); - const roleNames = normalizeRoleMappingResult( - roleMappingResult - ); + const roleNames = + normalizeRoleMappingResult(roleMappingResult); const supportsMultiRole = await isLicensedOrSubscribed( org.orgId, @@ -515,7 +510,7 @@ export async function validateOidcCallback( } } - const orgUserCounts: { orgId: string; userCount: number }[] = []; + const orgUserCounts: { orgId: string; userCount: number }[] = []; // sync the user with the orgs and roles await db.transaction(async (trx) => { @@ -628,7 +623,7 @@ export async function validateOidcCallback( { orgId: org.orgId, userId: userId!, - autoProvisioned: true, + autoProvisioned: true }, org.roleIds, trx @@ -758,9 +753,7 @@ function hydrateOrgMapping( return orgMapping.split("{{orgId}}").join(orgId); } -function normalizeRoleMappingResult( - result: unknown -): string[] { +function normalizeRoleMappingResult(result: unknown): string[] { if (typeof result === "string") { const role = result.trim(); return role ? [role] : []; @@ -770,7 +763,9 @@ function normalizeRoleMappingResult( return [ ...new Set( result - .filter((value): value is string => typeof value === "string") + .filter( + (value): value is string => typeof value === "string" + ) .map((value) => value.trim()) .filter(Boolean) ) From 81b8a8a9e3504c464ff91a0344e9962d5c06614d Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 3 May 2026 12:29:34 -0700 Subject: [PATCH 3/8] Fix ns cert generation --- .../routers/certificates/createCertificate.ts | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/server/private/routers/certificates/createCertificate.ts b/server/private/routers/certificates/createCertificate.ts index 048b92352..2f2e50fdc 100644 --- a/server/private/routers/certificates/createCertificate.ts +++ b/server/private/routers/certificates/createCertificate.ts @@ -90,14 +90,13 @@ export async function createCertificate( domainToWrite = `*.${domainToWrite}`; } } else if (domainRecord.type == "ns") { - // first if we have a * in the domain for this case we dont want to include it because it will mess with the cert generator so remove it - if (domain.startsWith("*.")) { - domain = domain.slice(2); - } - - const parts = domain.split("."); - if (parts.length > 2) { - domainToWrite = parts.slice(1).join("."); + if (domain == domainRecord.baseDomain) { + domainToWrite = domainRecord.baseDomain; + } else { + const parts = domain.split("."); + if (parts.length > 2) { + domainToWrite = parts.slice(1).join("."); + } } } From eb515a8f7fd3ca348b73343c1b460744a5f335a7 Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Sun, 3 May 2026 14:16:29 -0700 Subject: [PATCH 4/8] consolidate orgidps in import list --- src/components/OrgIdpTable.tsx | 104 ++++++++++++++++++++++++++++----- 1 file changed, 89 insertions(+), 15 deletions(-) diff --git a/src/components/OrgIdpTable.tsx b/src/components/OrgIdpTable.tsx index bdbaafa27..c0199c6d3 100644 --- a/src/components/OrgIdpTable.tsx +++ b/src/components/OrgIdpTable.tsx @@ -25,7 +25,6 @@ import { import { ArrowRight, ArrowUpDown, - KeyRound, MoreHorizontal } from "lucide-react"; import { useMemo, useState } from "react"; @@ -50,6 +49,7 @@ import { useQuery } from "@tanstack/react-query"; import { useDebounce } from "use-debounce"; import type { ListUserAdminOrgIdpsResponse } from "@server/routers/orgIdp/types"; import { cn } from "@app/lib/cn"; +import { Badge } from "@app/components/ui/badge"; import { usePaidStatus } from "@app/hooks/usePaidStatus"; import { tierMatrix } from "@server/lib/billing/tierMatrix"; import { isIdpGlobalModeBannerVisible } from "@app/components/IdpGlobalModeBanner"; @@ -63,6 +63,61 @@ export type IdpRow = { type AdminIdpRow = ListUserAdminOrgIdpsResponse["idps"][number]; +type ImportSourceOrg = { orgId: string; orgName: string }; + +type GroupedImportableIdp = { + idpId: number; + name: string; + type: string; + variant: string; + tags: string | null; + sources: ImportSourceOrg[]; +}; + +function adminRowForImport( + group: GroupedImportableIdp, + source: ImportSourceOrg +): AdminIdpRow { + return { + idpId: group.idpId, + orgId: source.orgId, + orgName: source.orgName, + name: group.name, + type: group.type, + variant: group.variant, + tags: group.tags + }; +} + +function groupImportableIdps(rows: AdminIdpRow[]): GroupedImportableIdp[] { + const map = new Map(); + for (const row of rows) { + let g = map.get(row.idpId); + if (!g) { + g = { + idpId: row.idpId, + name: row.name, + type: row.type, + variant: row.variant, + tags: row.tags, + sources: [] + }; + map.set(row.idpId, g); + } + if (!g.sources.some((s) => s.orgId === row.orgId)) { + g.sources.push({ orgId: row.orgId, orgName: row.orgName }); + } + } + return Array.from(map.values()) + .map((item) => ({ + ...item, + sources: [...item.sources].sort((a, b) => + a.orgName.localeCompare(b.orgName) + ) + })) + .sort((a, b) => b.name.localeCompare(a.name)); +} + function IdpImportRowIcon({ type, variant @@ -114,16 +169,22 @@ export default function IdpTable({ idps, orgId }: Props) { ); }, [adminIdpsRaw, orgId, idps]); - const shownImportIdps = useMemo(() => { + const importableGrouped = useMemo( + () => groupImportableIdps(importableIdps), + [importableIdps] + ); + + const shownImportGrouped = useMemo(() => { const q = debouncedImportSearch.trim().toLowerCase(); if (!q) { - return importableIdps; + return importableGrouped; } - return importableIdps.filter((row) => { - const hay = `${row.orgName} ${row.name}`.toLowerCase(); + return importableGrouped.filter((group) => { + const hay = + `${group.name} ${group.sources.map((s) => s.orgName).join(" ")}`.toLowerCase(); return hay.includes(q); }); - }, [importableIdps, debouncedImportSearch]); + }, [importableGrouped, debouncedImportSearch]); const deleteIdp = async (idpId: number) => { try { @@ -364,31 +425,44 @@ export default function IdpTable({ idps, orgId }: Props) { {t("idpImportEmpty")} - {shownImportIdps.map((row) => ( + {shownImportGrouped.map((group) => ( s.orgName).join(" ")}`} disabled={!canImportOrgOidcIdp} onSelect={() => { if (!canImportOrgOidcIdp) { return; } - void importIdp(row); + void importIdp( + adminRowForImport( + group, + group.sources[0] + ) + ); }} >
- {row.orgName} + {group.name}
-
- {row.name} +
+ {group.sources.map((src) => ( + + {src.orgName} + + ))}
From 1a926a7127c95d89b343f46b8cd609caeefeeadb Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 3 May 2026 14:30:34 -0700 Subject: [PATCH 5/8] Handle trial limit lifecycle --- .../billing/hooks/handleCustomerCreated.ts | 8 ++++ .../routers/billing/subscriptionLifecycle.ts | 2 +- .../routers/org/sendTrialNotification.ts | 42 ++++++++++++++----- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/server/private/routers/billing/hooks/handleCustomerCreated.ts b/server/private/routers/billing/hooks/handleCustomerCreated.ts index 66ad3a4fa..79dbcea35 100644 --- a/server/private/routers/billing/hooks/handleCustomerCreated.ts +++ b/server/private/routers/billing/hooks/handleCustomerCreated.ts @@ -16,6 +16,7 @@ import { customers, db, subscriptions } from "@server/db"; import { eq } from "drizzle-orm"; import logger from "@server/logger"; import { generateId } from "@server/auth/sessions/app"; +import { handleSubscriptionLifesycle } from "../subscriptionLifecycle"; export async function handleCustomerCreated( customer: Stripe.Customer @@ -62,6 +63,13 @@ export async function handleCustomerCreated( expiresAt: trialExpiresAt, trial: true }); + + // update to the business limits for the trial + await handleSubscriptionLifesycle( + customer.metadata.orgId, + "active", + "tier3" + ); }); logger.info(`Customer with ID ${customer.id} created successfully.`); diff --git a/server/private/routers/billing/subscriptionLifecycle.ts b/server/private/routers/billing/subscriptionLifecycle.ts index 76fb6ec8e..b993a4e1a 100644 --- a/server/private/routers/billing/subscriptionLifecycle.ts +++ b/server/private/routers/billing/subscriptionLifecycle.ts @@ -44,7 +44,7 @@ function getLimitSetForSubscriptionType( export async function handleSubscriptionLifesycle( orgId: string, status: string, - subType: SubscriptionType | null + subType: SubscriptionType | null = null ) { switch (status) { case "active": diff --git a/server/private/routers/org/sendTrialNotification.ts b/server/private/routers/org/sendTrialNotification.ts index c3b7f6518..233010064 100644 --- a/server/private/routers/org/sendTrialNotification.ts +++ b/server/private/routers/org/sendTrialNotification.ts @@ -24,13 +24,18 @@ import { fromError } from "zod-validation-error"; import { sendEmail } from "@server/emails"; import NotifyTrialExpiring from "@server/emails/templates/NotifyTrialExpiring"; import config from "@server/lib/config"; +import { handleSubscriptionLifesycle } from "../billing/subscriptionLifecycle"; const sendTrialNotificationParamsSchema = z.object({ orgId: z.string() }); const sendTrialNotificationBodySchema = z.object({ - notificationType: z.enum(["trial_ending_5d", "trial_ending_24h", "trial_ended"]), + notificationType: z.enum([ + "trial_ending_5d", + "trial_ending_24h", + "trial_ended" + ]), orgName: z.string(), trialEndsAt: z.number(), billingLink: z.string().optional() @@ -69,9 +74,7 @@ async function getOrgAdmins(orgId: string) { ) ); - const byUserId = new Map( - admins.map((a) => [a.userId, a]) - ); + const byUserId = new Map(admins.map((a) => [a.userId, a])); const orgAdmins = Array.from(byUserId.values()).filter( (admin) => admin.email && admin.email.length > 0 ); @@ -108,8 +111,12 @@ export async function sendTrialNotification( } const { orgId } = parsedParams.data; - const { notificationType, orgName, trialEndsAt, billingLink: bodyBillingLink } = - parsedBody.data; + const { + notificationType, + orgName, + trialEndsAt, + billingLink: bodyBillingLink + } = parsedBody.data; // Verify organization exists const org = await db @@ -146,13 +153,17 @@ export async function sendTrialNotification( bodyBillingLink ?? `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing`; - const trialEndsAtFormatted = new Date(trialEndsAt * 1000).toLocaleDateString( - "en-US", - { year: "numeric", month: "long", day: "numeric" } - ); + const trialEndsAtFormatted = new Date( + trialEndsAt * 1000 + ).toLocaleDateString("en-US", { + year: "numeric", + month: "long", + day: "numeric" + }); let daysRemaining: number | null; let subject: string; + let resetLimits = false; if (notificationType === "trial_ending_5d") { daysRemaining = 5; @@ -163,6 +174,7 @@ export async function sendTrialNotification( } else { daysRemaining = null; subject = "Your trial has ended"; + resetLimits = true; } let emailsSent = 0; @@ -201,6 +213,14 @@ export async function sendTrialNotification( } } + if (resetLimits) { + // this will only fire if they have not upgraded yet because when upgrading we delete the trial + await handleSubscriptionLifesycle(orgId, "cancled"); + logger.debug( + `Trial ended for org ${orgId}, limits reset to free tier` + ); + } + return response(res, { data: { success: true, @@ -221,4 +241,4 @@ export async function sendTrialNotification( ) ); } -} \ No newline at end of file +} From c33e295ce7b6d18bf8067f6c73d44fec445c8077 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 3 May 2026 14:42:43 -0700 Subject: [PATCH 6/8] Add a banner showing that you are on a trial --- messages/en-US.json | 3 ++ .../settings/(private)/billing/page.tsx | 15 ++++++++ src/components/DismissableBanner.tsx | 22 ++++++----- src/components/TrialBillingBanner.tsx | 38 +++++++++++++++++++ 4 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 src/components/TrialBillingBanner.tsx diff --git a/messages/en-US.json b/messages/en-US.json index a7b045480..09a2dc180 100644 --- a/messages/en-US.json +++ b/messages/en-US.json @@ -25,6 +25,9 @@ "subscriptionViolationMessage": "You're beyond your limits for your current plan. Correct the problem by removing sites, users, or other resources to stay within your plan.", "trialBannerMessage": "Your trial expires in {countdown}. Upgrade to keep access.", "trialBannerExpired": "Your trial has expired. Upgrade now to restore access.", + "billingTrialBannerTitle": "Free Trial Active", + "billingTrialBannerDescription": "You're currently on a free trial on the business tier. When the trial ends, your account will automatically revert to the Basic tier features and limits. Upgrade anytime to keep access to your current plan's features.", + "billingTrialBannerUpgrade": "Upgrade Now", "trialActive": "Free Trial Active", "trialExpired": "Trial Expired", "trialHasEnded": "Your trial has ended.", diff --git a/src/app/[orgId]/settings/(private)/billing/page.tsx b/src/app/[orgId]/settings/(private)/billing/page.tsx index 778062e8e..068f6ed62 100644 --- a/src/app/[orgId]/settings/(private)/billing/page.tsx +++ b/src/app/[orgId]/settings/(private)/billing/page.tsx @@ -55,6 +55,7 @@ import { tier3LimitSet } from "@server/lib/billing/limitSet"; import { FeatureId } from "@server/lib/billing/features"; +import TrialBillingBanner from "@app/components/TrialBillingBanner"; // Plan tier definitions matching the mockup type PlanId = "basic" | "home" | "team" | "business" | "enterprise"; @@ -805,6 +806,20 @@ export default function BillingPage() { return ( + {/* Trial Banner */} + {isTrial && ( + { + const currentPlan = planOptions.find( + (p) => p.id === currentPlanId + ); + if (currentPlan?.tierType) { + handleStartSubscription(currentPlan.tierType); + } + }} + /> + )} + {/* Subscription Status Alert */} {isProblematicState && statusMessage && ( diff --git a/src/components/DismissableBanner.tsx b/src/components/DismissableBanner.tsx index 289c4ec25..5527f1037 100644 --- a/src/components/DismissableBanner.tsx +++ b/src/components/DismissableBanner.tsx @@ -13,6 +13,7 @@ type DismissableBannerProps = { titleIcon: ReactNode; description: string; children?: ReactNode; + dismissable?: boolean; }; export const DismissableBanner = ({ @@ -21,7 +22,8 @@ export const DismissableBanner = ({ title, titleIcon, description, - children + children, + dismissable = true }: DismissableBannerProps) => { const [isDismissed, setIsDismissed] = useState(true); const t = useTranslations(); @@ -66,19 +68,21 @@ export const DismissableBanner = ({ ); }; - if (isDismissed) { + if (dismissable && isDismissed) { return null; } return ( - + {dismissable && ( + + )}
diff --git a/src/components/TrialBillingBanner.tsx b/src/components/TrialBillingBanner.tsx new file mode 100644 index 000000000..52fcb4873 --- /dev/null +++ b/src/components/TrialBillingBanner.tsx @@ -0,0 +1,38 @@ +"use client"; + +import React from "react"; +import { Button } from "@app/components/ui/button"; +import { ClockIcon, ArrowRight } from "lucide-react"; +import { useTranslations } from "next-intl"; +import DismissableBanner from "./DismissableBanner"; + +type TrialBillingBannerProps = { + onUpgrade: () => void; +}; + +export const TrialBillingBanner = ({ onUpgrade }: TrialBillingBannerProps) => { + const t = useTranslations(); + + return ( + } + description={t("billingTrialBannerDescription")} + dismissable={false} + > + + + ); +}; + +export default TrialBillingBanner; From 584be4dbd2642a93d615112be2a2f7e8ca44af42 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 3 May 2026 14:45:42 -0700 Subject: [PATCH 7/8] Add badge --- messages/en-US.json | 1 + .../[orgId]/settings/(private)/billing/page.tsx | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/messages/en-US.json b/messages/en-US.json index 09a2dc180..ee4ef143d 100644 --- a/messages/en-US.json +++ b/messages/en-US.json @@ -28,6 +28,7 @@ "billingTrialBannerTitle": "Free Trial Active", "billingTrialBannerDescription": "You're currently on a free trial on the business tier. When the trial ends, your account will automatically revert to the Basic tier features and limits. Upgrade anytime to keep access to your current plan's features.", "billingTrialBannerUpgrade": "Upgrade Now", + "billingTrialBadge": "Free Trial", "trialActive": "Free Trial Active", "trialExpired": "Trial Expired", "trialHasEnded": "Your trial has ended.", diff --git a/src/app/[orgId]/settings/(private)/billing/page.tsx b/src/app/[orgId]/settings/(private)/billing/page.tsx index 068f6ed62..f9f9bd77f 100644 --- a/src/app/[orgId]/settings/(private)/billing/page.tsx +++ b/src/app/[orgId]/settings/(private)/billing/page.tsx @@ -35,6 +35,7 @@ import { } from "@app/components/Credenza"; import { cn } from "@app/lib/cn"; import { CreditCard, ExternalLink, Check, AlertTriangle } from "lucide-react"; +import { Badge } from "@app/components/ui/badge"; import { Alert, AlertTitle, AlertDescription } from "@app/components/ui/alert"; import { Tooltip, @@ -874,8 +875,19 @@ export default function BillingPage() { )} >
-
- {plan.name} +
+ + {plan.name} + + {isCurrentPlan && isTrial && ( + + {t("billingTrialBadge") || + "Free Trial"} + + )}
From 1cc0e9b689c3a719e8b20ffef86c2d67c7ad4654 Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Sun, 3 May 2026 14:46:38 -0700 Subject: [PATCH 8/8] consolidate org idps in login form --- src/components/SmartLoginForm.tsx | 4 +- src/components/SmartLoginOrgSelector.tsx | 297 +++++++++++++++++++++++ 2 files changed, 299 insertions(+), 2 deletions(-) create mode 100644 src/components/SmartLoginOrgSelector.tsx diff --git a/src/components/SmartLoginForm.tsx b/src/components/SmartLoginForm.tsx index 164311b7b..7d695127f 100644 --- a/src/components/SmartLoginForm.tsx +++ b/src/components/SmartLoginForm.tsx @@ -22,7 +22,7 @@ import { useEnvContext } from "@app/hooks/useEnvContext"; import { LookupUserResponse } from "@server/routers/auth/lookupUser"; import { useTranslations } from "next-intl"; import LoginPasswordForm from "@app/components/LoginPasswordForm"; -import LoginOrgSelector from "@app/components/LoginOrgSelector"; +import SmartLoginOrgSelector from "@app/components/SmartLoginOrgSelector"; import UserProfileCard from "@app/components/UserProfileCard"; import SecurityKeyAuthButton from "@app/components/SecurityKeyAuthButton"; import { Separator } from "@app/components/ui/separator"; @@ -206,7 +206,7 @@ export default function SmartLoginForm({ if (viewState.type === "orgSelector") { return (
- void; +}; + +type OrgBucket = { + orgId: string; + orgName: string; + idps: Array<{ + idpId: number; + name: string; + variant: string | null; + }>; + hasInternalAuth: boolean; +}; + +type GroupedLoginIdp = { + idpId: number; + name: string; + variant: string | null; + orgs: { orgId: string; orgName: string }[]; +}; + +function buildOrgMap(lookupResult: LookupUserResponse) { + const orgMap = new Map(); + + for (const account of lookupResult.accounts) { + for (const org of account.orgs) { + if (!orgMap.has(org.orgId)) { + orgMap.set(org.orgId, { + orgId: org.orgId, + orgName: org.orgName, + idps: org.idps, + hasInternalAuth: org.hasInternalAuth + }); + } else { + const existing = orgMap.get(org.orgId)!; + const existingIdpIds = new Set( + existing.idps.map((i) => i.idpId) + ); + for (const idp of org.idps) { + if (!existingIdpIds.has(idp.idpId)) { + existing.idps.push(idp); + } + } + if (org.hasInternalAuth) { + existing.hasInternalAuth = true; + } + } + } + } + + return Array.from(orgMap.values()); +} + +function groupIdpsAcrossOrgs(orgs: OrgBucket[]): GroupedLoginIdp[] { + const map = new Map(); + + for (const org of orgs) { + for (const idp of org.idps) { + let g = map.get(idp.idpId); + if (!g) { + g = { + idpId: idp.idpId, + name: idp.name, + variant: idp.variant, + orgs: [] + }; + map.set(idp.idpId, g); + } + if (!g.orgs.some((o) => o.orgId === org.orgId)) { + g.orgs.push({ orgId: org.orgId, orgName: org.orgName }); + } + } + } + + return Array.from(map.values()) + .map((g) => ({ + ...g, + orgs: [...g.orgs].sort((a, b) => a.orgName.localeCompare(b.orgName)) + })) + .sort((a, b) => b.name.localeCompare(a.name)); +} + +export default function SmartLoginOrgSelector({ + identifier, + lookupResult, + redirect, + forceLogin, + onUseDifferentAccount +}: SmartLoginOrgSelectorProps) { + const t = useTranslations(); + const [showPasswordForm, setShowPasswordForm] = useState(false); + const [error, setError] = useState(null); + const [pendingIdpId, setPendingIdpId] = useState(null); + const params = useSearchParams(); + const router = useRouter(); + + const orgs = buildOrgMap(lookupResult); + const groupedIdps = groupIdpsAcrossOrgs(orgs); + + const hasInternalAccount = lookupResult.accounts.some( + (acc) => acc.hasInternalAuth + ); + + function goToApp() { + const url = window.location.href.split("?")[0]; + router.push(url); + } + + useEffect(() => { + if (params.get("gotoapp")) { + goToApp(); + } + }, []); + + async function loginWithIdp(idpId: number, orgId: string) { + setPendingIdpId(idpId); + setError(null); + + let redirectToUrl: string | undefined; + try { + const safeRedirect = cleanRedirect(redirect || "/"); + const response = await generateOidcUrlProxy( + idpId, + safeRedirect, + orgId, + forceLogin + ); + + if (response.error) { + setError(response.message); + setPendingIdpId(null); + return; + } + + const data = response.data; + if (data?.redirectUrl) { + redirectToUrl = data.redirectUrl; + } + } catch { + setError( + t("loginError", { + defaultValue: + "An unexpected error occurred. Please try again." + }) + ); + } + + if (redirectToUrl) { + redirectTo(redirectToUrl); + } else { + setPendingIdpId(null); + } + } + + if (showPasswordForm) { + return ( +
+ + +
+ ); + } + + return ( +
+ + + {hasInternalAccount && ( +
+ +
+ )} + + {groupedIdps.length > 0 ? ( +
+ {error && ( + + {error} + + )} + +
+
+ +
+
+ + {t("idpContinue")} + +
+
+ +
+ {params.get("gotoapp") ? ( + + ) : ( + groupedIdps.map((group) => { + const effectiveType = + group.variant || group.name.toLowerCase(); + const sourceOrgId = group.orgs[0].orgId; + + return ( + + ); + }) + )} +
+
+ ) : null} +
+ ); +}