From c5a73dc87e675a941b2d66070dab00df2e6e84db Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 15 Oct 2025 12:12:59 -0700 Subject: [PATCH] Try to handle the certs better --- server/private/lib/certificates.ts | 187 +++++++++++++----- .../private/lib/traefik/getTraefikConfig.ts | 174 ++++++++-------- 2 files changed, 218 insertions(+), 143 deletions(-) diff --git a/server/private/lib/certificates.ts b/server/private/lib/certificates.ts index 93eb5603..2f74a0f1 100644 --- a/server/private/lib/certificates.ts +++ b/server/private/lib/certificates.ts @@ -13,72 +13,163 @@ import config from "./config"; import { certificates, db } from "@server/db"; -import { and, eq, isNotNull } from "drizzle-orm"; +import { and, eq, isNotNull, or, inArray, sql } from "drizzle-orm"; import { decryptData } from "@server/lib/encryption"; import * as fs from "fs"; +import NodeCache from "node-cache"; + +const encryptionKeyPath = + config.getRawPrivateConfig().server.encryption_key_path; + +if (!fs.existsSync(encryptionKeyPath)) { + throw new Error( + "Encryption key file not found. Please generate one first." + ); +} + +const encryptionKeyHex = fs.readFileSync(encryptionKeyPath, "utf8").trim(); +const encryptionKey = Buffer.from(encryptionKeyHex, "hex"); + +// Define the return type for clarity and type safety +export type CertificateResult = { + id: number; + domain: string; + queriedDomain: string; // The domain that was originally requested (may differ for wildcards) + wildcard: boolean | null; + certFile: string | null; + keyFile: string | null; + expiresAt: number | null; + updatedAt?: number | null; +}; + +// --- In-Memory Cache Implementation --- +const certificateCache = new NodeCache({ stdTTL: 180 }); // Cache for 3 minutes (180 seconds) export async function getValidCertificatesForDomains( - domains: Set -): Promise< - Array<{ - id: number; - domain: string; - wildcard: boolean | null; - certFile: string | null; - keyFile: string | null; - expiresAt: number | null; - updatedAt?: number | null; - }> -> { - if (domains.size === 0) { - return []; + domains: Set, + useCache: boolean = true +): Promise> { + const finalResults: CertificateResult[] = []; + const domainsToQuery = new Set(); + + // 1. Check cache first if enabled + if (useCache) { + for (const domain of domains) { + const cachedCert = certificateCache.get(domain); + if (cachedCert) { + finalResults.push(cachedCert); // Valid cache hit + } else { + domainsToQuery.add(domain); // Cache miss or expired + } + } + } else { + // If caching is disabled, add all domains to the query set + domains.forEach((d) => domainsToQuery.add(d)); } - const domainArray = Array.from(domains); + // 2. If all domains were resolved from the cache, return early + if (domainsToQuery.size === 0) { + return decryptFinalResults(finalResults); + } - // TODO: add more foreign keys to make this query more efficient - we dont need to keep getting every certificate - const validCerts = await db - .select({ - id: certificates.certId, - domain: certificates.domain, - certFile: certificates.certFile, - keyFile: certificates.keyFile, - expiresAt: certificates.expiresAt, - updatedAt: certificates.updatedAt, - wildcard: certificates.wildcard - }) + // 3. Prepare domains for the database query + const domainsToQueryArray = Array.from(domainsToQuery); + const parentDomainsToQuery = new Set(); + + domainsToQueryArray.forEach((domain) => { + const parts = domain.split("."); + // A wildcard can only match a domain with at least two parts (e.g., example.com) + if (parts.length > 1) { + parentDomainsToQuery.add(parts.slice(1).join(".")); + } + }); + + const parentDomainsArray = Array.from(parentDomainsToQuery); + + // 4. Build and execute a single, efficient Drizzle query + // This query fetches all potential exact and wildcard matches in one database round-trip. + const potentialCerts = await db + .select() .from(certificates) .where( and( eq(certificates.status, "valid"), isNotNull(certificates.certFile), - isNotNull(certificates.keyFile) + isNotNull(certificates.keyFile), + or( + // Condition for exact matches on the requested domains + inArray(certificates.domain, domainsToQueryArray), + // Condition for wildcard matches on the parent domains + parentDomainsArray.length > 0 + ? and( + inArray(certificates.domain, parentDomainsArray), + eq(certificates.wildcard, true) + ) + : // If there are no possible parent domains, this condition is false + sql`false` + ) ) ); - // Filter certificates for the specified domains and if it is a wildcard then you can match on everything up to the first dot - const validCertsFiltered = validCerts.filter((cert) => { - return ( - domainArray.includes(cert.domain) || - (cert.wildcard && - domainArray.some((domain) => - domain.endsWith(`.${cert.domain}`) - )) - ); - }); + // 5. Process the database results, prioritizing exact matches over wildcards + const exactMatches = new Map(); + const wildcardMatches = new Map(); - const encryptionKeyPath = config.getRawPrivateConfig().server.encryption_key_path; - - if (!fs.existsSync(encryptionKeyPath)) { - throw new Error( - "Encryption key file not found. Please generate one first." - ); + for (const cert of potentialCerts) { + if (cert.wildcard) { + wildcardMatches.set(cert.domain, cert); + } else { + exactMatches.set(cert.domain, cert); + } } - const encryptionKeyHex = fs.readFileSync(encryptionKeyPath, "utf8").trim(); - const encryptionKey = Buffer.from(encryptionKeyHex, "hex"); + for (const domain of domainsToQuery) { + let foundCert: (typeof potentialCerts)[0] | undefined = undefined; - const validCertsDecrypted = validCertsFiltered.map((cert) => { + // Priority 1: Check for an exact match + if (exactMatches.has(domain)) { + foundCert = exactMatches.get(domain); + } + // Priority 2: Check for a wildcard match on the parent domain + else { + const parts = domain.split("."); + if (parts.length > 1) { + const parentDomain = parts.slice(1).join("."); + if (wildcardMatches.has(parentDomain)) { + foundCert = wildcardMatches.get(parentDomain); + } + } + } + + // If a certificate was found, format it, add to results, and cache it + if (foundCert) { + const resultCert: CertificateResult = { + id: foundCert.certId, + domain: foundCert.domain, // The actual domain of the cert record + queriedDomain: domain, // The domain that was originally requested + wildcard: foundCert.wildcard, + certFile: foundCert.certFile, + keyFile: foundCert.keyFile, + expiresAt: foundCert.expiresAt, + updatedAt: foundCert.updatedAt + }; + + finalResults.push(resultCert); + + // Add to cache for future requests, using the *requested domain* as the key + if (useCache) { + certificateCache.set(domain, resultCert); + } + } + } + + return decryptFinalResults(finalResults); +} + +function decryptFinalResults( + finalResults: CertificateResult[] +): CertificateResult[] { + const validCertsDecrypted = finalResults.map((cert) => { // Decrypt and save certificate file const decryptedCert = decryptData( cert.certFile!, // is not null from query @@ -97,4 +188,4 @@ export async function getValidCertificatesForDomains( }); return validCertsDecrypted; -} \ No newline at end of file +} diff --git a/server/private/lib/traefik/getTraefikConfig.ts b/server/private/lib/traefik/getTraefikConfig.ts index e09af0df..bac11f7d 100644 --- a/server/private/lib/traefik/getTraefikConfig.ts +++ b/server/private/lib/traefik/getTraefikConfig.ts @@ -26,6 +26,10 @@ import { orgs, resources, sites, Target, targets } from "@server/db"; import { sanitize, validatePathRewriteConfig } from "@server/lib/traefik/utils"; import privateConfig from "#private/lib/config"; import createPathRewriteMiddleware from "@server/lib/traefik/middleware"; +import { + CertificateResult, + getValidCertificatesForDomains +} from "#private/lib/certificates"; const redirectHttpsMiddlewareName = "redirect-to-https"; const redirectToRootMiddlewareName = "redirect-to-root"; @@ -89,25 +93,11 @@ export async function getTraefikConfig( subnet: sites.subnet, exitNodeId: sites.exitNodeId, // Namespace - domainNamespaceId: domainNamespaces.domainNamespaceId, - // Certificate fields - we'll get all valid certs and filter in application logic - certificateId: certificates.certId, - certificateDomain: certificates.domain, - certificateWildcard: certificates.wildcard, - certificateStatus: certificates.status + domainNamespaceId: domainNamespaces.domainNamespaceId }) .from(sites) .innerJoin(targets, eq(targets.siteId, sites.siteId)) .innerJoin(resources, eq(resources.resourceId, targets.resourceId)) - .leftJoin( - certificates, - and( - eq(certificates.domainId, resources.domainId), - eq(certificates.status, "valid"), - isNotNull(certificates.certFile), - isNotNull(certificates.keyFile) - ) - ) .leftJoin( targetHealthCheck, eq(targetHealthCheck.targetId, targets.targetId) @@ -138,14 +128,6 @@ export async function getTraefikConfig( // Group by resource and include targets with their unique site data const resourcesMap = new Map(); - - // Track certificates per resource to determine the correct certificate status - const resourceCertificates = new Map>(); resourcesWithTargetsAndSites.forEach((row) => { const resourceId = row.resourceId; @@ -170,25 +152,7 @@ export async function getTraefikConfig( .filter(Boolean) .join("-"); const mapKey = [resourceId, pathKey].filter(Boolean).join("-"); - const key = sanitize(mapKey) || ""; - - // Track certificates for this resource - if (row.certificateId && row.certificateDomain && row.certificateStatus) { - if (!resourceCertificates.has(key)) { - resourceCertificates.set(key, []); - } - - const certList = resourceCertificates.get(key)!; - // Only add if not already present (avoid duplicates from multiple targets) - if (!certList.some(cert => cert.id === row.certificateId)) { - certList.push({ - id: row.certificateId, - domain: row.certificateDomain, - wildcard: row.certificateWildcard, - status: row.certificateStatus - }); - } - } + const key = sanitize(mapKey); if (!resourcesMap.has(key)) { const validation = validatePathRewriteConfig( @@ -205,26 +169,6 @@ export async function getTraefikConfig( return; } - // Determine the correct certificate status for this resource - let certificateStatus: string | null = null; - const resourceCerts = resourceCertificates.get(key) || []; - - if (row.fullDomain && resourceCerts.length > 0) { - // Find the best matching certificate - // Priority: exact domain match > wildcard match - const exactMatch = resourceCerts.find(cert => - cert.domain === row.fullDomain - ); - - const wildcardMatch = resourceCerts.find(cert => - cert.wildcard && cert.domain && - row.fullDomain!.endsWith(`.${cert.domain}`) - ); - - const matchingCert = exactMatch || wildcardMatch; - certificateStatus = matchingCert?.status || null; - } - resourcesMap.set(key, { resourceId: row.resourceId, name: resourceName, @@ -240,7 +184,6 @@ export async function getTraefikConfig( tlsServerName: row.tlsServerName, setHostHeader: row.setHostHeader, enableProxy: row.enableProxy, - certificateStatus: certificateStatus, targets: [], headers: row.headers, path: row.path, // the targets will all have the same path @@ -270,6 +213,19 @@ export async function getTraefikConfig( }); }); + let validCerts: CertificateResult[] = []; + if (privateConfig.getRawPrivateConfig().flags.use_pangolin_dns) { + // create a list of all domains to get certs for + const domains = new Set(); + for (const resource of resourcesMap.values()) { + if (resource.enabled && resource.ssl && resource.fullDomain) { + domains.add(resource.fullDomain); + } + } + // get the valid certs for these domains + validCerts = await getValidCertificatesForDomains(domains, true); // we are caching here because this is called often + } + const config_output: any = { http: { middlewares: { @@ -312,14 +268,6 @@ export async function getTraefikConfig( continue; } - // TODO: for now dont filter it out because if you have multiple domain ids and one is failed it causes all of them to fail - if (resource.certificateStatus !== "valid" && privateConfig.getRawPrivateConfig().flags.use_pangolin_dns) { - logger.debug( - `Resource ${resource.resourceId} has certificate status ${resource.certificateStatus}` - ); - continue; - } - // add routers and services empty objects if they don't exist if (!config_output.http.routers) { config_output.http.routers = {}; @@ -329,22 +277,22 @@ export async function getTraefikConfig( config_output.http.services = {}; } - const domainParts = fullDomain.split("."); - let wildCard; - if (domainParts.length <= 2) { - wildCard = `*.${domainParts.join(".")}`; - } else { - wildCard = `*.${domainParts.slice(1).join(".")}`; - } - - if (!resource.subdomain) { - wildCard = resource.fullDomain; - } - - const configDomain = config.getDomain(resource.domainId); - let tls = {}; if (!privateConfig.getRawPrivateConfig().flags.use_pangolin_dns) { + const domainParts = fullDomain.split("."); + let wildCard; + if (domainParts.length <= 2) { + wildCard = `*.${domainParts.join(".")}`; + } else { + wildCard = `*.${domainParts.slice(1).join(".")}`; + } + + if (!resource.subdomain) { + wildCard = resource.fullDomain; + } + + const configDomain = config.getDomain(resource.domainId); + let certResolver: string, preferWildcardCert: boolean; if (!configDomain) { certResolver = config.getRawConfig().traefik.cert_resolver; @@ -367,6 +315,17 @@ export async function getTraefikConfig( } : {}) }; + } else { + // find a cert that matches the full domain, if not continue + const matchingCert = validCerts.find( + (cert) => cert.queriedDomain === resource.fullDomain + ); + if (!matchingCert) { + logger.warn( + `No matching certificate found for domain: ${resource.fullDomain}` + ); + continue; + } } const additionalMiddlewares = @@ -733,20 +692,31 @@ export async function getTraefikConfig( loginPageId: loginPage.loginPageId, fullDomain: loginPage.fullDomain, exitNodeId: exitNodes.exitNodeId, - domainId: loginPage.domainId, - certificateStatus: certificates.status + domainId: loginPage.domainId }) .from(loginPage) .innerJoin( exitNodes, eq(exitNodes.exitNodeId, loginPage.exitNodeId) ) - .leftJoin( - certificates, - eq(certificates.domainId, loginPage.domainId) - ) .where(eq(exitNodes.exitNodeId, exitNodeId)); + let validCertsLoginPages: CertificateResult[] = []; + if (privateConfig.getRawPrivateConfig().flags.use_pangolin_dns) { + // create a list of all domains to get certs for + const domains = new Set(); + for (const lp of exitNodeLoginPages) { + if (lp.fullDomain) { + domains.add(lp.fullDomain); + } + } + // get the valid certs for these domains + validCertsLoginPages = await getValidCertificatesForDomains( + domains, + true + ); // we are caching here because this is called often + } + if (exitNodeLoginPages.length > 0) { if (!config_output.http.services) { config_output.http.services = {}; @@ -776,8 +746,22 @@ export async function getTraefikConfig( continue; } - if (lp.certificateStatus !== "valid") { - continue; + let tls = {}; + if ( + !privateConfig.getRawPrivateConfig().flags.use_pangolin_dns + ) { + // TODO: we need to add the wildcard logic here too + } else { + // find a cert that matches the full domain, if not continue + const matchingCert = validCertsLoginPages.find( + (cert) => cert.queriedDomain === lp.fullDomain + ); + if (!matchingCert) { + logger.warn( + `No matching certificate found for login page domain: ${lp.fullDomain}` + ); + continue; + } } // auth-allowed: @@ -800,7 +784,7 @@ export async function getTraefikConfig( service: "landing-service", rule: `Host(\`${fullDomain}\`) && (PathRegexp(\`^/auth/resource/[^/]+$\`) || PathRegexp(\`^/auth/idp/[0-9]+/oidc/callback\`) || PathPrefix(\`/_next\`) || Path(\`/auth/org\`) || PathRegexp(\`^/__nextjs*\`))`, priority: 203, - tls: {} + tls: tls }; // auth-catchall: @@ -819,7 +803,7 @@ export async function getTraefikConfig( service: "landing-service", rule: `Host(\`${fullDomain}\`)`, priority: 202, - tls: {} + tls: tls }; // we need to add a redirect from http to https too