Merge branch 'alerting-rules' into dev

This commit is contained in:
Owen
2026-04-21 15:05:12 -07:00
230 changed files with 16351 additions and 3320 deletions

View File

@@ -1,8 +1,8 @@
import { logsDb, primaryLogsDb, requestAuditLog, resources, db, primaryDb } from "@server/db";
import { logsDb, primaryLogsDb, requestAuditLog, resources, siteResources, db, primaryDb } from "@server/db";
import { registry } from "@server/openApi";
import { NextFunction } from "express";
import { Request, Response } from "express";
import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm";
import { eq, gt, lt, and, count, desc, inArray, isNull, or } from "drizzle-orm";
import { OpenAPITags } from "@server/openApi";
import { z } from "zod";
import createHttpError from "http-errors";
@@ -92,7 +92,10 @@ function getWhere(data: Q) {
lt(requestAuditLog.timestamp, data.timeEnd),
eq(requestAuditLog.orgId, data.orgId),
data.resourceId
? eq(requestAuditLog.resourceId, data.resourceId)
? or(
eq(requestAuditLog.resourceId, data.resourceId),
eq(requestAuditLog.siteResourceId, data.resourceId)
)
: undefined,
data.actor ? eq(requestAuditLog.actor, data.actor) : undefined,
data.method ? eq(requestAuditLog.method, data.method) : undefined,
@@ -110,15 +113,16 @@ export function queryRequest(data: Q) {
return primaryLogsDb
.select({
id: requestAuditLog.id,
timestamp: requestAuditLog.timestamp,
orgId: requestAuditLog.orgId,
action: requestAuditLog.action,
reason: requestAuditLog.reason,
actorType: requestAuditLog.actorType,
actor: requestAuditLog.actor,
actorId: requestAuditLog.actorId,
resourceId: requestAuditLog.resourceId,
ip: requestAuditLog.ip,
timestamp: requestAuditLog.timestamp,
orgId: requestAuditLog.orgId,
action: requestAuditLog.action,
reason: requestAuditLog.reason,
actorType: requestAuditLog.actorType,
actor: requestAuditLog.actor,
actorId: requestAuditLog.actorId,
resourceId: requestAuditLog.resourceId,
siteResourceId: requestAuditLog.siteResourceId,
ip: requestAuditLog.ip,
location: requestAuditLog.location,
userAgent: requestAuditLog.userAgent,
metadata: requestAuditLog.metadata,
@@ -137,37 +141,73 @@ export function queryRequest(data: Q) {
}
async function enrichWithResourceDetails(logs: Awaited<ReturnType<typeof queryRequest>>) {
// If logs database is the same as main database, we can do a join
// Otherwise, we need to fetch resource details separately
const resourceIds = logs
.map(log => log.resourceId)
.filter((id): id is number => id !== null && id !== undefined);
if (resourceIds.length === 0) {
const siteResourceIds = logs
.filter(log => log.resourceId == null && log.siteResourceId != null)
.map(log => log.siteResourceId)
.filter((id): id is number => id !== null && id !== undefined);
if (resourceIds.length === 0 && siteResourceIds.length === 0) {
return logs.map(log => ({ ...log, resourceName: null, resourceNiceId: null }));
}
// Fetch resource details from main database
const resourceDetails = await primaryDb
.select({
resourceId: resources.resourceId,
name: resources.name,
niceId: resources.niceId
})
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
const resourceMap = new Map<number, { name: string | null; niceId: string | null }>();
// Create a map for quick lookup
const resourceMap = new Map(
resourceDetails.map(r => [r.resourceId, { name: r.name, niceId: r.niceId }])
);
if (resourceIds.length > 0) {
const resourceDetails = await primaryDb
.select({
resourceId: resources.resourceId,
name: resources.name,
niceId: resources.niceId
})
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
for (const r of resourceDetails) {
resourceMap.set(r.resourceId, { name: r.name, niceId: r.niceId });
}
}
const siteResourceMap = new Map<number, { name: string | null; niceId: string | null }>();
if (siteResourceIds.length > 0) {
const siteResourceDetails = await primaryDb
.select({
siteResourceId: siteResources.siteResourceId,
name: siteResources.name,
niceId: siteResources.niceId
})
.from(siteResources)
.where(inArray(siteResources.siteResourceId, siteResourceIds));
for (const r of siteResourceDetails) {
siteResourceMap.set(r.siteResourceId, { name: r.name, niceId: r.niceId });
}
}
// Enrich logs with resource details
return logs.map(log => ({
...log,
resourceName: log.resourceId ? resourceMap.get(log.resourceId)?.name ?? null : null,
resourceNiceId: log.resourceId ? resourceMap.get(log.resourceId)?.niceId ?? null : null
}));
return logs.map(log => {
if (log.resourceId != null) {
const details = resourceMap.get(log.resourceId);
return {
...log,
resourceName: details?.name ?? null,
resourceNiceId: details?.niceId ?? null
};
} else if (log.siteResourceId != null) {
const details = siteResourceMap.get(log.siteResourceId);
return {
...log,
resourceId: log.siteResourceId,
resourceName: details?.name ?? null,
resourceNiceId: details?.niceId ?? null
};
}
return { ...log, resourceName: null, resourceNiceId: null };
});
}
export function countRequestQuery(data: Q) {
@@ -211,7 +251,8 @@ async function queryUniqueFilterAttributes(
uniqueLocations,
uniqueHosts,
uniquePaths,
uniqueResources
uniqueResources,
uniqueSiteResources
] = await Promise.all([
primaryLogsDb
.selectDistinct({ actor: requestAuditLog.actor })
@@ -239,6 +280,13 @@ async function queryUniqueFilterAttributes(
})
.from(requestAuditLog)
.where(baseConditions)
.limit(DISTINCT_LIMIT + 1),
primaryLogsDb
.selectDistinct({
id: requestAuditLog.siteResourceId
})
.from(requestAuditLog)
.where(and(baseConditions, isNull(requestAuditLog.resourceId)))
.limit(DISTINCT_LIMIT + 1)
]);
@@ -259,6 +307,10 @@ async function queryUniqueFilterAttributes(
.map(row => row.id)
.filter((id): id is number => id !== null);
const siteResourceIds = uniqueSiteResources
.map(row => row.id)
.filter((id): id is number => id !== null);
let resourcesWithNames: Array<{ id: number; name: string | null }> = [];
if (resourceIds.length > 0) {
@@ -270,10 +322,31 @@ async function queryUniqueFilterAttributes(
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
resourcesWithNames = resourceDetails.map(r => ({
id: r.resourceId,
name: r.name
}));
resourcesWithNames = [
...resourcesWithNames,
...resourceDetails.map(r => ({
id: r.resourceId,
name: r.name
}))
];
}
if (siteResourceIds.length > 0) {
const siteResourceDetails = await primaryDb
.select({
siteResourceId: siteResources.siteResourceId,
name: siteResources.name
})
.from(siteResources)
.where(inArray(siteResources.siteResourceId, siteResourceIds));
resourcesWithNames = [
...resourcesWithNames,
...siteResourceDetails.map(r => ({
id: r.siteResourceId,
name: r.name
}))
];
}
return {

View File

@@ -28,6 +28,7 @@ export type QueryRequestAuditLogResponse = {
actor: string | null;
actorId: string | null;
resourceId: number | null;
siteResourceId: number | null;
resourceNiceId: string | null;
resourceName: string | null;
ip: string | null;

View File

@@ -18,6 +18,7 @@ Reasons:
105 - Valid Password
106 - Valid email
107 - Valid SSO
108 - Connected Client
201 - Resource Not Found
202 - Resource Blocked
@@ -38,6 +39,7 @@ const auditLogBuffer: Array<{
metadata: any;
action: boolean;
resourceId?: number;
siteResourceId?: number;
reason: number;
location?: string;
originalRequestURL: string;
@@ -186,6 +188,7 @@ export async function logRequestAudit(
action: boolean;
reason: number;
resourceId?: number;
siteResourceId?: number;
orgId?: string;
location?: string;
user?: { username: string; userId: string };
@@ -262,6 +265,7 @@ export async function logRequestAudit(
metadata: sanitizeString(metadata),
action: data.action,
resourceId: data.resourceId,
siteResourceId: data.siteResourceId,
reason: data.reason,
location: sanitizeString(data.location),
originalRequestURL: sanitizeString(body.originalRequestURL) ?? "",

View File

@@ -285,6 +285,13 @@ authenticated.get(
site.listContainers
);
authenticated.get(
"/site/:siteId/status-history",
verifySiteAccess,
verifyUserHasAction(ActionsEnum.getSite),
site.getSiteStatusHistory
);
// Site Resource endpoints
authenticated.put(
"/org/:orgId/site-resource",
@@ -420,6 +427,13 @@ authenticated.get(
resource.listResources
);
authenticated.get(
"/resource/:resourceId/status-history",
verifyResourceAccess,
verifyUserHasAction(ActionsEnum.getResource),
resource.getResourceStatusHistory
);
authenticated.get(
"/org/:orgId/resources",
verifyOrgAccess,

View File

@@ -88,11 +88,11 @@ async function dbQueryRows<T extends Record<string, unknown>>(
): Promise<T[]> {
const anyDb = db as any;
if (typeof anyDb.execute === "function") {
// PostgreSQL (node-postgres via Drizzle) returns { rows: [...] } or an array
// PostgreSQL (node-postgres via Drizzle) - returns { rows: [...] } or an array
const result = await anyDb.execute(query);
return (Array.isArray(result) ? result : (result.rows ?? [])) as T[];
}
// SQLite (better-sqlite3 via Drizzle) returns an array directly
// SQLite (better-sqlite3 via Drizzle) - returns an array directly
return (await anyDb.all(query)) as T[];
}
@@ -106,7 +106,7 @@ function isSQLite(): boolean {
* Swaps out the accumulator before writing so that any bandwidth messages
* received during the flush are captured in the new accumulator rather than
* being lost or causing contention. Sites are updated in chunks via a single
* batch UPDATE per chunk. Failed chunks are discarded exact per-flush
* batch UPDATE per chunk. Failed chunks are discarded - exact per-flush
* accuracy is not critical and re-queuing is not worth the added complexity.
*
* This function is exported so that the application's graceful-shutdown
@@ -125,7 +125,7 @@ export async function flushSiteBandwidthToDb(): Promise<void> {
const currentTime = new Date().toISOString();
// Sort by publicKey for consistent lock ordering across concurrent
// writers deadlock-prevention strategy.
// writers - deadlock-prevention strategy.
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
a.localeCompare(b)
);
@@ -150,7 +150,7 @@ export async function flushSiteBandwidthToDb(): Promise<void> {
try {
rows = await withDeadlockRetry(async () => {
if (isSQLite()) {
// SQLite: one UPDATE per row no need for batch efficiency here.
// SQLite: one UPDATE per row - no need for batch efficiency here.
const results: { orgId: string; pubKey: string }[] = [];
for (const [publicKey, { bytesIn, bytesOut }] of chunk) {
const result = await dbQueryRows<{
@@ -170,7 +170,7 @@ export async function flushSiteBandwidthToDb(): Promise<void> {
return results;
}
// PostgreSQL: batch UPDATE … FROM (VALUES …) single round-trip per chunk.
// PostgreSQL: batch UPDATE … FROM (VALUES …) - single round-trip per chunk.
const valuesList = chunk.map(([publicKey, { bytesIn, bytesOut }]) =>
sql`(${publicKey}::text, ${bytesIn}::real, ${bytesOut}::real)`
);
@@ -191,7 +191,7 @@ export async function flushSiteBandwidthToDb(): Promise<void> {
`Failed to flush bandwidth chunk [${i}${chunkEnd}], discarding ${chunk.length} site(s):`,
error
);
// Discard the chunk exact per-flush accuracy is not critical.
// Discard the chunk - exact per-flush accuracy is not critical.
continue;
}
@@ -232,7 +232,7 @@ export async function flushSiteBandwidthToDb(): Promise<void> {
totalBandwidth
);
if (bandwidthUsage) {
// Fire-and-forget don't block the flush on limit checking.
// Fire-and-forget - don't block the flush on limit checking.
usageService
.checkLimitSet(
orgId,
@@ -298,7 +298,7 @@ export async function updateSiteBandwidth(
exitNodeId?: number
): Promise<void> {
for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
// Skip peers that haven't transferred any data writing zeros to the
// Skip peers that haven't transferred any data - writing zeros to the
// database would be a no-op anyway.
if (bytesIn <= 0 && bytesOut <= 0) {
continue;

View File

@@ -0,0 +1,34 @@
export type ListHealthChecksResponse = {
healthChecks: {
targetHealthCheckId: number;
name: string;
siteId: number | null;
siteName: string | null;
siteNiceId: string | null;
hcEnabled: boolean;
hcHealth: "unknown" | "healthy" | "unhealthy";
hcMode: string | null;
hcHostname: string | null;
hcPort: number | null;
hcPath: string | null;
hcScheme: string | null;
hcMethod: string | null;
hcInterval: number | null;
hcUnhealthyInterval: number | null;
hcTimeout: number | null;
hcHeaders: string | null;
hcFollowRedirects: boolean | null;
hcStatus: number | null;
hcTlsServerName: string | null;
hcHealthyThreshold: number | null;
hcUnhealthyThreshold: number | null;
resourceId: number | null;
resourceName: string | null;
resourceNiceId: string | null;
}[];
pagination: {
total: number;
limit: number;
offset: number;
};
};

View File

@@ -4,8 +4,10 @@ import {
clientSitesAssociationsCache,
db,
ExitNode,
networks,
resources,
Site,
siteNetworks,
siteResources,
targetHealthCheck,
targets
@@ -84,7 +86,8 @@ export async function buildClientConfigurationForNewtClient(
// )
// );
if (!client.clientSitesAssociationsCache.isJitMode) { // if we are adding sites through jit then dont add the site to the olm
if (!client.clientSitesAssociationsCache.isJitMode) {
// if we are adding sites through jit then dont add the site to the olm
// update the peer info on the olm
// if the peer has not been added yet this will be a no-op
await updatePeer(client.clients.clientId, {
@@ -137,11 +140,14 @@ export async function buildClientConfigurationForNewtClient(
// Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null);
// Get all enabled site resources for this site
// Get all enabled site resources for this site by joining through siteNetworks and networks
const allSiteResources = await db
.select()
.from(siteResources)
.where(eq(siteResources.siteId, siteId));
.innerJoin(networks, eq(siteResources.networkId, networks.networkId))
.innerJoin(siteNetworks, eq(networks.networkId, siteNetworks.networkId))
.where(eq(siteNetworks.siteId, siteId))
.then((rows) => rows.map((r) => r.siteResources));
const targetsToSend: SubnetProxyTargetV2[] = [];
@@ -168,7 +174,7 @@ export async function buildClientConfigurationForNewtClient(
)
);
const resourceTargets = generateSubnetProxyTargetV2(
const resourceTargets = await generateSubnetProxyTargetV2(
resource,
resourceClients
);
@@ -184,7 +190,10 @@ export async function buildClientConfigurationForNewtClient(
};
}
export async function buildTargetConfigurationForNewtClient(siteId: number) {
export async function buildTargetConfigurationForNewtClient(
siteId: number,
version?: string | null
) {
// Get all enabled targets with their resource protocol information
const allTargets = await db
.select({
@@ -195,7 +204,15 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled,
protocol: resources.protocol,
protocol: resources.protocol
})
.from(targets)
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
const allHealthChecks = await db
.select({
targetHealthCheckId: targetHealthCheck.targetHealthCheckId,
hcEnabled: targetHealthCheck.hcEnabled,
hcPath: targetHealthCheck.hcPath,
hcScheme: targetHealthCheck.hcScheme,
@@ -206,17 +223,15 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
hcUnhealthyInterval: targetHealthCheck.hcUnhealthyInterval,
hcTimeout: targetHealthCheck.hcTimeout,
hcHeaders: targetHealthCheck.hcHeaders,
hcFollowRedirects: targetHealthCheck.hcFollowRedirects,
hcMethod: targetHealthCheck.hcMethod,
hcTlsServerName: targetHealthCheck.hcTlsServerName,
hcStatus: targetHealthCheck.hcStatus
hcStatus: targetHealthCheck.hcStatus,
hcHealthyThreshold: targetHealthCheck.hcHealthyThreshold,
hcUnhealthyThreshold: targetHealthCheck.hcUnhealthyThreshold
})
.from(targets)
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
.leftJoin(
targetHealthCheck,
eq(targets.targetId, targetHealthCheck.targetId)
)
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
.from(targetHealthCheck)
.where(eq(targetHealthCheck.siteId, siteId));
const { tcpTargets, udpTargets } = allTargets.reduce(
(acc, target) => {
@@ -240,19 +255,14 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
);
const healthCheckTargets = allTargets.map((target) => {
const healthCheckTargets = allHealthChecks.map((target) => {
// make sure the stuff is defined
if (
!target.hcPath ||
!target.hcHostname ||
!target.hcPort ||
!target.hcInterval ||
!target.hcMethod
) {
// logger.debug(
// `Skipping adding target health check ${target.targetId} due to missing health check fields`
// );
return null; // Skip targets with missing health check fields
const isTCP = target.hcMode?.toLowerCase() === "tcp";
if (!target.hcHostname || !target.hcPort || !target.hcInterval) {
return null;
}
if (!isTCP && (!target.hcPath || !target.hcMethod)) {
return null;
}
// parse headers
@@ -269,7 +279,7 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
}
return {
id: target.targetId,
id: target.targetHealthCheckId,
hcEnabled: target.hcEnabled,
hcPath: target.hcPath,
hcScheme: target.hcScheme,
@@ -280,9 +290,12 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
hcUnhealthyInterval: target.hcUnhealthyInterval, // in seconds
hcTimeout: target.hcTimeout, // in seconds
hcHeaders: hcHeadersSend,
hcFollowRedirects: target.hcFollowRedirects,
hcMethod: target.hcMethod,
hcTlsServerName: target.hcTlsServerName,
hcStatus: target.hcStatus
hcStatus: target.hcStatus,
hcHealthyThreshold: target.hcHealthyThreshold,
hcUnhealthyThreshold: target.hcUnhealthyThreshold
};
});

View File

@@ -10,7 +10,7 @@ import { convertTargetsIfNessicary } from "../client/targets";
import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config";
export const handleGetConfigMessage: MessageHandler = async (context) => {
export const handleNewtGetConfigMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
@@ -56,7 +56,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 5) {
logger.warn(
`Site last hole punch is too old; skipping this register. The site is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`
`Site last hole punch is too old; skipping this register. The site is failing to hole punch and identify its network address with the server. Can the site reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`
);
return;
}
@@ -113,7 +113,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
exitNode
);
const targetsToSend = await convertTargetsIfNessicary(newt.newtId, targets);
const targetsToSend = await convertTargetsIfNessicary(newt.newtId, targets); // for backward compatibility with old newt versions that don't support the new target format
return {
message: {

View File

@@ -1,180 +1,12 @@
import { db, newts, sites, targetHealthCheck, targets } from "@server/db";
import {
hasActiveConnections,
getClientConfigVersion
} from "#dynamic/routers/ws";
import { db, sites } from "@server/db";
import { getClientConfigVersion } from "#dynamic/routers/ws";
import { MessageHandler } from "@server/routers/ws";
import { Newt } from "@server/db";
import { eq, lt, isNull, and, or, ne, not } from "drizzle-orm";
import { eq } from "drizzle-orm";
import logger from "@server/logger";
import { sendNewtSyncMessage } from "./sync";
import { recordPing } from "./pingAccumulator";
// Track if the offline checker interval is running
let offlineCheckerInterval: NodeJS.Timeout | null = null;
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
const OFFLINE_THRESHOLD_BANDWIDTH_MS = 8 * 60 * 1000; // 8 minutes
/**
* Starts the background interval that checks for newt sites that haven't
* pinged recently and marks them as offline. For backward compatibility,
* a site is only marked offline when there is no active WebSocket connection
* either — so older newt versions that don't send pings but remain connected
* continue to be treated as online.
*/
export const startNewtOfflineChecker = (): void => {
if (offlineCheckerInterval) {
return; // Already running
}
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
);
// Find all online newt-type sites that haven't pinged recently
// (or have never pinged at all). Join newts to obtain the newtId
// needed for the WebSocket connection check.
const staleSites = await db
.select({
siteId: sites.siteId,
newtId: newts.newtId,
lastPing: sites.lastPing
})
.from(sites)
.innerJoin(newts, eq(newts.siteId, sites.siteId))
.where(
and(
eq(sites.online, true),
eq(sites.type, "newt"),
or(
lt(sites.lastPing, twoMinutesAgo),
isNull(sites.lastPing)
)
)
);
for (const staleSite of staleSites) {
// Backward-compatibility check: if the newt still has an
// active WebSocket connection (older clients that don't send
// pings), keep the site online.
const isConnected = await hasActiveConnections(
staleSite.newtId
);
if (isConnected) {
logger.debug(
`Newt ${staleSite.newtId} has not pinged recently but is still connected via WebSocket — keeping site ${staleSite.siteId} online`
);
continue;
}
logger.info(
`Marking site ${staleSite.siteId} offline: newt ${staleSite.newtId} has no recent ping and no active WebSocket connection`
);
await db
.update(sites)
.set({ online: false })
.where(eq(sites.siteId, staleSite.siteId));
const healthChecksOnSite = await db
.select()
.from(targetHealthCheck)
.innerJoin(
targets,
eq(targets.targetId, targetHealthCheck.targetId)
)
.innerJoin(sites, eq(sites.siteId, targets.siteId))
.where(eq(sites.siteId, staleSite.siteId));
for (const healthCheck of healthChecksOnSite) {
logger.info(
`Marking health check ${healthCheck.targetHealthCheck.targetHealthCheckId} offline due to site ${staleSite.siteId} being marked offline`
);
await db
.update(targetHealthCheck)
.set({ hcHealth: "unknown" })
.where(
eq(
targetHealthCheck.targetHealthCheckId,
healthCheck.targetHealthCheck
.targetHealthCheckId
)
);
}
}
// this part only effects self hosted. Its not efficient but we dont expect people to have very many wireguard sites
// select all of the wireguard sites to evaluate if they need to be offline due to the last bandwidth update
const allWireguardSites = await db
.select({
siteId: sites.siteId,
online: sites.online,
lastBandwidthUpdate: sites.lastBandwidthUpdate
})
.from(sites)
.where(
and(
eq(sites.type, "wireguard"),
not(isNull(sites.lastBandwidthUpdate))
)
);
const wireguardOfflineThreshold = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_BANDWIDTH_MS) / 1000
);
// loop over each one. If its offline and there is a new update then mark it online. If its online and there is no update then mark it offline
for (const site of allWireguardSites) {
const lastBandwidthUpdate =
new Date(site.lastBandwidthUpdate!).getTime() / 1000;
if (
lastBandwidthUpdate < wireguardOfflineThreshold &&
site.online
) {
logger.info(
`Marking wireguard site ${site.siteId} offline: no bandwidth update in over ${OFFLINE_THRESHOLD_BANDWIDTH_MS / 60000} minutes`
);
await db
.update(sites)
.set({ online: false })
.where(eq(sites.siteId, site.siteId));
} else if (
lastBandwidthUpdate >= wireguardOfflineThreshold &&
!site.online
) {
logger.info(
`Marking wireguard site ${site.siteId} online: recent bandwidth update`
);
await db
.update(sites)
.set({ online: true })
.where(eq(sites.siteId, site.siteId));
}
}
} catch (error) {
logger.error("Error in newt offline checker interval", { error });
}
}, OFFLINE_CHECK_INTERVAL);
logger.debug("Started newt offline checker interval");
};
/**
* Stops the background interval that checks for offline newt sites.
*/
export const stopNewtOfflineChecker = (): void => {
if (offlineCheckerInterval) {
clearInterval(offlineCheckerInterval);
offlineCheckerInterval = null;
logger.info("Stopped newt offline checker interval");
}
};
/**
* Handles ping messages from newt clients.
*

View File

@@ -192,7 +192,7 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
}
const { tcpTargets, udpTargets, validHealthCheckTargets } =
await buildTargetConfigurationForNewtClient(siteId);
await buildTargetConfigurationForNewtClient(siteId, newtVersion);
logger.debug(
`Sending health check targets to newt ${newt.newtId}: ${JSON.stringify(validHealthCheckTargets)}`

View File

@@ -88,7 +88,7 @@ export async function flushBandwidthToDb(): Promise<void> {
const currentTime = new Date().toISOString();
// Sort by publicKey for consistent lock ordering across concurrent
// writers this is the same deadlock-prevention strategy used in the
// writers - this is the same deadlock-prevention strategy used in the
// original per-message implementation.
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
a.localeCompare(b)
@@ -143,7 +143,7 @@ const flushTimer = setInterval(async () => {
}, FLUSH_INTERVAL_MS);
// Calling unref() means this timer will not keep the Node.js event loop alive
// on its own the process can still exit normally when there is no other work
// on its own - the process can still exit normally when there is no other work
// left. The graceful-shutdown path (see server/cleanup.ts) will call
// flushBandwidthToDb() explicitly before process.exit(), so no data is lost.
flushTimer.unref();
@@ -167,7 +167,7 @@ export const handleReceiveBandwidthMessage: MessageHandler = async (
// Accumulate the incoming data in memory; the periodic timer (and the
// shutdown hook) will take care of writing it to the database.
for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
// Skip peers that haven't transferred any data writing zeros to the
// Skip peers that haven't transferred any data - writing zeros to the
// database would be a no-op anyway.
if (bytesIn <= 0 && bytesOut <= 0) {
continue;

View File

@@ -0,0 +1,9 @@
import { MessageHandler } from "@server/routers/ws";
export async function flushRequestLogToDb(): Promise<void> {
return;
}
export const handleRequestLogMessage: MessageHandler = async (context) => {
return;
};

View File

@@ -2,11 +2,13 @@ export * from "./createNewt";
export * from "./getNewtToken";
export * from "./handleNewtRegisterMessage";
export * from "./handleReceiveBandwidthMessage";
export * from "./handleGetConfigMessage";
export * from "./handleNewtGetConfigMessage";
export * from "./handleSocketMessages";
export * from "./handleNewtPingRequestMessage";
export * from "./handleApplyBlueprintMessage";
export * from "./handleNewtPingMessage";
export * from "./handleNewtDisconnectingMessage";
export * from "./handleConnectionLogMessage";
export * from "./handleRequestLogMessage";
export * from "./registerNewt";
export * from "./offlineChecker";

View File

@@ -0,0 +1,208 @@
import { db, newts, sites, targetHealthCheck, targets, statusHistory } from "@server/db";
import {
hasActiveConnections,
} from "#dynamic/routers/ws";
import { eq, lt, isNull, and, or, ne, not } from "drizzle-orm";
import logger from "@server/logger";
import { fireSiteOfflineAlert, fireSiteOnlineAlert } from "#dynamic/lib/alerts";
// Track if the offline checker interval is running
let offlineCheckerInterval: NodeJS.Timeout | null = null;
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
const OFFLINE_THRESHOLD_BANDWIDTH_MS = 8 * 60 * 1000; // 8 minutes
/**
* Starts the background interval that checks for newt sites that haven't
* pinged recently and marks them as offline. For backward compatibility,
* a site is only marked offline when there is no active WebSocket connection
* either - so older newt versions that don't send pings but remain connected
* continue to be treated as online.
*/
export const startNewtOfflineChecker = (): void => {
if (offlineCheckerInterval) {
return; // Already running
}
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
);
// Find all online newt-type sites that haven't pinged recently
// (or have never pinged at all). Join newts to obtain the newtId
// needed for the WebSocket connection check.
const staleSites = await db
.select({
siteId: sites.siteId,
orgId: sites.orgId,
name: sites.name,
newtId: newts.newtId,
lastPing: sites.lastPing
})
.from(sites)
.innerJoin(newts, eq(newts.siteId, sites.siteId))
.where(
and(
eq(sites.online, true),
eq(sites.type, "newt"),
or(
lt(sites.lastPing, twoMinutesAgo),
isNull(sites.lastPing)
)
)
);
for (const staleSite of staleSites) {
// Backward-compatibility check: if the newt still has an
// active WebSocket connection (older clients that don't send
// pings), keep the site online.
const isConnected = await hasActiveConnections(
staleSite.newtId
);
if (isConnected) {
logger.debug(
`Newt ${staleSite.newtId} has not pinged recently but is still connected via WebSocket - keeping site ${staleSite.siteId} online`
);
continue;
}
logger.info(
`Marking site ${staleSite.siteId} offline: newt ${staleSite.newtId} has no recent ping and no active WebSocket connection`
);
await db
.update(sites)
.set({ online: false })
.where(eq(sites.siteId, staleSite.siteId));
await db.insert(statusHistory).values({
entityType: "site",
entityId: staleSite.siteId,
orgId: staleSite.orgId,
status: "offline",
timestamp: Math.floor(Date.now() / 1000),
}).execute();
const healthChecksOnSite = await db
.select()
.from(targetHealthCheck)
.innerJoin(
targets,
eq(targets.targetId, targetHealthCheck.targetId)
)
.innerJoin(sites, eq(sites.siteId, targets.siteId))
.where(eq(sites.siteId, staleSite.siteId));
for (const healthCheck of healthChecksOnSite) {
logger.info(
`Marking health check ${healthCheck.targetHealthCheck.targetHealthCheckId} offline due to site ${staleSite.siteId} being marked offline`
);
await db
.update(targetHealthCheck)
.set({ hcHealth: "unknown" })
.where(
eq(
targetHealthCheck.targetHealthCheckId,
healthCheck.targetHealthCheck
.targetHealthCheckId
)
);
// TODO: should we be firing an alert here when the health check goes to unknown?
}
await fireSiteOfflineAlert(staleSite.orgId, staleSite.siteId, staleSite.name);
}
// this part only effects self hosted. Its not efficient but we dont expect people to have very many wireguard sites
// select all of the wireguard sites to evaluate if they need to be offline due to the last bandwidth update
const allWireguardSites = await db
.select({
siteId: sites.siteId,
orgId: sites.orgId,
name: sites.name,
online: sites.online,
lastBandwidthUpdate: sites.lastBandwidthUpdate
})
.from(sites)
.where(
and(
eq(sites.type, "wireguard"),
not(isNull(sites.lastBandwidthUpdate))
)
);
const wireguardOfflineThreshold = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_BANDWIDTH_MS) / 1000
);
// loop over each one. If its offline and there is a new update then mark it online. If its online and there is no update then mark it offline
for (const site of allWireguardSites) {
const lastBandwidthUpdate =
new Date(site.lastBandwidthUpdate!).getTime() / 1000;
if (
lastBandwidthUpdate < wireguardOfflineThreshold &&
site.online
) {
logger.info(
`Marking wireguard site ${site.siteId} offline: no bandwidth update in over ${OFFLINE_THRESHOLD_BANDWIDTH_MS / 60000} minutes`
);
await db
.update(sites)
.set({ online: false })
.where(eq(sites.siteId, site.siteId));
await db.insert(statusHistory).values({
entityType: "site",
entityId: site.siteId,
orgId: site.orgId,
status: "offline",
timestamp: Math.floor(Date.now() / 1000),
}).execute();
await fireSiteOfflineAlert(site.orgId, site.siteId, site.name);
} else if (
lastBandwidthUpdate >= wireguardOfflineThreshold &&
!site.online
) {
logger.info(
`Marking wireguard site ${site.siteId} online: recent bandwidth update`
);
await db
.update(sites)
.set({ online: true })
.where(eq(sites.siteId, site.siteId));
await db.insert(statusHistory).values({
entityType: "site",
entityId: site.siteId,
orgId: site.orgId,
status: "online",
timestamp: Math.floor(Date.now() / 1000),
}).execute();
await fireSiteOnlineAlert(site.orgId, site.siteId, site.name);
}
}
} catch (error) {
logger.error("Error in newt offline checker interval", { error });
}
}, OFFLINE_CHECK_INTERVAL);
logger.debug("Started newt offline checker interval");
};
/**
* Stops the background interval that checks for offline newt sites.
*/
export const stopNewtOfflineChecker = (): void => {
if (offlineCheckerInterval) {
clearInterval(offlineCheckerInterval);
offlineCheckerInterval = null;
logger.info("Stopped newt offline checker interval");
}
};

View File

@@ -1,7 +1,8 @@
import { db } from "@server/db";
import { sites, clients, olms } from "@server/db";
import { inArray } from "drizzle-orm";
import { sites, clients, olms, statusHistory } from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import logger from "@server/logger";
import { fireSiteOnlineAlert } from "#dynamic/lib/alerts";
/**
* Ping Accumulator
@@ -110,15 +111,51 @@ async function flushSitePingsToDb(): Promise<void> {
const siteIds = batch.map(([id]) => id);
try {
await withRetry(async () => {
await db
const newlyOnlineSites = await withRetry(async () => {
// Only update sites that were offline - these are the
// offline→online transitions. .returning() gives us exactly
// the site IDs that changed state.
const transitioned = await db
.update(sites)
.set({
online: true,
lastPing: maxTimestamp
})
.where(inArray(sites.siteId, siteIds));
.where(
and(
inArray(sites.siteId, siteIds),
eq(sites.online, false)
)
)
.returning({ siteId: sites.siteId, orgId: sites.orgId, name: sites.name });
// Update lastPing for sites that were already online.
// After the update above, the newly-online sites now have
// online = true, so this catches all remaining sites in the
// batch and keeps lastPing current for them too.
await db
.update(sites)
.set({ lastPing: maxTimestamp })
.where(
and(
inArray(sites.siteId, siteIds),
eq(sites.online, true)
)
);
return transitioned;
}, "flushSitePingsToDb");
for (const site of newlyOnlineSites) {
await db.insert(statusHistory).values({
entityType: "site",
entityId: site.siteId,
orgId: site.orgId,
status: "online",
timestamp: Math.floor(Date.now() / 1000),
}).execute();
await fireSiteOnlineAlert(site.orgId, site.siteId, site.name);
}
} catch (error) {
logger.error(
`Failed to flush site ping batch (${batch.length} sites), re-queuing for next cycle`,
@@ -219,7 +256,7 @@ async function flushClientPingsToDb(): Promise<void> {
}
/**
* Flush everything called by the interval timer and during shutdown.
* Flush everything - called by the interval timer and during shutdown.
*/
export async function flushPingsToDb(): Promise<void> {
await flushSitePingsToDb();
@@ -284,7 +321,7 @@ function isTransientError(error: any): boolean {
return true;
}
// PostgreSQL deadlock detected always safe to retry (one winner guaranteed)
// PostgreSQL deadlock detected - always safe to retry (one winner guaranteed)
if (code === "40P01" || message.includes("deadlock")) {
return true;
}

View File

@@ -249,7 +249,7 @@ export async function registerNewt(
dateCreated: moment().toISOString()
});
// Consume the provisioning key cascade removes siteProvisioningKeyOrg
// Consume the provisioning key - cascade removes siteProvisioningKeyOrg
await trx
.update(siteProvisioningKeys)
.set({

View File

@@ -1,7 +1,6 @@
import { Target, TargetHealthCheck, db, targetHealthCheck } from "@server/db";
import { Target, TargetHealthCheck } from "@server/db";
import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger";
import { eq, inArray } from "drizzle-orm";
import { canCompress } from "@server/lib/clientVersionChecks";
export async function addTargets(
@@ -18,17 +17,23 @@ export async function addTargets(
}:${target.port}`;
});
await sendToClient(newtId, {
type: `newt/${protocol}/add`,
data: {
targets: payloadTargets
}
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
await sendToClient(
newtId,
{
type: `newt/${protocol}/add`,
data: {
targets: payloadTargets
}
},
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
);
// Create a map for quick lookup
const healthCheckMap = new Map<number, TargetHealthCheck>();
healthCheckData.forEach((hc) => {
healthCheckMap.set(hc.targetId, hc);
if (hc.targetId !== null) {
healthCheckMap.set(hc.targetId, hc);
}
});
const healthCheckTargets = targets.map((target) => {
@@ -43,17 +48,18 @@ export async function addTargets(
}
// Ensure all necessary fields are present
if (
!hc.hcPath ||
!hc.hcHostname ||
!hc.hcPort ||
!hc.hcInterval ||
!hc.hcMethod
) {
const isTCP = hc.hcMode?.toLowerCase() === "tcp";
if (!hc.hcHostname || !hc.hcPort || !hc.hcInterval) {
logger.debug(
`Skipping target ${target.targetId} due to missing health check fields`
);
return null; // Skip targets with missing health check fields
return null;
}
if (!isTCP && (!hc.hcPath || !hc.hcMethod)) {
logger.debug(
`Skipping target ${target.targetId} due to missing HTTP health check fields`
);
return null;
}
const hcHeadersParse = hc.hcHeaders ? JSON.parse(hc.hcHeaders) : null;
@@ -77,7 +83,7 @@ export async function addTargets(
}
return {
id: target.targetId,
id: hc.targetHealthCheckId,
hcEnabled: hc.hcEnabled,
hcPath: hc.hcPath,
hcScheme: hc.hcScheme,
@@ -88,9 +94,12 @@ export async function addTargets(
hcUnhealthyInterval: hc.hcUnhealthyInterval, // in seconds
hcTimeout: hc.hcTimeout, // in seconds
hcHeaders: hcHeadersSend,
hcFollowRedirects: hc.hcFollowRedirects,
hcMethod: hc.hcMethod,
hcStatus: hcStatus,
hcTlsServerName: hc.hcTlsServerName
hcTlsServerName: hc.hcTlsServerName,
hcHealthyThreshold: hc.hcHealthyThreshold,
hcUnhealthyThreshold: hc.hcUnhealthyThreshold
};
});
@@ -99,12 +108,106 @@ export async function addTargets(
(target) => target !== null
);
await sendToClient(newtId, {
type: `newt/healthcheck/add`,
data: {
targets: validHealthCheckTargets
await sendToClient(
newtId,
{
type: `newt/healthcheck/add`,
data: {
targets: validHealthCheckTargets
}
},
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
);
}
export async function addStandaloneHealthCheck(
newtId: string,
healthCheck: TargetHealthCheck,
version?: string | null
) {
const isTCP = healthCheck.hcMode?.toLowerCase() === "tcp";
if (
!healthCheck.hcHostname ||
!healthCheck.hcPort ||
!healthCheck.hcInterval
) {
logger.debug(
`Skipping standalone health check ${healthCheck.targetHealthCheckId} due to missing fields`
);
return;
}
if (!isTCP && (!healthCheck.hcPath || !healthCheck.hcMethod)) {
logger.debug(
`Skipping standalone health check ${healthCheck.targetHealthCheckId} due to missing HTTP health check fields`
);
return;
}
const hcHeadersParse = healthCheck.hcHeaders
? JSON.parse(healthCheck.hcHeaders)
: null;
const hcHeadersSend: { [key: string]: string } = {};
if (hcHeadersParse) {
hcHeadersParse.forEach((header: { name: string; value: string }) => {
hcHeadersSend[header.name] = header.value;
});
}
let hcStatus: number | undefined = undefined;
if (healthCheck.hcStatus) {
const parsedStatus = parseInt(healthCheck.hcStatus.toString());
if (!isNaN(parsedStatus)) {
hcStatus = parsedStatus;
}
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
}
await sendToClient(
newtId,
{
type: `newt/healthcheck/add`,
data: {
targets: [
{
id: healthCheck.targetHealthCheckId,
hcEnabled: healthCheck.hcEnabled,
hcPath: healthCheck.hcPath,
hcScheme: healthCheck.hcScheme,
hcMode: healthCheck.hcMode,
hcHostname: healthCheck.hcHostname,
hcPort: healthCheck.hcPort,
hcInterval: healthCheck.hcInterval,
hcUnhealthyInterval: healthCheck.hcUnhealthyInterval,
hcTimeout: healthCheck.hcTimeout,
hcHeaders: hcHeadersSend,
hcFollowRedirects: healthCheck.hcFollowRedirects,
hcMethod: healthCheck.hcMethod,
hcStatus: hcStatus,
hcTlsServerName: healthCheck.hcTlsServerName,
hcHealthyThreshold: healthCheck.hcHealthyThreshold,
hcUnhealthyThreshold: healthCheck.hcUnhealthyThreshold
}
]
}
},
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
);
}
export async function removeStandaloneHealthCheck(
newtId: string,
healthCheckId: number,
version?: string | null
) {
await sendToClient(
newtId,
{
type: `newt/healthcheck/remove`,
data: {
ids: [healthCheckId]
}
},
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
);
}
export async function removeTargets(
@@ -120,21 +223,29 @@ export async function removeTargets(
}:${target.port}`;
});
await sendToClient(newtId, {
type: `newt/${protocol}/remove`,
data: {
targets: payloadTargets
}
}, { incrementConfigVersion: true });
await sendToClient(
newtId,
{
type: `newt/${protocol}/remove`,
data: {
targets: payloadTargets
}
},
{ incrementConfigVersion: true }
);
const healthCheckTargets = targets.map((target) => {
return target.targetId;
});
await sendToClient(newtId, {
type: `newt/healthcheck/remove`,
data: {
ids: healthCheckTargets
}
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
await sendToClient(
newtId,
{
type: `newt/healthcheck/remove`,
data: {
ids: healthCheckTargets
}
},
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
);
}

View File

@@ -4,6 +4,8 @@ import {
clientSitesAssociationsCache,
db,
exitNodes,
networks,
siteNetworks,
siteResources,
sites
} from "@server/db";
@@ -59,9 +61,17 @@ export async function buildSiteConfigurationForOlmClient(
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.innerJoin(
networks,
eq(siteResources.networkId, networks.networkId)
)
.innerJoin(
siteNetworks,
eq(networks.networkId, siteNetworks.networkId)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(siteNetworks.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
@@ -69,6 +79,7 @@ export async function buildSiteConfigurationForOlmClient(
)
);
if (jitMode) {
// Add site configuration to the array
siteConfigurations.push({

View File

@@ -1,104 +1,17 @@
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
import { getClientConfigVersion } from "#dynamic/routers/ws";
import { db } from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, olms, Olm } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm";
import { clients, Olm } from "@server/db";
import { eq } from "drizzle-orm";
import { recordClientPing } from "@server/routers/newt/pingAccumulator";
import logger from "@server/logger";
import { validateSessionToken } from "@server/auth/sessions/app";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { sendTerminateClient } from "../client/terminate";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { sendOlmSyncMessage } from "./sync";
import { OlmErrorCodes } from "./error";
import { handleFingerprintInsertion } from "./fingerprintingUtils";
// Track if the offline checker interval is running
let offlineCheckerInterval: NodeJS.Timeout | null = null;
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
/**
* Starts the background interval that checks for clients that haven't pinged recently
* and marks them as offline
*/
export const startOlmOfflineChecker = (): void => {
if (offlineCheckerInterval) {
return; // Already running
}
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
);
// TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
const offlineClients = await db
.update(clients)
.set({ online: false })
.where(
and(
eq(clients.online, true),
or(
lt(clients.lastPing, twoMinutesAgo),
isNull(clients.lastPing)
)
)
)
.returning();
for (const offlineClient of offlineClients) {
logger.info(
`Kicking offline olm client ${offlineClient.clientId} due to inactivity`
);
if (!offlineClient.olmId) {
logger.warn(
`Offline client ${offlineClient.clientId} has no olmId, cannot disconnect`
);
continue;
}
// Send a disconnect message to the client if connected
try {
await sendTerminateClient(
offlineClient.clientId,
OlmErrorCodes.TERMINATED_INACTIVITY,
offlineClient.olmId
); // terminate first
// wait a moment to ensure the message is sent
await new Promise((resolve) => setTimeout(resolve, 1000));
await disconnectClient(offlineClient.olmId);
} catch (error) {
logger.error(
`Error sending disconnect to offline olm ${offlineClient.clientId}`,
{ error }
);
}
}
} catch (error) {
logger.error("Error in offline checker interval", { error });
}
}, OFFLINE_CHECK_INTERVAL);
logger.debug("Started offline checker interval");
};
/**
* Stops the background interval that checks for offline clients
*/
export const stopOlmOfflineChecker = (): void => {
if (offlineCheckerInterval) {
clearInterval(offlineCheckerInterval);
offlineCheckerInterval = null;
logger.info("Stopped offline checker interval");
}
};
/**
* Handles ping messages from clients and responds with pong
*/

View File

@@ -17,7 +17,6 @@ import { getUserDeviceName } from "@server/db/names";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { OlmErrorCodes, sendOlmError } from "./error";
import { handleFingerprintInsertion } from "./fingerprintingUtils";
import { Alias } from "@server/lib/ip";
import { build } from "@server/build";
import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config";

View File

@@ -4,10 +4,12 @@ import {
db,
exitNodes,
Site,
siteResources
siteNetworks,
siteResources,
sites
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, Olm, sites } from "@server/db";
import { clients, Olm } from "@server/db";
import { and, eq, or } from "drizzle-orm";
import logger from "@server/logger";
import { initPeerAddHandshake } from "./peers";
@@ -44,20 +46,31 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
const { siteId, resourceId, chainId } = message.data;
let site: Site | null = null;
const sendCancel = async () => {
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: { chainId }
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
};
let sitesToProcess: Site[] = [];
if (siteId) {
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (siteRes) {
site = siteRes;
sitesToProcess = [siteRes];
}
}
if (resourceId && !site) {
} else if (resourceId) {
const resources = await db
.select()
.from(siteResources)
@@ -72,27 +85,17 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
);
if (!resources || resources.length === 0) {
logger.error(`handleOlmServerPeerAddMessage: Resource not found`);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.error(
`handleOlmServerInitAddPeerHandshake: Resource not found`
);
await sendCancel();
return;
}
if (resources.length > 1) {
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
logger.error(
`handleOlmServerPeerAddMessage: Multiple resources found matching the criteria`
`handleOlmServerInitAddPeerHandshake: Multiple resources found matching the criteria`
);
return;
}
@@ -117,125 +120,120 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
if (currentResourceAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
`handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
await sendCancel();
return;
}
const siteIdFromResource = resource.siteId;
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteIdFromResource));
if (!siteRes) {
if (!resource.networkId) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site} not found`
`handleOlmServerInitAddPeerHandshake: Resource ${resource.siteResourceId} has no network`
);
await sendCancel();
return;
}
site = siteRes;
// Get all sites associated with this resource's network via siteNetworks
const siteRows = await db
.select({ siteId: siteNetworks.siteId })
.from(siteNetworks)
.where(eq(siteNetworks.networkId, resource.networkId));
if (!siteRows || siteRows.length === 0) {
logger.error(
`handleOlmServerInitAddPeerHandshake: No sites found for resource ${resource.siteResourceId}`
);
await sendCancel();
return;
}
// Fetch full site objects for all network members
const foundSites = await Promise.all(
siteRows.map(async ({ siteId: sid }) => {
const [s] = await db
.select()
.from(sites)
.where(eq(sites.siteId, sid))
.limit(1);
return s ?? null;
})
);
sitesToProcess = foundSites.filter((s): s is Site => s !== null);
}
if (!site) {
logger.error(`handleOlmServerPeerAddMessage: Site not found`);
if (sitesToProcess.length === 0) {
logger.error(
`handleOlmServerInitAddPeerHandshake: No sites to process`
);
await sendCancel();
return;
}
// check if the client can access this site using the cache
const currentSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
);
let handshakeInitiated = false;
if (currentSiteAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
for (const site of sitesToProcess) {
// Check if the client can access this site using the cache
const currentSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
);
if (currentSiteAssociationCaches.length === 0) {
logger.warn(
`handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to site ${site.siteId}, skipping`
);
continue;
}
if (!site.exitNodeId) {
logger.error(
`handleOlmServerInitAddPeerHandshake: Site ${site.siteId} has no exit node, skipping`
);
continue;
}
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
if (!exitNode) {
logger.error(
`handleOlmServerInitAddPeerHandshake: Exit node not found for site ${site.siteId}, skipping`
);
continue;
}
// Trigger the peer add handshake - if the peer was already added this will be a no-op
await initPeerAddHandshake(
client.clientId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
siteId: site.siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
if (!site.exitNodeId) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
// get the exit node from the side
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
if (!exitNode) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
chainId
);
return;
handshakeInitiated = true;
}
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
{
siteId: site.siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
},
olm.olmId,
chainId
);
if (!handshakeInitiated) {
logger.error(
`handleOlmServerInitAddPeerHandshake: No accessible sites with valid exit nodes found, cancelling chain`
);
await sendCancel();
}
return;
};

View File

@@ -1,43 +1,25 @@
import {
Client,
clientSiteResourcesAssociationsCache,
db,
ExitNode,
Org,
orgs,
roleClients,
roles,
networks,
siteNetworks,
siteResources,
Transaction,
userClients,
userOrgs,
users
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import {
clients,
clientSitesAssociationsCache,
exitNodes,
Olm,
olms,
sites
} from "@server/db";
import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import {
generateAliasConfig,
getNextAvailableClientSubnet
} from "@server/lib/ip";
import { generateRemoteSubnets } from "@server/lib/ip";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import {
addPeer as newtAddPeer,
deletePeer as newtDeletePeer
} from "@server/routers/newt/peers";
export const handleOlmServerPeerAddMessage: MessageHandler = async (
@@ -153,13 +135,21 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
.innerJoin(
networks,
eq(siteResources.networkId, networks.networkId)
)
.innerJoin(
siteNetworks,
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
eq(networks.networkId, siteNetworks.networkId),
eq(siteNetworks.siteId, site.siteId)
)
)
.where(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
);

View File

@@ -12,3 +12,4 @@ export * from "./handleOlmUnRelayMessage";
export * from "./recoverOlmWithFingerprint";
export * from "./handleOlmDisconnectingMessage";
export * from "./handleOlmServerInitAddPeerHandshake";
export * from "./offlineChecker";

View File

@@ -0,0 +1,92 @@
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
import { db } from "@server/db";
import { clients } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm";
import logger from "@server/logger";
import { sendTerminateClient } from "../client/terminate";
import { OlmErrorCodes } from "./error";
// Track if the offline checker interval is running
let offlineCheckerInterval: NodeJS.Timeout | null = null;
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
/**
* Starts the background interval that checks for clients that haven't pinged recently
* and marks them as offline
*/
export const startOlmOfflineChecker = (): void => {
if (offlineCheckerInterval) {
return; // Already running
}
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
);
// TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
const offlineClients = await db
.update(clients)
.set({ online: false })
.where(
and(
eq(clients.online, true),
or(
lt(clients.lastPing, twoMinutesAgo),
isNull(clients.lastPing)
)
)
)
.returning();
for (const offlineClient of offlineClients) {
logger.info(
`Kicking offline olm client ${offlineClient.clientId} due to inactivity`
);
if (!offlineClient.olmId) {
logger.warn(
`Offline client ${offlineClient.clientId} has no olmId, cannot disconnect`
);
continue;
}
// Send a disconnect message to the client if connected
try {
await sendTerminateClient(
offlineClient.clientId,
OlmErrorCodes.TERMINATED_INACTIVITY,
offlineClient.olmId
); // terminate first
// wait a moment to ensure the message is sent
await new Promise((resolve) => setTimeout(resolve, 1000));
await disconnectClient(offlineClient.olmId);
} catch (error) {
logger.error(
`Error sending disconnect to offline olm ${offlineClient.clientId}`,
{ error }
);
}
}
} catch (error) {
logger.error("Error in offline checker interval", { error });
}
}, OFFLINE_CHECK_INTERVAL);
logger.debug("Started offline checker interval");
};
/**
* Stops the background interval that checks for offline clients
*/
export const stopOlmOfflineChecker = (): void => {
if (offlineCheckerInterval) {
clearInterval(offlineCheckerInterval);
offlineCheckerInterval = null;
logger.info("Stopped offline checker interval");
}
};

View File

@@ -0,0 +1,93 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, statusHistory } from "@server/db";
import { and, eq, gte, asc } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import {
computeBuckets,
statusHistoryQuerySchema,
StatusHistoryResponse
} from "@server/lib/statusHistory";
const resourceParamsSchema = z.object({
resourceId: z.string().transform((v) => parseInt(v, 10))
});
export async function getResourceStatusHistory(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = resourceParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedQuery = statusHistoryQuerySchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error).toString()
)
);
}
const entityType = "resource";
const entityId = parsedParams.data.resourceId;
const { days } = parsedQuery.data;
const nowSec = Math.floor(Date.now() / 1000);
const startSec = nowSec - days * 86400;
const events = await db
.select()
.from(statusHistory)
.where(
and(
eq(statusHistory.entityType, entityType),
eq(statusHistory.entityId, entityId),
gte(statusHistory.timestamp, startSec)
)
)
.orderBy(asc(statusHistory.timestamp));
const { buckets, totalDowntime } = computeBuckets(events, days);
const totalWindow = days * 86400;
const overallUptime =
totalWindow > 0
? Math.max(
0,
((totalWindow - totalDowntime) / totalWindow) * 100
)
: 100;
return response<StatusHistoryResponse>(res, {
data: {
entityType,
entityId,
days: buckets,
overallUptimePercent: Math.round(overallUptime * 100) / 100,
totalDowntimeSeconds: totalDowntime
},
success: true,
error: false,
message: "Status history retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -145,7 +145,7 @@ export async function getUserResources(
niceId: string;
destination: string;
mode: string;
protocol: string | null;
scheme: string | null;
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
@@ -158,7 +158,7 @@ export async function getUserResources(
niceId: siteResources.niceId,
destination: siteResources.destination,
mode: siteResources.mode,
protocol: siteResources.protocol,
scheme: siteResources.scheme,
enabled: siteResources.enabled,
alias: siteResources.alias,
aliasAddress: siteResources.aliasAddress
@@ -242,7 +242,7 @@ export async function getUserResources(
name: siteResource.name,
destination: siteResource.destination,
mode: siteResource.mode,
protocol: siteResource.protocol,
protocol: siteResource.scheme,
enabled: siteResource.enabled,
alias: siteResource.alias,
aliasAddress: siteResource.aliasAddress,
@@ -291,7 +291,7 @@ export type GetUserResourcesResponse = {
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
type: 'site';
type: "site";
}>;
};
};

View File

@@ -32,3 +32,4 @@ export * from "./addUserToResource";
export * from "./removeUserFromResource";
export * from "./listAllResourceNames";
export * from "./removeEmailFromResourceWhitelist";
export * from "./getStatusHistory";

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, Site, siteResources } from "@server/db";
import { db, Site, siteNetworks, siteResources } from "@server/db";
import { newts, newtSessions, sites } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
@@ -71,18 +71,23 @@ export async function deleteSite(
await deletePeer(site.exitNodeId!, site.pubKey);
}
} else if (site.type == "newt") {
// delete all of the site resources on this site
const siteResourcesOnSite = trx
.delete(siteResources)
.where(eq(siteResources.siteId, siteId))
.returning();
const networks = await trx
.select({ networkId: siteNetworks.networkId })
.from(siteNetworks)
.where(eq(siteNetworks.siteId, siteId));
// loop through them
for (const removedSiteResource of await siteResourcesOnSite) {
await rebuildClientAssociationsFromSiteResource(
removedSiteResource,
trx
);
for (const network of await networks) {
const [siteResource] = await trx
.select()
.from(siteResources)
.where(eq(siteResources.networkId, network.networkId));
if (siteResource) {
await rebuildClientAssociationsFromSiteResource(
siteResource,
trx
);
}
}
// get the newt on the site by querying the newt table for siteId

View File

@@ -0,0 +1,93 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, statusHistory } from "@server/db";
import { and, eq, gte, asc } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import {
computeBuckets,
statusHistoryQuerySchema,
StatusHistoryResponse
} from "@server/lib/statusHistory";
const siteParamsSchema = z.object({
siteId: z.string().transform((v) => parseInt(v, 10))
});
export async function getSiteStatusHistory(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = siteParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedQuery = statusHistoryQuerySchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error).toString()
)
);
}
const entityType = "site";
const entityId = parsedParams.data.siteId;
const { days } = parsedQuery.data;
const nowSec = Math.floor(Date.now() / 1000);
const startSec = nowSec - days * 86400;
const events = await db
.select()
.from(statusHistory)
.where(
and(
eq(statusHistory.entityType, entityType),
eq(statusHistory.entityId, entityId),
gte(statusHistory.timestamp, startSec)
)
)
.orderBy(asc(statusHistory.timestamp));
const { buckets, totalDowntime } = computeBuckets(events, days);
const totalWindow = days * 86400;
const overallUptime =
totalWindow > 0
? Math.max(
0,
((totalWindow - totalDowntime) / totalWindow) * 100
)
: 100;
return response<StatusHistoryResponse>(res, {
data: {
entityType,
entityId,
days: buckets,
overallUptimePercent: Math.round(overallUptime * 100) / 100,
totalDowntimeSeconds: totalDowntime
},
success: true,
error: false,
message: "Status history retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -1,4 +1,5 @@
export * from "./getSite";
export * from "./getStatusHistory";
export * from "./createSite";
export * from "./deleteSite";
export * from "./updateSite";

View File

@@ -5,6 +5,8 @@ import {
orgs,
roles,
roleSiteResources,
siteNetworks,
networks,
SiteResource,
siteResources,
sites,
@@ -17,17 +19,18 @@ import {
portRangeStringSchema
} from "@server/lib/ip";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations";
import response from "@server/lib/response";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode";
import { and, eq } from "drizzle-orm";
import { and, eq, inArray } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { validateAndConstructDomain } from "@server/lib/domainUtils";
const createSiteResourceParamsSchema = z.strictObject({
orgId: z.string()
@@ -36,12 +39,14 @@ const createSiteResourceParamsSchema = z.strictObject({
const createSiteResourceSchema = z
.strictObject({
name: z.string().min(1).max(255),
mode: z.enum(["host", "cidr", "port"]),
siteId: z.int(),
niceId: z.string().optional(),
// protocol: z.enum(["tcp", "udp"]).optional(),
mode: z.enum(["host", "cidr", "http"]),
ssl: z.boolean().optional(), // only used for http mode
scheme: z.enum(["http", "https"]).optional(),
siteIds: z.array(z.int()),
// proxyPort: z.int().positive().optional(),
// destinationPort: z.int().positive().optional(),
destinationPort: z.int().positive().optional(),
destination: z.string().min(1),
enabled: z.boolean().default(true),
alias: z
@@ -58,20 +63,24 @@ const createSiteResourceSchema = z
udpPortRangeString: portRangeStringSchema,
disableIcmp: z.boolean().optional(),
authDaemonPort: z.int().positive().optional(),
authDaemonMode: z.enum(["site", "remote"]).optional()
authDaemonMode: z.enum(["site", "remote"]).optional(),
domainId: z.string().optional(), // only used for http mode, we need this to verify the alias is unique within the org
subdomain: z.string().optional() // only used for http mode, we need this to verify the alias is unique within the org
})
.strict()
.refine(
(data) => {
if (data.mode === "host") {
// Check if it's a valid IP address using zod (v4 or v6)
const isValidIP = z
// .union([z.ipv4(), z.ipv6()])
.union([z.ipv4()]) // for now lets just do ipv4 until we verify ipv6 works everywhere
.safeParse(data.destination).success;
if (data.mode == "host") {
// Check if it's a valid IP address using zod (v4 or v6)
const isValidIP = z
// .union([z.ipv4(), z.ipv6()])
.union([z.ipv4()]) // for now lets just do ipv4 until we verify ipv6 works everywhere
.safeParse(data.destination).success;
if (isValidIP) {
return true;
if (isValidIP) {
return true;
}
}
// Check if it's a valid domain (hostname pattern, TLD not required)
@@ -106,6 +115,21 @@ const createSiteResourceSchema = z
{
message: "Destination must be a valid CIDR notation for cidr mode"
}
)
.refine(
(data) => {
if (data.mode !== "http") return true;
return (
data.scheme !== undefined &&
data.destinationPort !== undefined &&
data.destinationPort >= 1 &&
data.destinationPort <= 65535
);
},
{
message:
"HTTP mode requires scheme (http or https) and a valid destination port"
}
);
export type CreateSiteResourceBody = z.infer<typeof createSiteResourceSchema>;
@@ -160,14 +184,15 @@ export async function createSiteResource(
const { orgId } = parsedParams.data;
const {
name,
siteId,
niceId,
siteIds,
mode,
// protocol,
scheme,
// proxyPort,
// destinationPort,
destinationPort,
destination,
enabled,
ssl,
alias,
userIds,
roleIds,
@@ -176,18 +201,36 @@ export async function createSiteResource(
udpPortRangeString,
disableIcmp,
authDaemonPort,
authDaemonMode
authDaemonMode,
domainId,
subdomain
} = parsedBody.data;
if (mode == "http") {
const hasHttpFeature = await isLicensedOrSubscribed(
orgId,
tierMatrix[TierFeature.HTTPPrivateResources]
);
if (!hasHttpFeature) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"HTTP private resources are not included in your current plan. Please upgrade."
)
);
}
}
// Verify the site exists and belongs to the org
const [site] = await db
const sitesToAssign = await db
.select()
.from(sites)
.where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId)))
.limit(1);
.where(and(inArray(sites.siteId, siteIds), eq(sites.orgId, orgId)));
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
if (sitesToAssign.length !== siteIds.length) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Some site not found")
);
}
const [org] = await db
@@ -228,29 +271,50 @@ export async function createSiteResource(
);
}
// // check if resource with same protocol and proxy port already exists (only for port mode)
// if (mode === "port" && protocol && proxyPort) {
// const [existingResource] = await db
// .select()
// .from(siteResources)
// .where(
// and(
// eq(siteResources.siteId, siteId),
// eq(siteResources.orgId, orgId),
// eq(siteResources.protocol, protocol),
// eq(siteResources.proxyPort, proxyPort)
// )
// )
// .limit(1);
// if (existingResource && existingResource.siteResourceId) {
// return next(
// createHttpError(
// HttpCode.CONFLICT,
// "A resource with the same protocol and proxy port already exists"
// )
// );
// }
// }
if (domainId && alias) {
// throw an error because we can only have one or the other
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Alias and domain cannot both be set. Please choose one or the other."
)
);
}
let fullDomain: string | null = null;
let finalSubdomain: string | null = null;
if (domainId) {
// Validate domain and construct full domain
const domainResult = await validateAndConstructDomain(
domainId,
orgId,
subdomain
);
if (!domainResult.success) {
return next(
createHttpError(HttpCode.BAD_REQUEST, domainResult.error)
);
}
fullDomain = domainResult.fullDomain;
finalSubdomain = domainResult.subdomain;
// make sure the full domain is unique
const existingResource = await db
.select()
.from(siteResources)
.where(eq(siteResources.fullDomain, fullDomain));
if (existingResource.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Resource with that domain already exists"
)
);
}
}
// make sure the alias is unique within the org if provided
if (alias) {
@@ -286,27 +350,49 @@ export async function createSiteResource(
}
let aliasAddress: string | null = null;
if (mode == "host") {
// we can only have an alias on a host
if (mode === "host" || mode === "http") {
aliasAddress = await getNextAvailableAliasAddress(orgId);
}
let newSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => {
const [network] = await trx
.insert(networks)
.values({
scope: "resource",
orgId: orgId
})
.returning();
if (!network) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Failed to create network`
)
);
}
// Create the site resource
const insertValues: typeof siteResources.$inferInsert = {
siteId,
niceId: updatedNiceId!,
orgId,
name,
mode: mode as "host" | "cidr",
mode,
ssl,
networkId: network.networkId,
destination,
scheme,
destinationPort,
enabled,
alias,
alias: alias ? alias.trim() : null,
aliasAddress,
tcpPortRangeString,
udpPortRangeString,
disableIcmp
disableIcmp,
domainId,
subdomain: finalSubdomain,
fullDomain
};
if (isLicensedSshPam) {
if (authDaemonPort !== undefined)
@@ -323,6 +409,13 @@ export async function createSiteResource(
//////////////////// update the associations ////////////////////
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: network.networkId
});
}
const [adminRole] = await trx
.select()
.from(roles)
@@ -365,16 +458,21 @@ export async function createSiteResource(
);
}
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
for (const siteToAssign of sitesToAssign) {
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, siteToAssign.siteId))
.limit(1);
if (!newt) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found")
);
if (!newt) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Newt not found for site ${siteToAssign.siteId}`
)
);
}
}
await rebuildClientAssociationsFromSiteResource(
@@ -393,7 +491,7 @@ export async function createSiteResource(
}
logger.info(
`Created site resource ${newSiteResource.siteResourceId} for site ${siteId}`
`Created site resource ${newSiteResource.siteResourceId} for org ${orgId}`
);
return response(res, {

View File

@@ -70,17 +70,18 @@ export async function deleteSiteResource(
.where(and(eq(siteResources.siteResourceId, siteResourceId)))
.returning();
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, removedSiteResource.siteId))
.limit(1);
// not sure why this is here...
// const [newt] = await trx
// .select()
// .from(newts)
// .where(eq(newts.siteId, removedSiteResource.siteId))
// .limit(1);
if (!newt) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found")
);
}
// if (!newt) {
// return next(
// createHttpError(HttpCode.NOT_FOUND, "Newt not found")
// );
// }
await rebuildClientAssociationsFromSiteResource(
removedSiteResource,

View File

@@ -17,38 +17,34 @@ const getSiteResourceParamsSchema = z.strictObject({
.transform((val) => (val ? Number(val) : undefined))
.pipe(z.int().positive().optional())
.optional(),
siteId: z.string().transform(Number).pipe(z.int().positive()),
niceId: z.string().optional(),
orgId: z.string()
});
async function query(
siteResourceId?: number,
siteId?: number,
niceId?: string,
orgId?: string
) {
if (siteResourceId && siteId && orgId) {
if (siteResourceId && orgId) {
const [siteResource] = await db
.select()
.from(siteResources)
.where(
and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
)
)
.limit(1);
return siteResource;
} else if (niceId && siteId && orgId) {
} else if (niceId && orgId) {
const [siteResource] = await db
.select()
.from(siteResources)
.where(
and(
eq(siteResources.niceId, niceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
)
)
@@ -84,7 +80,6 @@ registry.registerPath({
request: {
params: z.object({
niceId: z.string(),
siteId: z.number(),
orgId: z.string()
})
},
@@ -107,10 +102,10 @@ export async function getSiteResource(
);
}
const { siteResourceId, siteId, niceId, orgId } = parsedParams.data;
const { siteResourceId, niceId, orgId } = parsedParams.data;
// Get the site resource
const siteResource = await query(siteResourceId, siteId, niceId, orgId);
const siteResource = await query(siteResourceId, niceId, orgId);
if (!siteResource) {
return next(

View File

@@ -1,4 +1,4 @@
import { db, SiteResource, siteResources, sites } from "@server/db";
import { db, DB_TYPE, SiteResource, siteNetworks, siteResources, sites } from "@server/db";
import response from "@server/lib/response";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
@@ -41,12 +41,12 @@ const listAllSiteResourcesByOrgQuerySchema = z.object({
}),
query: z.string().optional(),
mode: z
.enum(["host", "cidr"])
.enum(["host", "cidr", "http"])
.optional()
.catch(undefined)
.openapi({
type: "string",
enum: ["host", "cidr"],
enum: ["host", "cidr", "http"],
description: "Filter site resources by mode"
}),
sort_by: z
@@ -73,22 +73,58 @@ const listAllSiteResourcesByOrgQuerySchema = z.object({
export type ListAllSiteResourcesByOrgResponse = PaginatedResponse<{
siteResources: (SiteResource & {
siteName: string;
siteNiceId: string;
siteAddress: string | null;
siteOnlines: boolean[];
siteIds: number[];
siteNames: string[];
siteNiceIds: string[];
siteAddresses: (string | null)[];
})[];
}>;
/**
* Returns an aggregation expression compatible with both SQLite and PostgreSQL.
* - SQLite: json_group_array(col) → returns a JSON array string, parsed after fetch
* - PostgreSQL: array_agg(col) → returns a native array
*/
function aggCol<T>(column: any) {
if (DB_TYPE === "sqlite") {
return sql<T>`json_group_array(${column})`;
}
return sql<T>`array_agg(${column})`;
}
/**
* For SQLite the aggregated columns come back as JSON strings; parse them into
* proper arrays. For PostgreSQL the driver already returns native arrays, so
* the row is returned unchanged.
*/
function transformSiteResourceRow(row: any) {
if (DB_TYPE !== "sqlite") {
return row;
}
return {
...row,
siteNames: JSON.parse(row.siteNames) as string[],
siteNiceIds: JSON.parse(row.siteNiceIds) as string[],
siteIds: JSON.parse(row.siteIds) as number[],
siteAddresses: JSON.parse(row.siteAddresses) as (string | null)[],
// SQLite stores booleans as 0/1 integers
siteOnlines: (JSON.parse(row.siteOnlines) as (0 | 1)[]).map(
(v) => v === 1
) as boolean[]
};
}
function querySiteResourcesBase() {
return db
.select({
siteResourceId: siteResources.siteResourceId,
siteId: siteResources.siteId,
orgId: siteResources.orgId,
niceId: siteResources.niceId,
name: siteResources.name,
mode: siteResources.mode,
protocol: siteResources.protocol,
ssl: siteResources.ssl,
scheme: siteResources.scheme,
proxyPort: siteResources.proxyPort,
destinationPort: siteResources.destinationPort,
destination: siteResources.destination,
@@ -100,12 +136,24 @@ function querySiteResourcesBase() {
disableIcmp: siteResources.disableIcmp,
authDaemonMode: siteResources.authDaemonMode,
authDaemonPort: siteResources.authDaemonPort,
siteName: sites.name,
siteNiceId: sites.niceId,
siteAddress: sites.address
subdomain: siteResources.subdomain,
domainId: siteResources.domainId,
fullDomain: siteResources.fullDomain,
networkId: siteResources.networkId,
defaultNetworkId: siteResources.defaultNetworkId,
siteNames: aggCol<string[]>(sites.name),
siteNiceIds: aggCol<string[]>(sites.niceId),
siteIds: aggCol<number[]>(sites.siteId),
siteAddresses: aggCol<(string | null)[]>(sites.address),
siteOnlines: aggCol<boolean[]>(sites.online)
})
.from(siteResources)
.innerJoin(sites, eq(siteResources.siteId, sites.siteId));
.innerJoin(
siteNetworks,
eq(siteResources.networkId, siteNetworks.networkId)
)
.innerJoin(sites, eq(siteNetworks.siteId, sites.siteId))
.groupBy(siteResources.siteResourceId);
}
registry.registerPath({
@@ -193,10 +241,12 @@ export async function listAllSiteResourcesByOrg(
const baseQuery = querySiteResourcesBase().where(and(...conditions));
const countQuery = db.$count(
querySiteResourcesBase().where(and(...conditions)).as("filtered_site_resources")
querySiteResourcesBase()
.where(and(...conditions))
.as("filtered_site_resources")
);
const [siteResourcesList, totalCount] = await Promise.all([
const [siteResourcesRaw, totalCount] = await Promise.all([
baseQuery
.limit(pageSize)
.offset(pageSize * (page - 1))
@@ -210,6 +260,8 @@ export async function listAllSiteResourcesByOrg(
countQuery
]);
const siteResourcesList = siteResourcesRaw.map(transformSiteResourceRow);
return response<ListAllSiteResourcesByOrgResponse>(res, {
data: {
siteResources: siteResourcesList,
@@ -233,4 +285,4 @@ export async function listAllSiteResourcesByOrg(
)
);
}
}
}

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { db, networks, siteNetworks } from "@server/db";
import { siteResources, sites, SiteResource } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
@@ -108,13 +108,21 @@ export async function listSiteResources(
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// Get site resources
// Get site resources by joining networks to siteResources via siteNetworks
const siteResourcesList = await db
.select()
.from(siteResources)
.from(siteNetworks)
.innerJoin(
networks,
eq(siteNetworks.networkId, networks.networkId)
)
.innerJoin(
siteResources,
eq(siteResources.networkId, networks.networkId)
)
.where(
and(
eq(siteResources.siteId, siteId),
eq(siteNetworks.siteId, siteId),
eq(siteResources.orgId, orgId)
)
)
@@ -128,6 +136,7 @@ export async function listSiteResources(
.limit(limit)
.offset(offset);
return response(res, {
data: { siteResources: siteResourcesList },
success: true,

View File

@@ -1,4 +1,3 @@
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import {
clientSiteResources,
clientSiteResourcesAssociationsCache,
@@ -7,13 +6,21 @@ import {
orgs,
roles,
roleSiteResources,
siteNetworks,
SiteResource,
siteResources,
sites,
networks,
Transaction,
userSiteResources
} from "@server/db";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { validateAndConstructDomain } from "@server/lib/domainUtils";
import response from "@server/lib/response";
import { eq, and, ne, inArray } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { updatePeerData, updateTargets } from "@server/routers/client/targets";
import {
generateAliasConfig,
generateRemoteSubnets,
@@ -22,12 +29,8 @@ import {
portRangeStringSchema
} from "@server/lib/ip";
import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations";
import response from "@server/lib/response";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import { updatePeerData, updateTargets } from "@server/routers/client/targets";
import HttpCode from "@server/types/HttpCode";
import { and, eq, ne } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
@@ -40,7 +43,8 @@ const updateSiteResourceParamsSchema = z.strictObject({
const updateSiteResourceSchema = z
.strictObject({
name: z.string().min(1).max(255).optional(),
siteId: z.int(),
siteIds: z.array(z.int()),
// niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(),
niceId: z
.string()
.min(1)
@@ -51,10 +55,11 @@ const updateSiteResourceSchema = z
)
.optional(),
// mode: z.enum(["host", "cidr", "port"]).optional(),
mode: z.enum(["host", "cidr"]).optional(),
// protocol: z.enum(["tcp", "udp"]).nullish(),
mode: z.enum(["host", "cidr", "http"]).optional(),
ssl: z.boolean().optional(),
scheme: z.enum(["http", "https"]).nullish(),
// proxyPort: z.int().positive().nullish(),
// destinationPort: z.int().positive().nullish(),
destinationPort: z.int().positive().nullish(),
destination: z.string().min(1).optional(),
enabled: z.boolean().optional(),
alias: z
@@ -71,7 +76,9 @@ const updateSiteResourceSchema = z
udpPortRangeString: portRangeStringSchema,
disableIcmp: z.boolean().optional(),
authDaemonPort: z.int().positive().nullish(),
authDaemonMode: z.enum(["site", "remote"]).optional()
authDaemonMode: z.enum(["site", "remote"]).optional(),
domainId: z.string().optional(),
subdomain: z.string().optional()
})
.strict()
.refine(
@@ -118,6 +125,23 @@ const updateSiteResourceSchema = z
{
message: "Destination must be a valid CIDR notation for cidr mode"
}
)
.refine(
(data) => {
if (data.mode !== "http") return true;
return (
data.scheme !== undefined &&
data.scheme !== null &&
data.destinationPort !== undefined &&
data.destinationPort !== null &&
data.destinationPort >= 1 &&
data.destinationPort <= 65535
);
},
{
message:
"HTTP mode requires scheme (http or https) and a valid destination port"
}
);
export type UpdateSiteResourceBody = z.infer<typeof updateSiteResourceSchema>;
@@ -172,11 +196,14 @@ export async function updateSiteResource(
const { siteResourceId } = parsedParams.data;
const {
name,
siteId, // because it can change
siteIds, // because it can change
niceId,
mode,
scheme,
destination,
destinationPort,
alias,
ssl,
enabled,
userIds,
roleIds,
@@ -185,19 +212,11 @@ export async function updateSiteResource(
udpPortRangeString,
disableIcmp,
authDaemonPort,
authDaemonMode
authDaemonMode,
domainId,
subdomain
} = parsedBody.data;
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// Check if site resource exists
const [existingSiteResource] = await db
.select()
@@ -211,6 +230,21 @@ export async function updateSiteResource(
);
}
if (mode == "http") {
const hasHttpFeature = await isLicensedOrSubscribed(
existingSiteResource.orgId,
tierMatrix[TierFeature.HTTPPrivateResources]
);
if (!hasHttpFeature) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"HTTP private resources are not included in your current plan. Please upgrade."
)
);
}
}
const isLicensedSshPam = await isLicensedOrSubscribed(
existingSiteResource.orgId,
tierMatrix.sshPam
@@ -237,6 +271,23 @@ export async function updateSiteResource(
);
}
// Verify the site exists and belongs to the org
const sitesToAssign = await db
.select()
.from(sites)
.where(
and(
inArray(sites.siteId, siteIds),
eq(sites.orgId, existingSiteResource.orgId)
)
);
if (sitesToAssign.length !== siteIds.length) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Some site not found")
);
}
// Only check if destination is an IP address
const isIp = z
.union([z.ipv4(), z.ipv6()])
@@ -254,22 +305,60 @@ export async function updateSiteResource(
);
}
let existingSite = site;
let siteChanged = false;
if (existingSiteResource.siteId !== siteId) {
siteChanged = true;
// get the existing site
[existingSite] = await db
.select()
.from(sites)
.where(eq(sites.siteId, existingSiteResource.siteId))
.limit(1);
let sitesChanged = false;
const existingSiteIds = existingSiteResource.networkId
? await db
.select()
.from(siteNetworks)
.where(
eq(siteNetworks.networkId, existingSiteResource.networkId)
)
: [];
if (!existingSite) {
const existingSiteIdSet = new Set(existingSiteIds.map((s) => s.siteId));
const newSiteIdSet = new Set(siteIds);
if (
existingSiteIdSet.size !== newSiteIdSet.size ||
![...existingSiteIdSet].every((id) => newSiteIdSet.has(id))
) {
sitesChanged = true;
}
let fullDomain: string | null = null;
let finalSubdomain: string | null = null;
if (domainId) {
// Validate domain and construct full domain
const domainResult = await validateAndConstructDomain(
domainId,
org.orgId,
subdomain
);
if (!domainResult.success) {
return next(
createHttpError(HttpCode.BAD_REQUEST, domainResult.error)
);
}
fullDomain = domainResult.fullDomain;
finalSubdomain = domainResult.subdomain;
// make sure the full domain is unique
const [existingDomain] = await db
.select()
.from(siteResources)
.where(eq(siteResources.fullDomain, fullDomain));
if (
existingDomain &&
existingDomain.siteResourceId !==
existingSiteResource.siteResourceId
) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Existing site not found"
HttpCode.CONFLICT,
"Resource with that domain already exists"
)
);
}
@@ -302,7 +391,7 @@ export async function updateSiteResource(
let updatedSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => {
// if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place
if (siteChanged) {
if (sitesChanged) {
// delete the existing site resource
await trx
.delete(siteResources)
@@ -343,15 +432,20 @@ export async function updateSiteResource(
.update(siteResources)
.set({
name,
siteId,
niceId,
mode,
scheme,
ssl,
destination,
destinationPort,
enabled,
alias: alias && alias.trim() ? alias : null,
alias: alias ? alias.trim() : null,
tcpPortRangeString,
udpPortRangeString,
disableIcmp,
domainId,
subdomain: finalSubdomain,
fullDomain,
...sshPamSet
})
.where(
@@ -372,6 +466,23 @@ export async function updateSiteResource(
//////////////////// update the associations ////////////////////
// delete the site - site resources associations
await trx
.delete(siteNetworks)
.where(
eq(
siteNetworks.networkId,
updatedSiteResource.networkId!
)
);
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: updatedSiteResource.networkId!
});
}
const [adminRole] = await trx
.select()
.from(roles)
@@ -447,14 +558,20 @@ export async function updateSiteResource(
.update(siteResources)
.set({
name: name,
siteId: siteId,
niceId: niceId,
mode: mode,
scheme,
ssl,
destination: destination,
destinationPort: destinationPort,
enabled: enabled,
alias: alias && alias.trim() ? alias : null,
alias: alias ? alias.trim() : null,
tcpPortRangeString: tcpPortRangeString,
udpPortRangeString: udpPortRangeString,
disableIcmp: disableIcmp,
domainId,
subdomain: finalSubdomain,
fullDomain,
...sshPamSet
})
.where(
@@ -464,6 +581,23 @@ export async function updateSiteResource(
//////////////////// update the associations ////////////////////
// delete the site - site resources associations
await trx
.delete(siteNetworks)
.where(
eq(
siteNetworks.networkId,
updatedSiteResource.networkId!
)
);
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: updatedSiteResource.networkId!
});
}
await trx
.delete(clientSiteResources)
.where(
@@ -533,14 +667,15 @@ export async function updateSiteResource(
);
}
logger.info(
`Updated site resource ${siteResourceId} for site ${siteId}`
);
logger.info(`Updated site resource ${siteResourceId}`);
await handleMessagingForUpdatedSiteResource(
existingSiteResource,
updatedSiteResource,
{ siteId: site.siteId, orgId: site.orgId },
siteIds.map((siteId) => ({
siteId,
orgId: existingSiteResource.orgId
})),
trx
);
}
@@ -567,7 +702,7 @@ export async function updateSiteResource(
export async function handleMessagingForUpdatedSiteResource(
existingSiteResource: SiteResource | undefined,
updatedSiteResource: SiteResource,
site: { siteId: number; orgId: string },
sites: { siteId: number; orgId: string }[],
trx: Transaction
) {
logger.debug(
@@ -589,9 +724,14 @@ export async function handleMessagingForUpdatedSiteResource(
const destinationChanged =
existingSiteResource &&
existingSiteResource.destination !== updatedSiteResource.destination;
const destinationPortChanged =
existingSiteResource &&
existingSiteResource.destinationPort !==
updatedSiteResource.destinationPort;
const aliasChanged =
existingSiteResource &&
existingSiteResource.alias !== updatedSiteResource.alias;
(existingSiteResource.alias !== updatedSiteResource.alias ||
existingSiteResource.fullDomain !== updatedSiteResource.fullDomain); // because the full domain gets sent down to the stuff as an alias
const portRangesChanged =
existingSiteResource &&
(existingSiteResource.tcpPortRangeString !==
@@ -603,106 +743,122 @@ export async function handleMessagingForUpdatedSiteResource(
// if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all
if (destinationChanged || aliasChanged || portRangesChanged) {
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) {
throw new Error(
"Newt not found for site during site resource update"
);
}
// Only update targets on newt if destination changed
if (destinationChanged || portRangesChanged) {
const oldTargets = generateSubnetProxyTargetV2(
existingSiteResource,
mergedAllClients
);
const newTargets = generateSubnetProxyTargetV2(
updatedSiteResource,
mergedAllClients
);
await updateTargets(
newt.newtId,
{
oldTargets: oldTargets ? oldTargets : [],
newTargets: newTargets ? newTargets : []
},
newt.version
);
}
const olmJobs: Promise<void>[] = [];
for (const client of mergedAllClients) {
// does this client have access to another resource on this site that has the same destination still? if so we dont want to remove it from their olm yet
// todo: optimize this query if needed
const oldDestinationStillInUseSites = await trx
if (
destinationChanged ||
aliasChanged ||
portRangesChanged ||
destinationPortChanged
) {
for (const site of sites) {
const [newt] = await trx
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
siteResources.siteResourceId
)
)
.where(
and(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
),
eq(siteResources.siteId, site.siteId),
eq(
siteResources.destination,
existingSiteResource.destination
),
ne(
siteResources.siteResourceId,
existingSiteResource.siteResourceId
)
)
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) {
throw new Error(
"Newt not found for site during site resource update"
);
}
// Only update targets on newt if destination changed
if (
destinationChanged ||
portRangesChanged ||
destinationPortChanged
) {
const oldTargets = await generateSubnetProxyTargetV2(
existingSiteResource,
mergedAllClients
);
const newTargets = await generateSubnetProxyTargetV2(
updatedSiteResource,
mergedAllClients
);
const oldDestinationStillInUseByASite =
oldDestinationStillInUseSites.length > 0;
await updateTargets(
newt.newtId,
{
oldTargets: oldTargets ? oldTargets : [],
newTargets: newTargets ? newTargets : []
},
newt.version
);
}
// we also need to update the remote subnets on the olms for each client that has access to this site
olmJobs.push(
updatePeerData(
client.clientId,
updatedSiteResource.siteId,
destinationChanged
? {
oldRemoteSubnets: !oldDestinationStillInUseByASite
? generateRemoteSubnets([
existingSiteResource
])
: [],
newRemoteSubnets: generateRemoteSubnets([
updatedSiteResource
])
}
: undefined,
aliasChanged
? {
oldAliases: generateAliasConfig([
existingSiteResource
]),
newAliases: generateAliasConfig([
updatedSiteResource
])
}
: undefined
)
);
const olmJobs: Promise<void>[] = [];
for (const client of mergedAllClients) {
// does this client have access to another resource on this site that has the same destination still? if so we dont want to remove it from their olm yet
// todo: optimize this query if needed
const oldDestinationStillInUseSites = await trx
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
siteResources.siteResourceId
)
)
.innerJoin(
siteNetworks,
eq(siteNetworks.networkId, siteResources.networkId)
)
.where(
and(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
),
eq(siteNetworks.siteId, site.siteId),
eq(
siteResources.destination,
existingSiteResource.destination
),
ne(
siteResources.siteResourceId,
existingSiteResource.siteResourceId
)
)
);
const oldDestinationStillInUseByASite =
oldDestinationStillInUseSites.length > 0;
// we also need to update the remote subnets on the olms for each client that has access to this site
olmJobs.push(
updatePeerData(
client.clientId,
site.siteId,
destinationChanged
? {
oldRemoteSubnets:
!oldDestinationStillInUseByASite
? generateRemoteSubnets([
existingSiteResource
])
: [],
newRemoteSubnets: generateRemoteSubnets([
updatedSiteResource
])
}
: undefined,
aliasChanged
? {
oldAliases: generateAliasConfig([
existingSiteResource
]),
newAliases: generateAliasConfig([
updatedSiteResource
])
}
: undefined
)
);
}
await Promise.all(olmJobs);
}
await Promise.all(olmJobs);
}
}

View File

@@ -42,6 +42,8 @@ const createTargetSchema = z.strictObject({
hcMethod: z.string().min(1).optional().nullable(),
hcStatus: z.int().optional().nullable(),
hcTlsServerName: z.string().optional().nullable(),
hcHealthyThreshold: z.int().positive().min(1).optional().nullable(),
hcUnhealthyThreshold: z.int().positive().min(1).optional().nullable(),
path: z.string().optional().nullable(),
pathMatchType: z.enum(["exact", "prefix", "regex"]).optional().nullable(),
rewritePath: z.string().optional().nullable(),
@@ -226,7 +228,10 @@ export async function createTarget(
healthCheck = await db
.insert(targetHealthCheck)
.values({
orgId: resource.orgId,
targetId: newTarget[0].targetId,
siteId: targetData.siteId,
name: `Resource ${resource.name} - ${targetData.ip}:${targetData.port}`,
hcEnabled: targetData.hcEnabled ?? false,
hcPath: targetData.hcPath ?? null,
hcScheme: targetData.hcScheme ?? null,
@@ -241,7 +246,9 @@ export async function createTarget(
hcMethod: targetData.hcMethod ?? null,
hcStatus: targetData.hcStatus ?? null,
hcHealth: "unknown",
hcTlsServerName: targetData.hcTlsServerName ?? null
hcTlsServerName: targetData.hcTlsServerName ?? null,
hcHealthyThreshold: targetData.hcHealthyThreshold ?? null,
hcUnhealthyThreshold: targetData.hcUnhealthyThreshold ?? null
})
.returning();
@@ -271,8 +278,8 @@ export async function createTarget(
return response<CreateTargetResponse>(res, {
data: {
...newTarget[0],
...healthCheck[0]
...healthCheck[0],
...newTarget[0]
},
success: true,
error: false,

View File

@@ -15,8 +15,8 @@ const getTargetSchema = z.strictObject({
});
type GetTargetResponse = Target &
Omit<TargetHealthCheck, "hcHeaders"> & {
hcHeaders: { name: string; value: string }[] | null;
Partial<Omit<TargetHealthCheck, "hcHeaders" | "targetId">> & {
hcHeaders: { name: string; value: string }[] | null | undefined;
};
registry.registerPath({
@@ -70,20 +70,19 @@ export async function getTarget(
.limit(1);
// Parse hcHeaders from JSON string back to array
let parsedHcHeaders = null;
let parsedHcHeaders: { name: string; value: string }[] | null = null;
if (targetHc?.hcHeaders) {
try {
parsedHcHeaders = JSON.parse(targetHc.hcHeaders);
} catch (error) {
// If parsing fails, keep as string for backward compatibility
parsedHcHeaders = targetHc.hcHeaders;
// If parsing fails, keep as null for safety
}
}
return response<GetTargetResponse>(res, {
data: {
...target[0],
...targetHc,
...target[0],
hcHeaders: parsedHcHeaders
},
success: true,

View File

@@ -1,9 +1,19 @@
import { db, targets, resources, sites, targetHealthCheck } from "@server/db";
import {
db,
targets,
resources,
sites,
targetHealthCheck,
statusHistory
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { Newt } from "@server/db";
import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import { unknown } from "zod";
import {
fireHealthCheckHealthyAlert,
fireHealthCheckNotHealthyAlert
} from "#dynamic/lib/alerts";
interface TargetHealthStatus {
status: string;
@@ -11,7 +21,7 @@ interface TargetHealthStatus {
checkCount: number;
lastError?: string;
config: {
id: string;
id: string; // this could be the hc id or the target id, depending on the version of newt
hcEnabled: boolean;
hcPath?: string;
hcScheme?: string;
@@ -22,7 +32,11 @@ interface TargetHealthStatus {
hcUnhealthyInterval?: number;
hcTimeout?: number;
hcHeaders?: any;
hcFollowRedirects?: boolean;
hcMethod?: string;
hcTlsServerName?: string;
hcHealthyThreshold?: number;
hcUnhealthyThreshold?: number;
};
}
@@ -78,18 +92,26 @@ export const handleHealthcheckStatusMessage: MessageHandler = async (
.select({
targetId: targets.targetId,
siteId: targets.siteId,
orgId: targetHealthCheck.orgId,
targetHealthCheckId: targetHealthCheck.targetHealthCheckId,
resourceOrgId: resources.orgId,
resourceId: resources.resourceId,
name: targetHealthCheck.name,
hcStatus: targetHealthCheck.hcHealth
})
.from(targets)
.from(targetHealthCheck)
.innerJoin(
targets,
eq(targetHealthCheck.targetId, targets.targetId)
)
.innerJoin(
resources,
eq(targets.resourceId, resources.resourceId)
)
.innerJoin(sites, eq(targets.siteId, sites.siteId))
.innerJoin(targetHealthCheck, eq(targets.targetId, targetHealthCheck.targetId))
.where(
and(
eq(targets.targetId, targetIdNum),
eq(targetHealthCheck.targetHealthCheckId, targetIdNum),
eq(sites.siteId, newt.siteId)
)
)
@@ -120,8 +142,80 @@ export const handleHealthcheckStatusMessage: MessageHandler = async (
| "healthy"
| "unhealthy"
})
.where(eq(targetHealthCheck.targetId, targetIdNum))
.execute();
.where(eq(targetHealthCheck.targetId, targetCheck.targetId));
const orgId = targetCheck.orgId || targetCheck.resourceOrgId; // for backwards compatibility, check both orgId fields because the target health checks dont have the orgId
if (!orgId) {
logger.warn(
`No org ID found for target ${targetId}, skipping status history logging`
);
continue;
}
// Log the state change to status history
await db.insert(statusHistory).values({
entityType: "healthCheck",
entityId: targetCheck.targetHealthCheckId,
orgId: orgId,
status: healthStatus.status,
timestamp: Math.floor(Date.now() / 1000)
});
if (targetCheck.resourceId) {
// Log the state change to status history for the resource as well
// so we can show the resource status along with the site
// if the status is healthy we should check if ALL of the targets on the resource are currently healthy and if not then dont mark the resource as healthy yet, we want to wait until all targets are healthy to mark the resource as healthy
let status = healthStatus.status;
if (healthStatus.status === "healthy") {
const otherTargets = await db
.select({ hcHealth: targetHealthCheck.hcHealth })
.from(targets)
.innerJoin(
targetHealthCheck,
eq(targets.targetId, targetHealthCheck.targetId)
)
.where(
and(
eq(targets.resourceId, targetCheck.resourceId),
eq(targets.targetId, targetCheck.targetId) // only check the other targets, not the one we just updated
)
);
const allHealthy = otherTargets.every(
(t) => t.hcHealth === "healthy"
);
if (!allHealthy) {
logger.debug(
`Not marking resource ${targetCheck.resourceId} as healthy because not all targets are healthy`
);
status = "unhealthy";
}
}
await db.insert(statusHistory).values({
entityType: "resource",
entityId: targetCheck.resourceId,
orgId: orgId,
status: status,
timestamp: Math.floor(Date.now() / 1000)
});
}
// because we are checking above if there was a change we can fire the alert here because it changed
if (healthStatus.status === "unhealthy") {
await fireHealthCheckHealthyAlert(
orgId,
targetCheck.targetHealthCheckId,
targetCheck.name
);
} else if (healthStatus.status === "healthy") {
await fireHealthCheckNotHealthyAlert(
orgId,
targetCheck.targetHealthCheckId,
targetCheck.name
);
}
logger.debug(
`Updated health status for target ${targetId} to ${healthStatus.status}`

View File

@@ -43,6 +43,8 @@ const updateTargetBodySchema = z
hcMethod: z.string().min(1).optional().nullable(),
hcStatus: z.int().optional().nullable(),
hcTlsServerName: z.string().optional().nullable(),
hcHealthyThreshold: z.int().positive().min(1).optional().nullable(),
hcUnhealthyThreshold: z.int().positive().min(1).optional().nullable(),
path: z.string().optional().nullable(),
pathMatchType: z
.enum(["exact", "prefix", "regex"])
@@ -226,6 +228,7 @@ export async function updateTarget(
const [updatedHc] = await db
.update(targetHealthCheck)
.set({
siteId: parsedBody.data.siteId,
hcEnabled: parsedBody.data.hcEnabled || false,
hcPath: parsedBody.data.hcPath,
hcScheme: parsedBody.data.hcScheme,
@@ -240,6 +243,8 @@ export async function updateTarget(
hcMethod: parsedBody.data.hcMethod,
hcStatus: parsedBody.data.hcStatus,
hcTlsServerName: parsedBody.data.hcTlsServerName,
hcHealthyThreshold: parsedBody.data.hcHealthyThreshold,
hcUnhealthyThreshold: parsedBody.data.hcUnhealthyThreshold,
...(hcHealthValue !== undefined && { hcHealth: hcHealthValue })
})
.where(eq(targetHealthCheck.targetId, targetId))

View File

@@ -2,7 +2,7 @@ import { build } from "@server/build";
import {
handleNewtRegisterMessage,
handleReceiveBandwidthMessage,
handleGetConfigMessage,
handleNewtGetConfigMessage,
handleDockerStatusMessage,
handleDockerContainersMessage,
handleNewtPingRequestMessage,
@@ -37,7 +37,7 @@ export const messageHandlers: Record<string, MessageHandler> = {
"newt/disconnecting": handleNewtDisconnectingMessage,
"newt/ping": handleNewtPingMessage,
"newt/wg/register": handleNewtRegisterMessage,
"newt/wg/get-config": handleGetConfigMessage,
"newt/wg/get-config": handleNewtGetConfigMessage,
"newt/receive-bandwidth": handleReceiveBandwidthMessage,
"newt/socket/status": handleDockerStatusMessage,
"newt/socket/containers": handleDockerContainersMessage,
@@ -47,7 +47,7 @@ export const messageHandlers: Record<string, MessageHandler> = {
"ws/round-trip/complete": handleRoundTripMessage
};
// Start the ping accumulator for all builds it batches per-site online/lastPing
// Start the ping accumulator for all builds - it batches per-site online/lastPing
// updates into periodic bulk writes, preventing connection pool exhaustion.
startPingAccumulator();