Compare commits

...

14 Commits

Author SHA1 Message Date
Owen
05748bf8ff Merge branch 'dev' into msg-delivery 2026-01-16 12:22:23 -08:00
Owen
f8c98bf6bf Fix log messages 2026-01-16 12:19:52 -08:00
Owen
a1ea3f74b3 Move the query into the sync 2026-01-15 22:00:13 -08:00
Owen
65e8bfc93e Message syncing works 2026-01-15 21:26:13 -08:00
Owen
d52bd65d21 Fix build 2026-01-14 17:54:34 -08:00
Owen
eb0cdda0f9 Merge branch 'dev' into msg-delivery 2026-01-12 21:17:38 -08:00
Owen
eba25fcc4d Add increment options and slight cleanup 2026-01-12 20:48:18 -08:00
Owen
0ccd5714f9 Seperating out functions 2025-12-24 11:50:27 -05:00
Owen
e2dfc3eb20 Merge branch 'dev' into msg-delivery 2025-12-24 11:33:41 -05:00
Owen
446eba8bc9 Orging how we are going to make the sync 2025-12-24 10:38:44 -05:00
Owen
18579c0647 Merge branch 'dev' into msg-delivery 2025-12-23 16:57:17 -05:00
Owen
0d37e08638 Merge branch 'dev' into msg-delivery 2025-12-23 16:56:50 -05:00
Owen
75b9703793 Seperate config gen into functions 2025-12-20 11:41:23 -05:00
Owen
322f3bfb1d Add version and send it down 2025-12-19 16:44:57 -05:00
21 changed files with 1116 additions and 500 deletions

View File

@@ -26,6 +26,7 @@ export function initLogCleanupInterval() {
) )
); );
// TODO: handle when there are multiple nodes doing this clearing using redis
for (const org of orgsToClean) { for (const org of orgsToClean) {
const { const {
orgId, orgId,

View File

@@ -50,10 +50,14 @@ export async function sendToExitNode(
); );
} }
return sendToClient(remoteExitNode.remoteExitNodeId, { return sendToClient(
type: request.remoteType, remoteExitNode.remoteExitNodeId,
data: request.data {
}); type: request.remoteType,
data: request.data
},
{ incrementConfigVersion: true }
);
} else { } else {
let hostname = exitNode.reachableAt; let hostname = exitNode.reachableAt;

View File

@@ -573,6 +573,20 @@ class RedisManager {
} }
} }
public async incr(key: string): Promise<number> {
if (!this.isRedisEnabled() || !this.writeClient) return 0;
try {
return await this.executeWithRetry(
() => this.writeClient!.incr(key),
"Redis INCR"
);
} catch (error) {
logger.error("Redis INCR error:", error);
return 0;
}
}
public async sadd(key: string, member: string): Promise<boolean> { public async sadd(key: string, member: string): Promise<boolean> {
if (!this.isRedisEnabled() || !this.writeClient) return false; if (!this.isRedisEnabled() || !this.writeClient) return false;

View File

@@ -43,7 +43,8 @@ import {
WSMessage, WSMessage,
TokenPayload, TokenPayload,
WebSocketRequest, WebSocketRequest,
RedisMessage RedisMessage,
SendMessageOptions
} from "@server/routers/ws"; } from "@server/routers/ws";
import { validateSessionToken } from "@server/auth/sessions/app"; import { validateSessionToken } from "@server/auth/sessions/app";
@@ -118,12 +119,21 @@ const processMessage = async (
if (response.broadcast) { if (response.broadcast) {
await broadcastToAllExcept( await broadcastToAllExcept(
response.message, response.message,
response.excludeSender ? clientId : undefined response.excludeSender ? clientId : undefined,
response.options
); );
} else if (response.targetClientId) { } else if (response.targetClientId) {
await sendToClient(response.targetClientId, response.message); await sendToClient(
response.targetClientId,
response.message,
response.options
);
} else { } else {
ws.send(JSON.stringify(response.message)); await sendToClient(
clientId,
response.message,
response.options
);
} }
} }
} catch (error) { } catch (error) {
@@ -172,6 +182,9 @@ const REDIS_CHANNEL = "websocket_messages";
// Client tracking map (local to this node) // Client tracking map (local to this node)
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map(); const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
// Config version tracking map (local to this node, resets on server restart)
const clientConfigVersions: Map<string, number> = new Map();
// Recovery tracking // Recovery tracking
let isRedisRecoveryInProgress = false; let isRedisRecoveryInProgress = false;
@@ -182,6 +195,8 @@ const getClientMapKey = (clientId: string) => clientId;
const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`; const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`;
const getNodeConnectionsKey = (nodeId: string, clientId: string) => const getNodeConnectionsKey = (nodeId: string, clientId: string) =>
`ws:node:${nodeId}:${clientId}`; `ws:node:${nodeId}:${clientId}`;
const getConfigVersionKey = (clientId: string) =>
`ws:configVersion:${clientId}`;
// Initialize Redis subscription for cross-node messaging // Initialize Redis subscription for cross-node messaging
const initializeRedisSubscription = async (): Promise<void> => { const initializeRedisSubscription = async (): Promise<void> => {
@@ -304,6 +319,45 @@ const addClient = async (
existingClients.push(ws); existingClients.push(ws);
connectedClients.set(mapKey, existingClients); connectedClients.set(mapKey, existingClients);
// Get or initialize config version
let configVersion = 0;
// Check Redis first if enabled
if (redisManager.isRedisEnabled()) {
try {
const redisVersion = await redisManager.get(getConfigVersionKey(clientId));
if (redisVersion !== null) {
configVersion = parseInt(redisVersion, 10);
// Sync to local cache
clientConfigVersions.set(clientId, configVersion);
} else if (!clientConfigVersions.has(clientId)) {
// No version in Redis or local cache, initialize to 0
await redisManager.set(getConfigVersionKey(clientId), "0");
clientConfigVersions.set(clientId, 0);
} else {
// Use local cache version and sync to Redis
configVersion = clientConfigVersions.get(clientId) || 0;
await redisManager.set(getConfigVersionKey(clientId), configVersion.toString());
}
} catch (error) {
logger.error("Failed to get/set config version in Redis:", error);
// Fall back to local cache
if (!clientConfigVersions.has(clientId)) {
clientConfigVersions.set(clientId, 0);
}
configVersion = clientConfigVersions.get(clientId) || 0;
}
} else {
// Redis not enabled, use local cache only
if (!clientConfigVersions.has(clientId)) {
clientConfigVersions.set(clientId, 0);
}
configVersion = clientConfigVersions.get(clientId) || 0;
}
// Set config version on websocket
ws.configVersion = configVersion;
// Add to Redis tracking if enabled // Add to Redis tracking if enabled
if (redisManager.isRedisEnabled()) { if (redisManager.isRedisEnabled()) {
try { try {
@@ -322,7 +376,7 @@ const addClient = async (
} }
logger.info( logger.info(
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}` `Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}, Config version: ${configVersion}`
); );
}; };
@@ -377,53 +431,133 @@ const removeClient = async (
} }
}; };
// Helper to get the current config version for a client
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
// Try Redis first if available
if (redisManager.isRedisEnabled()) {
try {
const redisVersion = await redisManager.get(
getConfigVersionKey(clientId)
);
if (redisVersion !== null) {
const version = parseInt(redisVersion, 10);
// Sync local cache with Redis
clientConfigVersions.set(clientId, version);
return version;
}
} catch (error) {
logger.error("Failed to get config version from Redis:", error);
}
}
// Fall back to local cache
return clientConfigVersions.get(clientId);
};
// Helper to increment and get the new config version for a client
const incrementClientConfigVersion = async (
clientId: string
): Promise<number> => {
let newVersion: number;
if (redisManager.isRedisEnabled()) {
try {
// Use Redis INCR for atomic increment across nodes
newVersion = await redisManager.incr(getConfigVersionKey(clientId));
// Sync local cache
clientConfigVersions.set(clientId, newVersion);
return newVersion;
} catch (error) {
logger.error("Failed to increment config version in Redis:", error);
// Fall through to local increment
}
}
// Local increment
const currentVersion = clientConfigVersions.get(clientId) || 0;
newVersion = currentVersion + 1;
clientConfigVersions.set(clientId, newVersion);
return newVersion;
};
// Local message sending (within this node) // Local message sending (within this node)
const sendToClientLocal = async ( const sendToClientLocal = async (
clientId: string, clientId: string,
message: WSMessage message: WSMessage,
options: SendMessageOptions = {}
): Promise<boolean> => { ): Promise<boolean> => {
const mapKey = getClientMapKey(clientId); const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey); const clients = connectedClients.get(mapKey);
if (!clients || clients.length === 0) { if (!clients || clients.length === 0) {
return false; return false;
} }
const messageString = JSON.stringify(message);
// Handle config version
let configVersion = await getClientConfigVersion(clientId);
// Add config version to message
const messageWithVersion = {
...message,
configVersion
};
const messageString = JSON.stringify(messageWithVersion);
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(messageString); client.send(messageString);
} }
}); });
logger.debug(
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
);
return true; return true;
}; };
const broadcastToAllExceptLocal = async ( const broadcastToAllExceptLocal = async (
message: WSMessage, message: WSMessage,
excludeClientId?: string excludeClientId?: string,
options: SendMessageOptions = {}
): Promise<void> => { ): Promise<void> => {
connectedClients.forEach((clients, mapKey) => { for (const [mapKey, clients] of connectedClients.entries()) {
const [type, id] = mapKey.split(":"); const [type, id] = mapKey.split(":");
if (!(excludeClientId && id === excludeClientId)) { const clientId = mapKey; // mapKey is the clientId
if (!(excludeClientId && clientId === excludeClientId)) {
// Handle config version per client
let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId);
}
// Add config version to message
const messageWithVersion = {
...message,
configVersion
};
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(message)); client.send(JSON.stringify(messageWithVersion));
} }
}); });
} }
}); }
}; };
// Cross-node message sending (via Redis) // Cross-node message sending (via Redis)
const sendToClient = async ( const sendToClient = async (
clientId: string, clientId: string,
message: WSMessage message: WSMessage,
options: SendMessageOptions = {}
): Promise<boolean> => { ): Promise<boolean> => {
let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId);
}
logger.debug(
`sendToClient: Message type ${message.type} sent to clientId ${clientId} (new configVersion: ${configVersion})`
);
// Try to send locally first // Try to send locally first
const localSent = await sendToClientLocal(clientId, message); const localSent = await sendToClientLocal(clientId, message, options);
// Only send via Redis if the client is not connected locally and Redis is enabled // Only send via Redis if the client is not connected locally and Redis is enabled
if (!localSent && redisManager.isRedisEnabled()) { if (!localSent && redisManager.isRedisEnabled()) {
@@ -431,7 +565,10 @@ const sendToClient = async (
const redisMessage: RedisMessage = { const redisMessage: RedisMessage = {
type: "direct", type: "direct",
targetClientId: clientId, targetClientId: clientId,
message, message: {
...message,
configVersion
},
fromNodeId: NODE_ID fromNodeId: NODE_ID
}; };
@@ -458,19 +595,22 @@ const sendToClient = async (
const broadcastToAllExcept = async ( const broadcastToAllExcept = async (
message: WSMessage, message: WSMessage,
excludeClientId?: string excludeClientId?: string,
options: SendMessageOptions = {}
): Promise<void> => { ): Promise<void> => {
// Broadcast locally // Broadcast locally
await broadcastToAllExceptLocal(message, excludeClientId); await broadcastToAllExceptLocal(message, excludeClientId, options);
// If Redis is enabled, also broadcast via Redis pub/sub to other nodes // If Redis is enabled, also broadcast via Redis pub/sub to other nodes
// Note: For broadcasts, we include the options so remote nodes can handle versioning
if (redisManager.isRedisEnabled()) { if (redisManager.isRedisEnabled()) {
try { try {
const redisMessage: RedisMessage = { const redisMessage: RedisMessage = {
type: "broadcast", type: "broadcast",
excludeClientId, excludeClientId,
message, message,
fromNodeId: NODE_ID fromNodeId: NODE_ID,
options
}; };
await redisManager.publish( await redisManager.publish(
@@ -936,5 +1076,6 @@ export {
getActiveNodes, getActiveNodes,
disconnectClient, disconnectClient,
NODE_ID, NODE_ID,
cleanup cleanup,
getClientConfigVersion
}; };

View File

@@ -28,7 +28,7 @@ export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
await sendToClient(newtId, { await sendToClient(newtId, {
type: `newt/wg/targets/add`, type: `newt/wg/targets/add`,
data: batches[i] data: batches[i]
}); }, { incrementConfigVersion: true });
} }
} }
@@ -44,7 +44,7 @@ export async function removeTargets(
await sendToClient(newtId, { await sendToClient(newtId, {
type: `newt/wg/targets/remove`, type: `newt/wg/targets/remove`,
data: batches[i] data: batches[i]
}); },{ incrementConfigVersion: true });
} }
} }
@@ -69,7 +69,7 @@ export async function updateTargets(
oldTargets: oldBatches[i] || [], oldTargets: oldBatches[i] || [],
newTargets: newBatches[i] || [] newTargets: newBatches[i] || []
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }
@@ -101,7 +101,7 @@ export async function addPeerData(
remoteSubnets: remoteSubnets, remoteSubnets: remoteSubnets,
aliases: aliases aliases: aliases
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }
@@ -132,7 +132,7 @@ export async function removePeerData(
remoteSubnets: remoteSubnets, remoteSubnets: remoteSubnets,
aliases: aliases aliases: aliases
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }
@@ -173,7 +173,7 @@ export async function updatePeerData(
...remoteSubnets, ...remoteSubnets,
...aliases ...aliases
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }

View File

@@ -0,0 +1,278 @@
import { clients, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, ExitNode, resources, Site, siteResources, targetHealthCheck, targets } from "@server/db";
import logger from "@server/logger";
import { initPeerAddHandshake, updatePeer } from "../olm/peers";
import { eq, and } from "drizzle-orm";
import config from "@server/lib/config";
import { generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip";
export async function buildClientConfigurationForNewtClient(
site: Site,
exitNode?: ExitNode
) {
const siteId = site.siteId;
// Get all clients connected to this site
const clientsRes = await db
.select()
.from(clients)
.innerJoin(
clientSitesAssociationsCache,
eq(clients.clientId, clientSitesAssociationsCache.clientId)
)
.where(eq(clientSitesAssociationsCache.siteId, siteId));
let peers: Array<{
publicKey: string;
allowedIps: string[];
endpoint?: string;
}> = [];
if (site.publicKey && site.endpoint && exitNode) {
// Prepare peers data for the response
peers = await Promise.all(
clientsRes
.filter((client) => {
if (!client.clients.pubKey) {
logger.warn(
`Client ${client.clients.clientId} has no public key, skipping`
);
return false;
}
if (!client.clients.subnet) {
logger.warn(
`Client ${client.clients.clientId} has no subnet, skipping`
);
return false;
}
return true;
})
.map(async (client) => {
// Add or update this peer on the olm if it is connected
// const allSiteResources = await db // only get the site resources that this client has access to
// .select()
// .from(siteResources)
// .innerJoin(
// clientSiteResourcesAssociationsCache,
// eq(
// siteResources.siteResourceId,
// clientSiteResourcesAssociationsCache.siteResourceId
// )
// )
// .where(
// and(
// eq(siteResources.siteId, site.siteId),
// eq(
// clientSiteResourcesAssociationsCache.clientId,
// client.clients.clientId
// )
// )
// );
// update the peer info on the olm
// if the peer has not been added yet this will be a no-op
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: site.endpoint!,
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
publicKey: site.publicKey!,
serverIP: site.address,
serverPort: site.listenPort
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// ),
// aliases: generateAliasConfig(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// )
});
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clients.clientId,
{
siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
}
);
return {
publicKey: client.clients.pubKey!,
allowedIps: [
`${client.clients.subnet.split("/")[0]}/32`
], // we want to only allow from that client
endpoint: client.clientSitesAssociationsCache.isRelayed
? ""
: client.clientSitesAssociationsCache.endpoint! // if its relayed it should be localhost
};
})
);
}
// Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null);
// Get all enabled site resources for this site
const allSiteResources = await db
.select()
.from(siteResources)
.where(eq(siteResources.siteId, siteId));
const targetsToSend: SubnetProxyTarget[] = [];
for (const resource of allSiteResources) {
// Get clients associated with this specific resource
const resourceClients = await db
.select({
clientId: clients.clientId,
pubKey: clients.pubKey,
subnet: clients.subnet
})
.from(clients)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
clients.clientId,
clientSiteResourcesAssociationsCache.clientId
)
)
.where(
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
resource.siteResourceId
)
);
const resourceTargets = generateSubnetProxyTargets(
resource,
resourceClients
);
targetsToSend.push(...resourceTargets);
}
return {
peers: validPeers,
targets: targetsToSend
};
}
export async function buildTargetConfigurationForNewtClient(siteId: number) {
// Get all enabled targets with their resource protocol information
const allTargets = await db
.select({
resourceId: targets.resourceId,
targetId: targets.targetId,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled,
protocol: resources.protocol,
hcEnabled: targetHealthCheck.hcEnabled,
hcPath: targetHealthCheck.hcPath,
hcScheme: targetHealthCheck.hcScheme,
hcMode: targetHealthCheck.hcMode,
hcHostname: targetHealthCheck.hcHostname,
hcPort: targetHealthCheck.hcPort,
hcInterval: targetHealthCheck.hcInterval,
hcUnhealthyInterval: targetHealthCheck.hcUnhealthyInterval,
hcTimeout: targetHealthCheck.hcTimeout,
hcHeaders: targetHealthCheck.hcHeaders,
hcMethod: targetHealthCheck.hcMethod,
hcTlsServerName: targetHealthCheck.hcTlsServerName
})
.from(targets)
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
.leftJoin(
targetHealthCheck,
eq(targets.targetId, targetHealthCheck.targetId)
)
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
const { tcpTargets, udpTargets } = allTargets.reduce(
(acc, target) => {
// Filter out invalid targets
if (!target.internalPort || !target.ip || !target.port) {
return acc;
}
// Format target into string
const formattedTarget = `${target.internalPort}:${target.ip}:${target.port}`;
// Add to the appropriate protocol array
if (target.protocol === "tcp") {
acc.tcpTargets.push(formattedTarget);
} else {
acc.udpTargets.push(formattedTarget);
}
return acc;
},
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
);
const healthCheckTargets = allTargets.map((target) => {
// make sure the stuff is defined
if (
!target.hcPath ||
!target.hcHostname ||
!target.hcPort ||
!target.hcInterval ||
!target.hcMethod
) {
logger.debug(
`Skipping target ${target.targetId} due to missing health check fields`
);
return null; // Skip targets with missing health check fields
}
// parse headers
const hcHeadersParse = target.hcHeaders
? JSON.parse(target.hcHeaders)
: null;
const hcHeadersSend: { [key: string]: string } = {};
if (hcHeadersParse) {
hcHeadersParse.forEach(
(header: { name: string; value: string }) => {
hcHeadersSend[header.name] = header.value;
}
);
}
return {
id: target.targetId,
hcEnabled: target.hcEnabled,
hcPath: target.hcPath,
hcScheme: target.hcScheme,
hcMode: target.hcMode,
hcHostname: target.hcHostname,
hcPort: target.hcPort,
hcInterval: target.hcInterval, // in seconds
hcUnhealthyInterval: target.hcUnhealthyInterval, // in seconds
hcTimeout: target.hcTimeout, // in seconds
hcHeaders: hcHeadersSend,
hcMethod: target.hcMethod,
hcTlsServerName: target.hcTlsServerName
};
});
// Filter out any null values from health check targets
const validHealthCheckTargets = healthCheckTargets.filter(
(target) => target !== null
);
return {
validHealthCheckTargets,
tcpTargets,
udpTargets
};
}

View File

@@ -2,19 +2,10 @@ import { z } from "zod";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { import { db, ExitNode, exitNodes, Newt, sites } from "@server/db";
db,
ExitNode,
exitNodes,
siteResources,
clientSiteResourcesAssociationsCache
} from "@server/db";
import { clients, clientSitesAssociationsCache, Newt, sites } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { initPeerAddHandshake, updatePeer } from "../olm/peers";
import { sendToExitNode } from "#dynamic/lib/exitNodes"; import { sendToExitNode } from "#dynamic/lib/exitNodes";
import { generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip"; import { buildClientConfigurationForNewtClient } from "./buildConfiguration";
import config from "@server/lib/config";
const inputSchema = z.object({ const inputSchema = z.object({
publicKey: z.string(), publicKey: z.string(),
@@ -130,167 +121,18 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
} }
} }
// Get all clients connected to this site const { peers, targets } = await buildClientConfigurationForNewtClient(
const clientsRes = await db site,
.select() exitNode
.from(clients) );
.innerJoin(
clientSitesAssociationsCache,
eq(clients.clientId, clientSitesAssociationsCache.clientId)
)
.where(eq(clientSitesAssociationsCache.siteId, siteId));
let peers: Array<{
publicKey: string;
allowedIps: string[];
endpoint?: string;
}> = [];
if (site.publicKey && site.endpoint && exitNode) {
// Prepare peers data for the response
peers = await Promise.all(
clientsRes
.filter((client) => {
if (!client.clients.pubKey) {
logger.warn(
`Client ${client.clients.clientId} has no public key, skipping`
);
return false;
}
if (!client.clients.subnet) {
logger.warn(
`Client ${client.clients.clientId} has no subnet, skipping`
);
return false;
}
return true;
})
.map(async (client) => {
// Add or update this peer on the olm if it is connected
// const allSiteResources = await db // only get the site resources that this client has access to
// .select()
// .from(siteResources)
// .innerJoin(
// clientSiteResourcesAssociationsCache,
// eq(
// siteResources.siteResourceId,
// clientSiteResourcesAssociationsCache.siteResourceId
// )
// )
// .where(
// and(
// eq(siteResources.siteId, site.siteId),
// eq(
// clientSiteResourcesAssociationsCache.clientId,
// client.clients.clientId
// )
// )
// );
// update the peer info on the olm
// if the peer has not been added yet this will be a no-op
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: site.endpoint!,
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
publicKey: site.publicKey!,
serverIP: site.address,
serverPort: site.listenPort
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// ),
// aliases: generateAliasConfig(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// )
});
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clients.clientId,
{
siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
}
);
return {
publicKey: client.clients.pubKey!,
allowedIps: [
`${client.clients.subnet.split("/")[0]}/32`
], // we want to only allow from that client
endpoint: client.clientSitesAssociationsCache.isRelayed
? ""
: client.clientSitesAssociationsCache.endpoint! // if its relayed it should be localhost
};
})
);
}
// Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null);
// Get all enabled site resources for this site
const allSiteResources = await db
.select()
.from(siteResources)
.where(eq(siteResources.siteId, siteId));
const targetsToSend: SubnetProxyTarget[] = [];
for (const resource of allSiteResources) {
// Get clients associated with this specific resource
const resourceClients = await db
.select({
clientId: clients.clientId,
pubKey: clients.pubKey,
subnet: clients.subnet
})
.from(clients)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
clients.clientId,
clientSiteResourcesAssociationsCache.clientId
)
)
.where(
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
resource.siteResourceId
)
);
const resourceTargets = generateSubnetProxyTargets(
resource,
resourceClients
);
targetsToSend.push(...resourceTargets);
}
// Build the configuration response
const configResponse = {
ipAddress: site.address,
peers: validPeers,
targets: targetsToSend
};
logger.debug("Sending config: ", configResponse);
return { return {
message: { message: {
type: "newt/wg/receive-config", type: "newt/wg/receive-config",
data: { data: {
...configResponse ipAddress: site.address,
peers,
targets
} }
}, },
broadcast: false, broadcast: false,

View File

@@ -0,0 +1,163 @@
import { db, sites } from "@server/db";
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
import { MessageHandler } from "@server/routers/ws";
import { clients, Newt } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm";
import logger from "@server/logger";
import { validateSessionToken } from "@server/auth/sessions/app";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { sendTerminateClient } from "../client/terminate";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { sendNewtSyncMessage } from "./sync";
// Track if the offline checker interval is running
// let offlineCheckerInterval: NodeJS.Timeout | null = null;
// const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
// const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
/**
* Starts the background interval that checks for clients that haven't pinged recently
* and marks them as offline
*/
// export const startNewtOfflineChecker = (): void => {
// if (offlineCheckerInterval) {
// return; // Already running
// }
// offlineCheckerInterval = setInterval(async () => {
// try {
// const twoMinutesAgo = Math.floor(
// (Date.now() - OFFLINE_THRESHOLD_MS) / 1000
// );
// // TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING
// // Find clients that haven't pinged in the last 2 minutes and mark them as offline
// const offlineClients = await db
// .update(clients)
// .set({ online: false })
// .where(
// and(
// eq(clients.online, true),
// or(
// lt(clients.lastPing, twoMinutesAgo),
// isNull(clients.lastPing)
// )
// )
// )
// .returning();
// for (const offlineClient of offlineClients) {
// logger.info(
// `Kicking offline newt client ${offlineClient.clientId} due to inactivity`
// );
// if (!offlineClient.newtId) {
// logger.warn(
// `Offline client ${offlineClient.clientId} has no newtId, cannot disconnect`
// );
// continue;
// }
// // Send a disconnect message to the client if connected
// try {
// await sendTerminateClient(
// offlineClient.clientId,
// offlineClient.newtId
// ); // terminate first
// // wait a moment to ensure the message is sent
// await new Promise((resolve) => setTimeout(resolve, 1000));
// await disconnectClient(offlineClient.newtId);
// } catch (error) {
// logger.error(
// `Error sending disconnect to offline newt ${offlineClient.clientId}`,
// { error }
// );
// }
// }
// } catch (error) {
// logger.error("Error in offline checker interval", { error });
// }
// }, OFFLINE_CHECK_INTERVAL);
// logger.debug("Started offline checker interval");
// };
/**
* Stops the background interval that checks for offline clients
*/
// export const stopNewtOfflineChecker = (): void => {
// if (offlineCheckerInterval) {
// clearInterval(offlineCheckerInterval);
// offlineCheckerInterval = null;
// logger.info("Stopped offline checker interval");
// }
// };
/**
* Handles ping messages from clients and responds with pong
*/
export const handleNewtPingMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const newt = c as Newt;
if (!newt) {
logger.warn("Newt ping message: Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt ping message: has no site ID");
return;
}
// get the version
const configVersion = await getClientConfigVersion(newt.newtId);
if (message.configVersion && configVersion != null && configVersion != message.configVersion) {
logger.warn(
`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
);
// get the site
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, newt.siteId))
.limit(1);
if (!site) {
logger.warn(
`Newt ping message: site with ID ${newt.siteId} not found`
);
return;
}
await sendNewtSyncMessage(newt, site);
}
// try {
// // Update the client's last ping timestamp
// await db
// .update(clients)
// .set({
// lastPing: Math.floor(Date.now() / 1000),
// online: true
// })
// .where(eq(clients.clientId, newt.clientId));
// } catch (error) {
// logger.error("Error handling ping message", { error });
// }
return {
message: {
type: "pong",
data: {
timestamp: new Date().toISOString()
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -18,6 +18,7 @@ import {
} from "#dynamic/lib/exitNodes"; } from "#dynamic/lib/exitNodes";
import { fetchContainers } from "./dockerSocket"; import { fetchContainers } from "./dockerSocket";
import { lockManager } from "#dynamic/lib/lock"; import { lockManager } from "#dynamic/lib/lock";
import { buildTargetConfigurationForNewtClient } from "./buildConfiguration";
export type ExitNodePingResult = { export type ExitNodePingResult = {
exitNodeId: number; exitNodeId: number;
@@ -233,109 +234,8 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
.where(eq(newts.newtId, newt.newtId)); .where(eq(newts.newtId, newt.newtId));
} }
// Get all enabled targets with their resource protocol information const { tcpTargets, udpTargets, validHealthCheckTargets } =
const allTargets = await db await buildTargetConfigurationForNewtClient(siteId);
.select({
resourceId: targets.resourceId,
targetId: targets.targetId,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled,
protocol: resources.protocol,
hcEnabled: targetHealthCheck.hcEnabled,
hcPath: targetHealthCheck.hcPath,
hcScheme: targetHealthCheck.hcScheme,
hcMode: targetHealthCheck.hcMode,
hcHostname: targetHealthCheck.hcHostname,
hcPort: targetHealthCheck.hcPort,
hcInterval: targetHealthCheck.hcInterval,
hcUnhealthyInterval: targetHealthCheck.hcUnhealthyInterval,
hcTimeout: targetHealthCheck.hcTimeout,
hcHeaders: targetHealthCheck.hcHeaders,
hcMethod: targetHealthCheck.hcMethod,
hcTlsServerName: targetHealthCheck.hcTlsServerName
})
.from(targets)
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
.leftJoin(
targetHealthCheck,
eq(targets.targetId, targetHealthCheck.targetId)
)
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
const { tcpTargets, udpTargets } = allTargets.reduce(
(acc, target) => {
// Filter out invalid targets
if (!target.internalPort || !target.ip || !target.port) {
return acc;
}
// Format target into string
const formattedTarget = `${target.internalPort}:${target.ip}:${target.port}`;
// Add to the appropriate protocol array
if (target.protocol === "tcp") {
acc.tcpTargets.push(formattedTarget);
} else {
acc.udpTargets.push(formattedTarget);
}
return acc;
},
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
);
const healthCheckTargets = allTargets.map((target) => {
// make sure the stuff is defined
if (
!target.hcPath ||
!target.hcHostname ||
!target.hcPort ||
!target.hcInterval ||
!target.hcMethod
) {
logger.debug(
`Skipping target ${target.targetId} due to missing health check fields`
);
return null; // Skip targets with missing health check fields
}
// parse headers
const hcHeadersParse = target.hcHeaders
? JSON.parse(target.hcHeaders)
: null;
const hcHeadersSend: { [key: string]: string } = {};
if (hcHeadersParse) {
hcHeadersParse.forEach(
(header: { name: string; value: string }) => {
hcHeadersSend[header.name] = header.value;
}
);
}
return {
id: target.targetId,
hcEnabled: target.hcEnabled,
hcPath: target.hcPath,
hcScheme: target.hcScheme,
hcMode: target.hcMode,
hcHostname: target.hcHostname,
hcPort: target.hcPort,
hcInterval: target.hcInterval, // in seconds
hcUnhealthyInterval: target.hcUnhealthyInterval, // in seconds
hcTimeout: target.hcTimeout, // in seconds
hcHeaders: hcHeadersSend,
hcMethod: target.hcMethod,
hcTlsServerName: target.hcTlsServerName
};
});
// Filter out any null values from health check targets
const validHealthCheckTargets = healthCheckTargets.filter(
(target) => target !== null
);
logger.debug( logger.debug(
`Sending health check targets to newt ${newt.newtId}: ${JSON.stringify(validHealthCheckTargets)}` `Sending health check targets to newt ${newt.newtId}: ${JSON.stringify(validHealthCheckTargets)}`

View File

@@ -6,3 +6,4 @@ export * from "./handleGetConfigMessage";
export * from "./handleSocketMessages"; export * from "./handleSocketMessages";
export * from "./handleNewtPingRequestMessage"; export * from "./handleNewtPingRequestMessage";
export * from "./handleApplyBlueprintMessage"; export * from "./handleApplyBlueprintMessage";
export * from "./handleNewtPingMessage";

View File

@@ -39,7 +39,7 @@ export async function addPeer(
await sendToClient(newtId, { await sendToClient(newtId, {
type: "newt/wg/peer/add", type: "newt/wg/peer/add",
data: peer data: peer
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -81,7 +81,7 @@ export async function deletePeer(
data: { data: {
publicKey publicKey
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -128,7 +128,7 @@ export async function updatePeer(
publicKey, publicKey,
...peer ...peer
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });

View File

@@ -0,0 +1,41 @@
import { ExitNode, exitNodes, Newt, Site, db } from "@server/db";
import { eq } from "drizzle-orm";
import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger";
import {
buildClientConfigurationForNewtClient,
buildTargetConfigurationForNewtClient
} from "./buildConfiguration";
export async function sendNewtSyncMessage(newt: Newt, site: Site) {
const { tcpTargets, udpTargets, validHealthCheckTargets } =
await buildTargetConfigurationForNewtClient(site.siteId);
let exitNode: ExitNode | undefined;
if (site.exitNodeId) {
[exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
}
const { peers, targets } = await buildClientConfigurationForNewtClient(
site,
exitNode
);
await sendToClient(newt.newtId, {
type: "newt/sync",
data: {
proxyTargets: {
udp: udpTargets,
tcp: tcpTargets
},
healthCheckTargets: validHealthCheckTargets,
peers: peers,
clientTargets: targets
}
}).catch((error) => {
logger.warn(`Error sending newt sync message:`, error);
});
}

View File

@@ -22,7 +22,7 @@ export async function addTargets(
data: { data: {
targets: payloadTargets targets: payloadTargets
} }
}); }, { incrementConfigVersion: true });
// Create a map for quick lookup // Create a map for quick lookup
const healthCheckMap = new Map<number, TargetHealthCheck>(); const healthCheckMap = new Map<number, TargetHealthCheck>();
@@ -103,7 +103,7 @@ export async function addTargets(
data: { data: {
targets: validHealthCheckTargets targets: validHealthCheckTargets
} }
}); }, { incrementConfigVersion: true });
} }
export async function removeTargets( export async function removeTargets(
@@ -124,7 +124,7 @@ export async function removeTargets(
data: { data: {
targets: payloadTargets targets: payloadTargets
} }
}); }, { incrementConfigVersion: true });
const healthCheckTargets = targets.map((target) => { const healthCheckTargets = targets.map((target) => {
return target.targetId; return target.targetId;
@@ -135,5 +135,5 @@ export async function removeTargets(
data: { data: {
ids: healthCheckTargets ids: healthCheckTargets
} }
}); }, { incrementConfigVersion: true });
} }

View File

@@ -0,0 +1,145 @@
import { Client, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, exitNodes, siteResources, sites } from "@server/db";
import { generateAliasConfig, generateRemoteSubnets } from "@server/lib/ip";
import logger from "@server/logger";
import { and, eq } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import config from "@server/lib/config";
export async function buildSiteConfigurationForOlmClient(
client: Client,
publicKey: string | null,
relay: boolean
) {
const siteConfigurations = [];
// Get all sites data
const sitesData = await db
.select()
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Process each site
for (const {
sites: site,
clientSitesAssociationsCache: association
} of sitesData) {
if (!site.exitNodeId) {
logger.warn(
`Site ${site.siteId} does not have exit node, skipping`
);
continue;
}
// Validate endpoint and hole punch status
if (!site.endpoint) {
logger.warn(
`In olm register: site ${site.siteId} has no endpoint, skipping`
);
continue;
}
// if (site.lastHolePunch && now - site.lastHolePunch > 6 && relay) {
// logger.warn(
// `Site ${site.siteId} last hole punch is too old, skipping`
// );
// continue;
// }
// If public key changed, delete old peer from this site
if (client.pubKey && client.pubKey != publicKey) {
logger.info(
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
);
await deletePeer(site.siteId, client.pubKey!);
}
if (!site.subnet) {
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
continue;
}
const [clientSite] = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
)
.limit(1);
// Add the peer to the exit node for this site
if (clientSite.endpoint && publicKey) {
logger.info(
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}`
);
await addPeer(site.siteId, {
publicKey: publicKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
endpoint: relay ? "" : clientSite.endpoint
});
} else {
logger.warn(
`Client ${client.clientId} has no endpoint, skipping peer addition`
);
}
let relayEndpoint: string | undefined = undefined;
if (relay) {
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
if (!exitNode) {
logger.warn(`Exit node not found for site ${site.siteId}`);
continue;
}
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
}
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
name: site.name,
// relayEndpoint: relayEndpoint, // this can be undefined now if not relayed // lets not do this for now because it would conflict with the hole punch testing
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: generateRemoteSubnets(
allSiteResources.map(({ siteResources }) => siteResources)
),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
});
}
return siteConfigurations;
}

View File

@@ -1,5 +1,5 @@
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
import { clientPostureSnapshots, db, fingerprints } from "@server/db"; import { clientPostureSnapshots, db, fingerprints } from "@server/db";
import { disconnectClient } from "#dynamic/routers/ws";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { clients, olms, Olm } from "@server/db"; import { clients, olms, Olm } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm"; import { eq, lt, isNull, and, or } from "drizzle-orm";
@@ -9,6 +9,7 @@ import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { sendTerminateClient } from "../client/terminate"; import { sendTerminateClient } from "../client/terminate";
import { encodeHexLowerCase } from "@oslojs/encoding"; import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2"; import { sha256 } from "@oslojs/crypto/sha2";
import { sendOlmSyncMessage } from "./sync";
// Track if the offline checker interval is running // Track if the offline checker interval is running
let offlineCheckerInterval: NodeJS.Timeout | null = null; let offlineCheckerInterval: NodeJS.Timeout | null = null;
@@ -128,7 +129,9 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
if (client.blocked) { if (client.blocked) {
// NOTE: by returning we dont update the lastPing, so the offline checker will eventually disconnect them // NOTE: by returning we dont update the lastPing, so the offline checker will eventually disconnect them
logger.debug(`Blocked client ${client.clientId} attempted olm ping`); logger.debug(
`Blocked client ${client.clientId} attempted olm ping`
);
return; return;
} }
@@ -167,6 +170,23 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
} }
} }
// get the version
logger.debug(`handleOlmPingMessage: About to get config version for olmId: ${olm.olmId}`);
const configVersion = await getClientConfigVersion(olm.olmId);
logger.debug(`handleOlmPingMessage: Got config version: ${configVersion} (type: ${typeof configVersion})`);
if (configVersion == null || configVersion === undefined) {
logger.debug(`handleOlmPingMessage: could not get config version from server for olmId: ${olm.olmId}`)
}
if (message.configVersion != null && configVersion != null && configVersion != message.configVersion) {
logger.debug(
`handleOlmPingMessage: Olm ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
);
await sendOlmSyncMessage(olm, client);
}
// Update the client's last ping timestamp
await db await db
.update(clients) .update(clients)
.set({ .set({

View File

@@ -1,4 +1,5 @@
import { import {
Client,
clientPostureSnapshots, clientPostureSnapshots,
clientSiteResourcesAssociationsCache, clientSiteResourcesAssociationsCache,
db, db,
@@ -15,7 +16,7 @@ import {
olms, olms,
sites sites
} from "@server/db"; } from "@server/db";
import { and, eq, inArray, isNull } from "drizzle-orm"; import { and, count, eq, inArray, isNull } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers"; import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger"; import logger from "@server/logger";
import { generateAliasConfig } from "@server/lib/ip"; import { generateAliasConfig } from "@server/lib/ip";
@@ -25,6 +26,7 @@ import { validateSessionToken } from "@server/auth/sessions/app";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { encodeHexLowerCase } from "@oslojs/encoding"; import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2"; import { sha256 } from "@oslojs/crypto/sha2";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
export const handleOlmRegisterMessage: MessageHandler = async (context) => { export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!"); logger.info("Handling register olm message!");
@@ -163,8 +165,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} }
// Get all sites data // Get all sites data
const sitesData = await db const sitesCountResult = await db
.select() .select({ count: count() })
.from(sites) .from(sites)
.innerJoin( .innerJoin(
clientSitesAssociationsCache, clientSitesAssociationsCache,
@@ -172,140 +174,29 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
) )
.where(eq(clientSitesAssociationsCache.clientId, client.clientId)); .where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations // Prepare an array to store site configurations
const siteConfigurations = []; logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
logger.debug(
`Found ${sitesData.length} sites for client ${client.clientId}`
);
// this prevents us from accepting a register from an olm that has not hole punched yet. // this prevents us from accepting a register from an olm that has not hole punched yet.
// the olm will pump the register so we can keep checking // the olm will pump the register so we can keep checking
// TODO: I still think there is a better way to do this rather than locking it out here but ??? // TODO: I still think there is a better way to do this rather than locking it out here but ???
if (now - (client.lastHolePunch || 0) > 5 && sitesData.length > 0) { if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) {
logger.warn( logger.warn(
"Client last hole punch is too old and we have sites to send; skipping this register" "Client last hole punch is too old and we have sites to send; skipping this register"
); );
return; return;
} }
// Process each site // NOTE: its important that the client here is the old client and the public key is the new key
for (const { const siteConfigurations = await buildSiteConfigurationForOlmClient(
sites: site, client,
clientSitesAssociationsCache: association publicKey,
} of sitesData) { relay
if (!site.exitNodeId) { );
logger.warn(
`Site ${site.siteId} does not have exit node, skipping`
);
continue;
}
// Validate endpoint and hole punch status
if (!site.endpoint) {
logger.warn(
`In olm register: site ${site.siteId} has no endpoint, skipping`
);
continue;
}
// if (site.lastHolePunch && now - site.lastHolePunch > 6 && relay) {
// logger.warn(
// `Site ${site.siteId} last hole punch is too old, skipping`
// );
// continue;
// }
// If public key changed, delete old peer from this site
if (client.pubKey && client.pubKey != publicKey) {
logger.info(
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
);
await deletePeer(site.siteId, client.pubKey!);
}
if (!site.subnet) {
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
continue;
}
const [clientSite] = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
)
.limit(1);
// Add the peer to the exit node for this site
if (clientSite.endpoint) {
logger.info(
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}`
);
await addPeer(site.siteId, {
publicKey: publicKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
endpoint: relay ? "" : clientSite.endpoint
});
} else {
logger.warn(
`Client ${client.clientId} has no endpoint, skipping peer addition`
);
}
let relayEndpoint: string | undefined = undefined;
if (relay) {
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
if (!exitNode) {
logger.warn(`Exit node not found for site ${site.siteId}`);
continue;
}
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
}
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
name: site.name,
// relayEndpoint: relayEndpoint, // this can be undefined now if not relayed // lets not do this for now because it would conflict with the hole punch testing
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: generateRemoteSubnets(
allSiteResources.map(({ siteResources }) => siteResources)
),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
});
}
if (fingerprint) { if (fingerprint) {
const [existingFingerprint] = await db const [existingFingerprint] = await db

View File

@@ -32,20 +32,24 @@ export async function addPeer(
olmId = olm.olmId; olmId = olm.olmId;
} }
await sendToClient(olmId, { await sendToClient(
type: "olm/wg/peer/add", olmId,
data: { {
siteId: peer.siteId, type: "olm/wg/peer/add",
name: peer.name, data: {
publicKey: peer.publicKey, siteId: peer.siteId,
endpoint: peer.endpoint, name: peer.name,
relayEndpoint: peer.relayEndpoint, publicKey: peer.publicKey,
serverIP: peer.serverIP, endpoint: peer.endpoint,
serverPort: peer.serverPort, relayEndpoint: peer.relayEndpoint,
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access serverIP: peer.serverIP,
aliases: peer.aliases serverPort: peer.serverPort,
} remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
}).catch((error) => { aliases: peer.aliases
}
},
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -70,13 +74,17 @@ export async function deletePeer(
olmId = olm.olmId; olmId = olm.olmId;
} }
await sendToClient(olmId, { await sendToClient(
type: "olm/wg/peer/remove", olmId,
data: { {
publicKey, type: "olm/wg/peer/remove",
siteId: siteId data: {
} publicKey,
}).catch((error) => { siteId: siteId
}
},
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -109,19 +117,23 @@ export async function updatePeer(
olmId = olm.olmId; olmId = olm.olmId;
} }
await sendToClient(olmId, { await sendToClient(
type: "olm/wg/peer/update", olmId,
data: { {
siteId: peer.siteId, type: "olm/wg/peer/update",
publicKey: peer.publicKey, data: {
endpoint: peer.endpoint, siteId: peer.siteId,
relayEndpoint: peer.relayEndpoint, publicKey: peer.publicKey,
serverIP: peer.serverIP, endpoint: peer.endpoint,
serverPort: peer.serverPort, relayEndpoint: peer.relayEndpoint,
remoteSubnets: peer.remoteSubnets, serverIP: peer.serverIP,
aliases: peer.aliases serverPort: peer.serverPort,
} remoteSubnets: peer.remoteSubnets,
}).catch((error) => { aliases: peer.aliases
}
},
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -151,17 +163,21 @@ export async function initPeerAddHandshake(
olmId = olm.olmId; olmId = olm.olmId;
} }
await sendToClient(olmId, { await sendToClient(
type: "olm/wg/peer/holepunch/site/add", olmId,
data: { {
siteId: peer.siteId, type: "olm/wg/peer/holepunch/site/add",
exitNode: { data: {
publicKey: peer.exitNode.publicKey, siteId: peer.siteId,
relayPort: config.getRawConfig().gerbil.clients_start_port, exitNode: {
endpoint: peer.exitNode.endpoint publicKey: peer.exitNode.publicKey,
relayPort: config.getRawConfig().gerbil.clients_start_port,
endpoint: peer.exitNode.endpoint
}
} }
} },
}).catch((error) => { { incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });

View File

@@ -0,0 +1,80 @@
import { Client, db, exitNodes, Olm, sites, clientSitesAssociationsCache } from "@server/db";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger";
import { eq, inArray } from "drizzle-orm";
import config from "@server/lib/config";
export async function sendOlmSyncMessage(olm: Olm, client: Client) {
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
const siteConfigurations = await buildSiteConfigurationForOlmClient(
client,
client.pubKey,
false
);
// Get all exit nodes from sites where the client has peers
const clientSites = await db
.select()
.from(clientSitesAssociationsCache)
.innerJoin(
sites,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract unique exit node IDs
const exitNodeIds = Array.from(
new Set(
clientSites
.map(({ sites: site }) => site.exitNodeId)
.filter((id): id is number => id !== null)
)
);
let exitNodesData: {
publicKey: string;
relayPort: number;
endpoint: string;
siteIds: number[];
}[] = [];
if (exitNodeIds.length > 0) {
const allExitNodes = await db
.select()
.from(exitNodes)
.where(inArray(exitNodes.exitNodeId, exitNodeIds));
// Map exitNodeId to siteIds
const exitNodeIdToSiteIds: Record<number, number[]> = {};
for (const { sites: site } of clientSites) {
if (site.exitNodeId !== null) {
if (!exitNodeIdToSiteIds[site.exitNodeId]) {
exitNodeIdToSiteIds[site.exitNodeId] = [];
}
exitNodeIdToSiteIds[site.exitNodeId].push(site.siteId);
}
}
exitNodesData = allExitNodes.map((exitNode) => {
return {
publicKey: exitNode.publicKey,
relayPort: config.getRawConfig().gerbil.clients_start_port,
endpoint: exitNode.endpoint,
siteIds: exitNodeIdToSiteIds[exitNode.exitNodeId] ?? []
};
});
}
logger.debug("sendOlmSyncMessage: sending sync message")
await sendToClient(olm.olmId, {
type: "olm/sync",
data: {
sites: siteConfigurations,
exitNodes: exitNodesData
}
}).catch((error) => {
logger.warn(`Error sending olm sync message:`, error);
});
}

View File

@@ -5,7 +5,8 @@ import {
handleDockerStatusMessage, handleDockerStatusMessage,
handleDockerContainersMessage, handleDockerContainersMessage,
handleNewtPingRequestMessage, handleNewtPingRequestMessage,
handleApplyBlueprintMessage handleApplyBlueprintMessage,
handleNewtPingMessage
} from "../newt"; } from "../newt";
import { import {
handleOlmRegisterMessage, handleOlmRegisterMessage,
@@ -24,6 +25,7 @@ export const messageHandlers: Record<string, MessageHandler> = {
"olm/wg/relay": handleOlmRelayMessage, "olm/wg/relay": handleOlmRelayMessage,
"olm/wg/unrelay": handleOlmUnRelayMessage, "olm/wg/unrelay": handleOlmUnRelayMessage,
"olm/ping": handleOlmPingMessage, "olm/ping": handleOlmPingMessage,
"newt/ping": handleNewtPingMessage,
"newt/wg/register": handleNewtRegisterMessage, "newt/wg/register": handleNewtRegisterMessage,
"newt/wg/get-config": handleGetConfigMessage, "newt/wg/get-config": handleGetConfigMessage,
"newt/receive-bandwidth": handleReceiveBandwidthMessage, "newt/receive-bandwidth": handleReceiveBandwidthMessage,

View File

@@ -25,6 +25,7 @@ export interface AuthenticatedWebSocket extends WebSocket {
connectionId?: string; connectionId?: string;
isFullyConnected?: boolean; isFullyConnected?: boolean;
pendingMessages?: Buffer[]; pendingMessages?: Buffer[];
configVersion?: number;
} }
export interface TokenPayload { export interface TokenPayload {
@@ -36,6 +37,7 @@ export interface TokenPayload {
export interface WSMessage { export interface WSMessage {
type: string; type: string;
data: any; data: any;
configVersion?: number;
} }
export interface HandlerResponse { export interface HandlerResponse {
@@ -43,6 +45,7 @@ export interface HandlerResponse {
broadcast?: boolean; broadcast?: boolean;
excludeSender?: boolean; excludeSender?: boolean;
targetClientId?: string; targetClientId?: string;
options?: SendMessageOptions;
} }
export interface HandlerContext { export interface HandlerContext {
@@ -50,10 +53,15 @@ export interface HandlerContext {
senderWs: WebSocket; senderWs: WebSocket;
client: Newt | Olm | RemoteExitNode | undefined; client: Newt | Olm | RemoteExitNode | undefined;
clientType: ClientType; clientType: ClientType;
sendToClient: (clientId: string, message: WSMessage) => Promise<boolean>; sendToClient: (
clientId: string,
message: WSMessage,
options?: SendMessageOptions
) => Promise<boolean>;
broadcastToAllExcept: ( broadcastToAllExcept: (
message: WSMessage, message: WSMessage,
excludeClientId?: string excludeClientId?: string,
options?: SendMessageOptions
) => Promise<void>; ) => Promise<void>;
connectedClients: Map<string, WebSocket[]>; connectedClients: Map<string, WebSocket[]>;
} }
@@ -62,6 +70,11 @@ export type MessageHandler = (
context: HandlerContext context: HandlerContext
) => Promise<HandlerResponse | void>; ) => Promise<HandlerResponse | void>;
// Options for sending messages with config version tracking
export interface SendMessageOptions {
incrementConfigVersion?: boolean;
}
// Redis message type for cross-node communication // Redis message type for cross-node communication
export interface RedisMessage { export interface RedisMessage {
type: "direct" | "broadcast"; type: "direct" | "broadcast";
@@ -69,4 +82,5 @@ export interface RedisMessage {
excludeClientId?: string; excludeClientId?: string;
message: WSMessage; message: WSMessage;
fromNodeId: string; fromNodeId: string;
options?: SendMessageOptions;
} }

View File

@@ -15,7 +15,8 @@ import {
TokenPayload, TokenPayload,
WebSocketRequest, WebSocketRequest,
WSMessage, WSMessage,
AuthenticatedWebSocket AuthenticatedWebSocket,
SendMessageOptions
} from "./types"; } from "./types";
import { validateSessionToken } from "@server/auth/sessions/app"; import { validateSessionToken } from "@server/auth/sessions/app";
@@ -34,6 +35,8 @@ const NODE_ID = uuidv4();
// Client tracking map (local to this node) // Client tracking map (local to this node)
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map(); const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
// Config version tracking map (clientId -> version)
const clientConfigVersions: Map<string, number> = new Map();
// Helper to get map key // Helper to get map key
const getClientMapKey = (clientId: string) => clientId; const getClientMapKey = (clientId: string) => clientId;
@@ -53,6 +56,13 @@ const addClient = async (
existingClients.push(ws); existingClients.push(ws);
connectedClients.set(mapKey, existingClients); connectedClients.set(mapKey, existingClients);
// Initialize config version to 0 if not already set, otherwise use existing
if (!clientConfigVersions.has(clientId)) {
clientConfigVersions.set(clientId, 0);
}
// Set the current config version on the websocket
ws.configVersion = clientConfigVersions.get(clientId) || 0;
logger.info( logger.info(
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}` `Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`
); );
@@ -84,14 +94,28 @@ const removeClient = async (
// Local message sending (within this node) // Local message sending (within this node)
const sendToClientLocal = async ( const sendToClientLocal = async (
clientId: string, clientId: string,
message: WSMessage message: WSMessage,
options: SendMessageOptions = {}
): Promise<boolean> => { ): Promise<boolean> => {
const mapKey = getClientMapKey(clientId); const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey); const clients = connectedClients.get(mapKey);
if (!clients || clients.length === 0) { if (!clients || clients.length === 0) {
return false; return false;
} }
const messageString = JSON.stringify(message);
// Include config version in message
const configVersion = clientConfigVersions.get(clientId) || 0;
// Update version on all client connections
clients.forEach((client) => {
client.configVersion = configVersion;
});
const messageWithVersion = {
...message,
configVersion
};
const messageString = JSON.stringify(messageWithVersion);
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(messageString); client.send(messageString);
@@ -102,14 +126,30 @@ const sendToClientLocal = async (
const broadcastToAllExceptLocal = async ( const broadcastToAllExceptLocal = async (
message: WSMessage, message: WSMessage,
excludeClientId?: string excludeClientId?: string,
options: SendMessageOptions = {}
): Promise<void> => { ): Promise<void> => {
connectedClients.forEach((clients, mapKey) => { connectedClients.forEach((clients, mapKey) => {
const [type, id] = mapKey.split(":"); const clientId = mapKey; // mapKey is the clientId
if (!(excludeClientId && id === excludeClientId)) { if (!(excludeClientId && clientId === excludeClientId)) {
// Handle config version per client
if (options.incrementConfigVersion) {
const currentVersion = clientConfigVersions.get(clientId) || 0;
const newVersion = currentVersion + 1;
clientConfigVersions.set(clientId, newVersion);
clients.forEach((client) => {
client.configVersion = newVersion;
});
}
// Include config version in message for this client
const configVersion = clientConfigVersions.get(clientId) || 0;
const messageWithVersion = {
...message,
configVersion
};
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(message)); client.send(JSON.stringify(messageWithVersion));
} }
}); });
} }
@@ -119,10 +159,18 @@ const broadcastToAllExceptLocal = async (
// Cross-node message sending // Cross-node message sending
const sendToClient = async ( const sendToClient = async (
clientId: string, clientId: string,
message: WSMessage message: WSMessage,
options: SendMessageOptions = {}
): Promise<boolean> => { ): Promise<boolean> => {
// Increment config version if requested
if (options.incrementConfigVersion) {
const currentVersion = clientConfigVersions.get(clientId) || 0;
const newVersion = currentVersion + 1;
clientConfigVersions.set(clientId, newVersion);
}
// Try to send locally first // Try to send locally first
const localSent = await sendToClientLocal(clientId, message); const localSent = await sendToClientLocal(clientId, message, options);
logger.debug( logger.debug(
`sendToClient: Message type ${message.type} sent to clientId ${clientId}` `sendToClient: Message type ${message.type} sent to clientId ${clientId}`
@@ -133,10 +181,11 @@ const sendToClient = async (
const broadcastToAllExcept = async ( const broadcastToAllExcept = async (
message: WSMessage, message: WSMessage,
excludeClientId?: string excludeClientId?: string,
options: SendMessageOptions = {}
): Promise<void> => { ): Promise<void> => {
// Broadcast locally // Broadcast locally
await broadcastToAllExceptLocal(message, excludeClientId); await broadcastToAllExceptLocal(message, excludeClientId, options);
}; };
// Check if a client has active connections across all nodes // Check if a client has active connections across all nodes
@@ -146,6 +195,13 @@ const hasActiveConnections = async (clientId: string): Promise<boolean> => {
return !!(clients && clients.length > 0); return !!(clients && clients.length > 0);
}; };
// Get the current config version for a client
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
const version = clientConfigVersions.get(clientId);
logger.debug(`getClientConfigVersion called for clientId: ${clientId}, returning: ${version} (type: ${typeof version})`);
return version;
};
// Get all active nodes for a client // Get all active nodes for a client
const getActiveNodes = async ( const getActiveNodes = async (
clientType: ClientType, clientType: ClientType,
@@ -259,15 +315,21 @@ const setupConnection = async (
if (response.broadcast) { if (response.broadcast) {
await broadcastToAllExcept( await broadcastToAllExcept(
response.message, response.message,
response.excludeSender ? clientId : undefined response.excludeSender ? clientId : undefined,
response.options
); );
} else if (response.targetClientId) { } else if (response.targetClientId) {
await sendToClient( await sendToClient(
response.targetClientId, response.targetClientId,
response.message response.message,
response.options
); );
} else { } else {
ws.send(JSON.stringify(response.message)); await sendToClient(
clientId,
response.message,
response.options
);
} }
} }
} catch (error) { } catch (error) {
@@ -434,5 +496,6 @@ export {
getActiveNodes, getActiveNodes,
disconnectClient, disconnectClient,
NODE_ID, NODE_ID,
cleanup cleanup,
getClientConfigVersion
}; };