mirror of
https://github.com/fosrl/pangolin.git
synced 2026-05-05 20:13:58 +00:00
Compare commits
14 Commits
06aaa7c680
...
msg-delive
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05748bf8ff | ||
|
|
f8c98bf6bf | ||
|
|
a1ea3f74b3 | ||
|
|
65e8bfc93e | ||
|
|
d52bd65d21 | ||
|
|
eb0cdda0f9 | ||
|
|
eba25fcc4d | ||
|
|
0ccd5714f9 | ||
|
|
e2dfc3eb20 | ||
|
|
446eba8bc9 | ||
|
|
18579c0647 | ||
|
|
0d37e08638 | ||
|
|
75b9703793 | ||
|
|
322f3bfb1d |
@@ -26,6 +26,7 @@ export function initLogCleanupInterval() {
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// TODO: handle when there are multiple nodes doing this clearing using redis
|
||||||
for (const org of orgsToClean) {
|
for (const org of orgsToClean) {
|
||||||
const {
|
const {
|
||||||
orgId,
|
orgId,
|
||||||
|
|||||||
@@ -50,10 +50,14 @@ export async function sendToExitNode(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return sendToClient(remoteExitNode.remoteExitNodeId, {
|
return sendToClient(
|
||||||
type: request.remoteType,
|
remoteExitNode.remoteExitNodeId,
|
||||||
data: request.data
|
{
|
||||||
});
|
type: request.remoteType,
|
||||||
|
data: request.data
|
||||||
|
},
|
||||||
|
{ incrementConfigVersion: true }
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
let hostname = exitNode.reachableAt;
|
let hostname = exitNode.reachableAt;
|
||||||
|
|
||||||
|
|||||||
@@ -573,6 +573,20 @@ class RedisManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async incr(key: string): Promise<number> {
|
||||||
|
if (!this.isRedisEnabled() || !this.writeClient) return 0;
|
||||||
|
|
||||||
|
try {
|
||||||
|
return await this.executeWithRetry(
|
||||||
|
() => this.writeClient!.incr(key),
|
||||||
|
"Redis INCR"
|
||||||
|
);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error("Redis INCR error:", error);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public async sadd(key: string, member: string): Promise<boolean> {
|
public async sadd(key: string, member: string): Promise<boolean> {
|
||||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,8 @@ import {
|
|||||||
WSMessage,
|
WSMessage,
|
||||||
TokenPayload,
|
TokenPayload,
|
||||||
WebSocketRequest,
|
WebSocketRequest,
|
||||||
RedisMessage
|
RedisMessage,
|
||||||
|
SendMessageOptions
|
||||||
} from "@server/routers/ws";
|
} from "@server/routers/ws";
|
||||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||||
|
|
||||||
@@ -118,12 +119,21 @@ const processMessage = async (
|
|||||||
if (response.broadcast) {
|
if (response.broadcast) {
|
||||||
await broadcastToAllExcept(
|
await broadcastToAllExcept(
|
||||||
response.message,
|
response.message,
|
||||||
response.excludeSender ? clientId : undefined
|
response.excludeSender ? clientId : undefined,
|
||||||
|
response.options
|
||||||
);
|
);
|
||||||
} else if (response.targetClientId) {
|
} else if (response.targetClientId) {
|
||||||
await sendToClient(response.targetClientId, response.message);
|
await sendToClient(
|
||||||
|
response.targetClientId,
|
||||||
|
response.message,
|
||||||
|
response.options
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
ws.send(JSON.stringify(response.message));
|
await sendToClient(
|
||||||
|
clientId,
|
||||||
|
response.message,
|
||||||
|
response.options
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -172,6 +182,9 @@ const REDIS_CHANNEL = "websocket_messages";
|
|||||||
// Client tracking map (local to this node)
|
// Client tracking map (local to this node)
|
||||||
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
|
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
|
||||||
|
|
||||||
|
// Config version tracking map (local to this node, resets on server restart)
|
||||||
|
const clientConfigVersions: Map<string, number> = new Map();
|
||||||
|
|
||||||
// Recovery tracking
|
// Recovery tracking
|
||||||
let isRedisRecoveryInProgress = false;
|
let isRedisRecoveryInProgress = false;
|
||||||
|
|
||||||
@@ -182,6 +195,8 @@ const getClientMapKey = (clientId: string) => clientId;
|
|||||||
const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`;
|
const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`;
|
||||||
const getNodeConnectionsKey = (nodeId: string, clientId: string) =>
|
const getNodeConnectionsKey = (nodeId: string, clientId: string) =>
|
||||||
`ws:node:${nodeId}:${clientId}`;
|
`ws:node:${nodeId}:${clientId}`;
|
||||||
|
const getConfigVersionKey = (clientId: string) =>
|
||||||
|
`ws:configVersion:${clientId}`;
|
||||||
|
|
||||||
// Initialize Redis subscription for cross-node messaging
|
// Initialize Redis subscription for cross-node messaging
|
||||||
const initializeRedisSubscription = async (): Promise<void> => {
|
const initializeRedisSubscription = async (): Promise<void> => {
|
||||||
@@ -304,6 +319,45 @@ const addClient = async (
|
|||||||
existingClients.push(ws);
|
existingClients.push(ws);
|
||||||
connectedClients.set(mapKey, existingClients);
|
connectedClients.set(mapKey, existingClients);
|
||||||
|
|
||||||
|
// Get or initialize config version
|
||||||
|
let configVersion = 0;
|
||||||
|
|
||||||
|
// Check Redis first if enabled
|
||||||
|
if (redisManager.isRedisEnabled()) {
|
||||||
|
try {
|
||||||
|
const redisVersion = await redisManager.get(getConfigVersionKey(clientId));
|
||||||
|
if (redisVersion !== null) {
|
||||||
|
configVersion = parseInt(redisVersion, 10);
|
||||||
|
// Sync to local cache
|
||||||
|
clientConfigVersions.set(clientId, configVersion);
|
||||||
|
} else if (!clientConfigVersions.has(clientId)) {
|
||||||
|
// No version in Redis or local cache, initialize to 0
|
||||||
|
await redisManager.set(getConfigVersionKey(clientId), "0");
|
||||||
|
clientConfigVersions.set(clientId, 0);
|
||||||
|
} else {
|
||||||
|
// Use local cache version and sync to Redis
|
||||||
|
configVersion = clientConfigVersions.get(clientId) || 0;
|
||||||
|
await redisManager.set(getConfigVersionKey(clientId), configVersion.toString());
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error("Failed to get/set config version in Redis:", error);
|
||||||
|
// Fall back to local cache
|
||||||
|
if (!clientConfigVersions.has(clientId)) {
|
||||||
|
clientConfigVersions.set(clientId, 0);
|
||||||
|
}
|
||||||
|
configVersion = clientConfigVersions.get(clientId) || 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Redis not enabled, use local cache only
|
||||||
|
if (!clientConfigVersions.has(clientId)) {
|
||||||
|
clientConfigVersions.set(clientId, 0);
|
||||||
|
}
|
||||||
|
configVersion = clientConfigVersions.get(clientId) || 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set config version on websocket
|
||||||
|
ws.configVersion = configVersion;
|
||||||
|
|
||||||
// Add to Redis tracking if enabled
|
// Add to Redis tracking if enabled
|
||||||
if (redisManager.isRedisEnabled()) {
|
if (redisManager.isRedisEnabled()) {
|
||||||
try {
|
try {
|
||||||
@@ -322,7 +376,7 @@ const addClient = async (
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`
|
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}, Config version: ${configVersion}`
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -377,53 +431,133 @@ const removeClient = async (
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Helper to get the current config version for a client
|
||||||
|
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
|
||||||
|
// Try Redis first if available
|
||||||
|
if (redisManager.isRedisEnabled()) {
|
||||||
|
try {
|
||||||
|
const redisVersion = await redisManager.get(
|
||||||
|
getConfigVersionKey(clientId)
|
||||||
|
);
|
||||||
|
if (redisVersion !== null) {
|
||||||
|
const version = parseInt(redisVersion, 10);
|
||||||
|
// Sync local cache with Redis
|
||||||
|
clientConfigVersions.set(clientId, version);
|
||||||
|
return version;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error("Failed to get config version from Redis:", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to local cache
|
||||||
|
return clientConfigVersions.get(clientId);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper to increment and get the new config version for a client
|
||||||
|
const incrementClientConfigVersion = async (
|
||||||
|
clientId: string
|
||||||
|
): Promise<number> => {
|
||||||
|
let newVersion: number;
|
||||||
|
|
||||||
|
if (redisManager.isRedisEnabled()) {
|
||||||
|
try {
|
||||||
|
// Use Redis INCR for atomic increment across nodes
|
||||||
|
newVersion = await redisManager.incr(getConfigVersionKey(clientId));
|
||||||
|
// Sync local cache
|
||||||
|
clientConfigVersions.set(clientId, newVersion);
|
||||||
|
return newVersion;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error("Failed to increment config version in Redis:", error);
|
||||||
|
// Fall through to local increment
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Local increment
|
||||||
|
const currentVersion = clientConfigVersions.get(clientId) || 0;
|
||||||
|
newVersion = currentVersion + 1;
|
||||||
|
clientConfigVersions.set(clientId, newVersion);
|
||||||
|
return newVersion;
|
||||||
|
};
|
||||||
|
|
||||||
// Local message sending (within this node)
|
// Local message sending (within this node)
|
||||||
const sendToClientLocal = async (
|
const sendToClientLocal = async (
|
||||||
clientId: string,
|
clientId: string,
|
||||||
message: WSMessage
|
message: WSMessage,
|
||||||
|
options: SendMessageOptions = {}
|
||||||
): Promise<boolean> => {
|
): Promise<boolean> => {
|
||||||
const mapKey = getClientMapKey(clientId);
|
const mapKey = getClientMapKey(clientId);
|
||||||
const clients = connectedClients.get(mapKey);
|
const clients = connectedClients.get(mapKey);
|
||||||
if (!clients || clients.length === 0) {
|
if (!clients || clients.length === 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const messageString = JSON.stringify(message);
|
|
||||||
|
// Handle config version
|
||||||
|
let configVersion = await getClientConfigVersion(clientId);
|
||||||
|
|
||||||
|
// Add config version to message
|
||||||
|
const messageWithVersion = {
|
||||||
|
...message,
|
||||||
|
configVersion
|
||||||
|
};
|
||||||
|
|
||||||
|
const messageString = JSON.stringify(messageWithVersion);
|
||||||
clients.forEach((client) => {
|
clients.forEach((client) => {
|
||||||
if (client.readyState === WebSocket.OPEN) {
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
client.send(messageString);
|
client.send(messageString);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
|
|
||||||
);
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
const broadcastToAllExceptLocal = async (
|
const broadcastToAllExceptLocal = async (
|
||||||
message: WSMessage,
|
message: WSMessage,
|
||||||
excludeClientId?: string
|
excludeClientId?: string,
|
||||||
|
options: SendMessageOptions = {}
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
connectedClients.forEach((clients, mapKey) => {
|
for (const [mapKey, clients] of connectedClients.entries()) {
|
||||||
const [type, id] = mapKey.split(":");
|
const [type, id] = mapKey.split(":");
|
||||||
if (!(excludeClientId && id === excludeClientId)) {
|
const clientId = mapKey; // mapKey is the clientId
|
||||||
|
if (!(excludeClientId && clientId === excludeClientId)) {
|
||||||
|
// Handle config version per client
|
||||||
|
let configVersion = await getClientConfigVersion(clientId);
|
||||||
|
if (options.incrementConfigVersion) {
|
||||||
|
configVersion = await incrementClientConfigVersion(clientId);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add config version to message
|
||||||
|
const messageWithVersion = {
|
||||||
|
...message,
|
||||||
|
configVersion
|
||||||
|
};
|
||||||
|
|
||||||
clients.forEach((client) => {
|
clients.forEach((client) => {
|
||||||
if (client.readyState === WebSocket.OPEN) {
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
client.send(JSON.stringify(message));
|
client.send(JSON.stringify(messageWithVersion));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Cross-node message sending (via Redis)
|
// Cross-node message sending (via Redis)
|
||||||
const sendToClient = async (
|
const sendToClient = async (
|
||||||
clientId: string,
|
clientId: string,
|
||||||
message: WSMessage
|
message: WSMessage,
|
||||||
|
options: SendMessageOptions = {}
|
||||||
): Promise<boolean> => {
|
): Promise<boolean> => {
|
||||||
|
let configVersion = await getClientConfigVersion(clientId);
|
||||||
|
if (options.incrementConfigVersion) {
|
||||||
|
configVersion = await incrementClientConfigVersion(clientId);
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
`sendToClient: Message type ${message.type} sent to clientId ${clientId} (new configVersion: ${configVersion})`
|
||||||
|
);
|
||||||
|
|
||||||
// Try to send locally first
|
// Try to send locally first
|
||||||
const localSent = await sendToClientLocal(clientId, message);
|
const localSent = await sendToClientLocal(clientId, message, options);
|
||||||
|
|
||||||
// Only send via Redis if the client is not connected locally and Redis is enabled
|
// Only send via Redis if the client is not connected locally and Redis is enabled
|
||||||
if (!localSent && redisManager.isRedisEnabled()) {
|
if (!localSent && redisManager.isRedisEnabled()) {
|
||||||
@@ -431,7 +565,10 @@ const sendToClient = async (
|
|||||||
const redisMessage: RedisMessage = {
|
const redisMessage: RedisMessage = {
|
||||||
type: "direct",
|
type: "direct",
|
||||||
targetClientId: clientId,
|
targetClientId: clientId,
|
||||||
message,
|
message: {
|
||||||
|
...message,
|
||||||
|
configVersion
|
||||||
|
},
|
||||||
fromNodeId: NODE_ID
|
fromNodeId: NODE_ID
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -458,19 +595,22 @@ const sendToClient = async (
|
|||||||
|
|
||||||
const broadcastToAllExcept = async (
|
const broadcastToAllExcept = async (
|
||||||
message: WSMessage,
|
message: WSMessage,
|
||||||
excludeClientId?: string
|
excludeClientId?: string,
|
||||||
|
options: SendMessageOptions = {}
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
// Broadcast locally
|
// Broadcast locally
|
||||||
await broadcastToAllExceptLocal(message, excludeClientId);
|
await broadcastToAllExceptLocal(message, excludeClientId, options);
|
||||||
|
|
||||||
// If Redis is enabled, also broadcast via Redis pub/sub to other nodes
|
// If Redis is enabled, also broadcast via Redis pub/sub to other nodes
|
||||||
|
// Note: For broadcasts, we include the options so remote nodes can handle versioning
|
||||||
if (redisManager.isRedisEnabled()) {
|
if (redisManager.isRedisEnabled()) {
|
||||||
try {
|
try {
|
||||||
const redisMessage: RedisMessage = {
|
const redisMessage: RedisMessage = {
|
||||||
type: "broadcast",
|
type: "broadcast",
|
||||||
excludeClientId,
|
excludeClientId,
|
||||||
message,
|
message,
|
||||||
fromNodeId: NODE_ID
|
fromNodeId: NODE_ID,
|
||||||
|
options
|
||||||
};
|
};
|
||||||
|
|
||||||
await redisManager.publish(
|
await redisManager.publish(
|
||||||
@@ -936,5 +1076,6 @@ export {
|
|||||||
getActiveNodes,
|
getActiveNodes,
|
||||||
disconnectClient,
|
disconnectClient,
|
||||||
NODE_ID,
|
NODE_ID,
|
||||||
cleanup
|
cleanup,
|
||||||
|
getClientConfigVersion
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
|
|||||||
await sendToClient(newtId, {
|
await sendToClient(newtId, {
|
||||||
type: `newt/wg/targets/add`,
|
type: `newt/wg/targets/add`,
|
||||||
data: batches[i]
|
data: batches[i]
|
||||||
});
|
}, { incrementConfigVersion: true });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ export async function removeTargets(
|
|||||||
await sendToClient(newtId, {
|
await sendToClient(newtId, {
|
||||||
type: `newt/wg/targets/remove`,
|
type: `newt/wg/targets/remove`,
|
||||||
data: batches[i]
|
data: batches[i]
|
||||||
});
|
},{ incrementConfigVersion: true });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,7 +69,7 @@ export async function updateTargets(
|
|||||||
oldTargets: oldBatches[i] || [],
|
oldTargets: oldBatches[i] || [],
|
||||||
newTargets: newBatches[i] || []
|
newTargets: newBatches[i] || []
|
||||||
}
|
}
|
||||||
}).catch((error) => {
|
}, { incrementConfigVersion: true }).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -101,7 +101,7 @@ export async function addPeerData(
|
|||||||
remoteSubnets: remoteSubnets,
|
remoteSubnets: remoteSubnets,
|
||||||
aliases: aliases
|
aliases: aliases
|
||||||
}
|
}
|
||||||
}).catch((error) => {
|
}, { incrementConfigVersion: true }).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -132,7 +132,7 @@ export async function removePeerData(
|
|||||||
remoteSubnets: remoteSubnets,
|
remoteSubnets: remoteSubnets,
|
||||||
aliases: aliases
|
aliases: aliases
|
||||||
}
|
}
|
||||||
}).catch((error) => {
|
}, { incrementConfigVersion: true }).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -173,7 +173,7 @@ export async function updatePeerData(
|
|||||||
...remoteSubnets,
|
...remoteSubnets,
|
||||||
...aliases
|
...aliases
|
||||||
}
|
}
|
||||||
}).catch((error) => {
|
}, { incrementConfigVersion: true }).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
278
server/routers/newt/buildConfiguration.ts
Normal file
278
server/routers/newt/buildConfiguration.ts
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
import { clients, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, ExitNode, resources, Site, siteResources, targetHealthCheck, targets } from "@server/db";
|
||||||
|
import logger from "@server/logger";
|
||||||
|
import { initPeerAddHandshake, updatePeer } from "../olm/peers";
|
||||||
|
import { eq, and } from "drizzle-orm";
|
||||||
|
import config from "@server/lib/config";
|
||||||
|
import { generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip";
|
||||||
|
|
||||||
|
export async function buildClientConfigurationForNewtClient(
|
||||||
|
site: Site,
|
||||||
|
exitNode?: ExitNode
|
||||||
|
) {
|
||||||
|
const siteId = site.siteId;
|
||||||
|
|
||||||
|
// Get all clients connected to this site
|
||||||
|
const clientsRes = await db
|
||||||
|
.select()
|
||||||
|
.from(clients)
|
||||||
|
.innerJoin(
|
||||||
|
clientSitesAssociationsCache,
|
||||||
|
eq(clients.clientId, clientSitesAssociationsCache.clientId)
|
||||||
|
)
|
||||||
|
.where(eq(clientSitesAssociationsCache.siteId, siteId));
|
||||||
|
|
||||||
|
let peers: Array<{
|
||||||
|
publicKey: string;
|
||||||
|
allowedIps: string[];
|
||||||
|
endpoint?: string;
|
||||||
|
}> = [];
|
||||||
|
|
||||||
|
if (site.publicKey && site.endpoint && exitNode) {
|
||||||
|
// Prepare peers data for the response
|
||||||
|
peers = await Promise.all(
|
||||||
|
clientsRes
|
||||||
|
.filter((client) => {
|
||||||
|
if (!client.clients.pubKey) {
|
||||||
|
logger.warn(
|
||||||
|
`Client ${client.clients.clientId} has no public key, skipping`
|
||||||
|
);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!client.clients.subnet) {
|
||||||
|
logger.warn(
|
||||||
|
`Client ${client.clients.clientId} has no subnet, skipping`
|
||||||
|
);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
})
|
||||||
|
.map(async (client) => {
|
||||||
|
// Add or update this peer on the olm if it is connected
|
||||||
|
|
||||||
|
// const allSiteResources = await db // only get the site resources that this client has access to
|
||||||
|
// .select()
|
||||||
|
// .from(siteResources)
|
||||||
|
// .innerJoin(
|
||||||
|
// clientSiteResourcesAssociationsCache,
|
||||||
|
// eq(
|
||||||
|
// siteResources.siteResourceId,
|
||||||
|
// clientSiteResourcesAssociationsCache.siteResourceId
|
||||||
|
// )
|
||||||
|
// )
|
||||||
|
// .where(
|
||||||
|
// and(
|
||||||
|
// eq(siteResources.siteId, site.siteId),
|
||||||
|
// eq(
|
||||||
|
// clientSiteResourcesAssociationsCache.clientId,
|
||||||
|
// client.clients.clientId
|
||||||
|
// )
|
||||||
|
// )
|
||||||
|
// );
|
||||||
|
|
||||||
|
// update the peer info on the olm
|
||||||
|
// if the peer has not been added yet this will be a no-op
|
||||||
|
await updatePeer(client.clients.clientId, {
|
||||||
|
siteId: site.siteId,
|
||||||
|
endpoint: site.endpoint!,
|
||||||
|
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
|
||||||
|
publicKey: site.publicKey!,
|
||||||
|
serverIP: site.address,
|
||||||
|
serverPort: site.listenPort
|
||||||
|
// remoteSubnets: generateRemoteSubnets(
|
||||||
|
// allSiteResources.map(
|
||||||
|
// ({ siteResources }) => siteResources
|
||||||
|
// )
|
||||||
|
// ),
|
||||||
|
// aliases: generateAliasConfig(
|
||||||
|
// allSiteResources.map(
|
||||||
|
// ({ siteResources }) => siteResources
|
||||||
|
// )
|
||||||
|
// )
|
||||||
|
});
|
||||||
|
|
||||||
|
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
|
||||||
|
// if it has already been added this will be a no-op
|
||||||
|
await initPeerAddHandshake(
|
||||||
|
// this will kick off the add peer process for the client
|
||||||
|
client.clients.clientId,
|
||||||
|
{
|
||||||
|
siteId,
|
||||||
|
exitNode: {
|
||||||
|
publicKey: exitNode.publicKey,
|
||||||
|
endpoint: exitNode.endpoint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
publicKey: client.clients.pubKey!,
|
||||||
|
allowedIps: [
|
||||||
|
`${client.clients.subnet.split("/")[0]}/32`
|
||||||
|
], // we want to only allow from that client
|
||||||
|
endpoint: client.clientSitesAssociationsCache.isRelayed
|
||||||
|
? ""
|
||||||
|
: client.clientSitesAssociationsCache.endpoint! // if its relayed it should be localhost
|
||||||
|
};
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out any null values from peers that didn't have an olm
|
||||||
|
const validPeers = peers.filter((peer) => peer !== null);
|
||||||
|
|
||||||
|
// Get all enabled site resources for this site
|
||||||
|
const allSiteResources = await db
|
||||||
|
.select()
|
||||||
|
.from(siteResources)
|
||||||
|
.where(eq(siteResources.siteId, siteId));
|
||||||
|
|
||||||
|
const targetsToSend: SubnetProxyTarget[] = [];
|
||||||
|
|
||||||
|
for (const resource of allSiteResources) {
|
||||||
|
// Get clients associated with this specific resource
|
||||||
|
const resourceClients = await db
|
||||||
|
.select({
|
||||||
|
clientId: clients.clientId,
|
||||||
|
pubKey: clients.pubKey,
|
||||||
|
subnet: clients.subnet
|
||||||
|
})
|
||||||
|
.from(clients)
|
||||||
|
.innerJoin(
|
||||||
|
clientSiteResourcesAssociationsCache,
|
||||||
|
eq(
|
||||||
|
clients.clientId,
|
||||||
|
clientSiteResourcesAssociationsCache.clientId
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
eq(
|
||||||
|
clientSiteResourcesAssociationsCache.siteResourceId,
|
||||||
|
resource.siteResourceId
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
const resourceTargets = generateSubnetProxyTargets(
|
||||||
|
resource,
|
||||||
|
resourceClients
|
||||||
|
);
|
||||||
|
|
||||||
|
targetsToSend.push(...resourceTargets);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
peers: validPeers,
|
||||||
|
targets: targetsToSend
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function buildTargetConfigurationForNewtClient(siteId: number) {
|
||||||
|
// Get all enabled targets with their resource protocol information
|
||||||
|
const allTargets = await db
|
||||||
|
.select({
|
||||||
|
resourceId: targets.resourceId,
|
||||||
|
targetId: targets.targetId,
|
||||||
|
ip: targets.ip,
|
||||||
|
method: targets.method,
|
||||||
|
port: targets.port,
|
||||||
|
internalPort: targets.internalPort,
|
||||||
|
enabled: targets.enabled,
|
||||||
|
protocol: resources.protocol,
|
||||||
|
hcEnabled: targetHealthCheck.hcEnabled,
|
||||||
|
hcPath: targetHealthCheck.hcPath,
|
||||||
|
hcScheme: targetHealthCheck.hcScheme,
|
||||||
|
hcMode: targetHealthCheck.hcMode,
|
||||||
|
hcHostname: targetHealthCheck.hcHostname,
|
||||||
|
hcPort: targetHealthCheck.hcPort,
|
||||||
|
hcInterval: targetHealthCheck.hcInterval,
|
||||||
|
hcUnhealthyInterval: targetHealthCheck.hcUnhealthyInterval,
|
||||||
|
hcTimeout: targetHealthCheck.hcTimeout,
|
||||||
|
hcHeaders: targetHealthCheck.hcHeaders,
|
||||||
|
hcMethod: targetHealthCheck.hcMethod,
|
||||||
|
hcTlsServerName: targetHealthCheck.hcTlsServerName
|
||||||
|
})
|
||||||
|
.from(targets)
|
||||||
|
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
|
||||||
|
.leftJoin(
|
||||||
|
targetHealthCheck,
|
||||||
|
eq(targets.targetId, targetHealthCheck.targetId)
|
||||||
|
)
|
||||||
|
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
|
||||||
|
|
||||||
|
const { tcpTargets, udpTargets } = allTargets.reduce(
|
||||||
|
(acc, target) => {
|
||||||
|
// Filter out invalid targets
|
||||||
|
if (!target.internalPort || !target.ip || !target.port) {
|
||||||
|
return acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format target into string
|
||||||
|
const formattedTarget = `${target.internalPort}:${target.ip}:${target.port}`;
|
||||||
|
|
||||||
|
// Add to the appropriate protocol array
|
||||||
|
if (target.protocol === "tcp") {
|
||||||
|
acc.tcpTargets.push(formattedTarget);
|
||||||
|
} else {
|
||||||
|
acc.udpTargets.push(formattedTarget);
|
||||||
|
}
|
||||||
|
|
||||||
|
return acc;
|
||||||
|
},
|
||||||
|
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
|
||||||
|
);
|
||||||
|
|
||||||
|
const healthCheckTargets = allTargets.map((target) => {
|
||||||
|
// make sure the stuff is defined
|
||||||
|
if (
|
||||||
|
!target.hcPath ||
|
||||||
|
!target.hcHostname ||
|
||||||
|
!target.hcPort ||
|
||||||
|
!target.hcInterval ||
|
||||||
|
!target.hcMethod
|
||||||
|
) {
|
||||||
|
logger.debug(
|
||||||
|
`Skipping target ${target.targetId} due to missing health check fields`
|
||||||
|
);
|
||||||
|
return null; // Skip targets with missing health check fields
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse headers
|
||||||
|
const hcHeadersParse = target.hcHeaders
|
||||||
|
? JSON.parse(target.hcHeaders)
|
||||||
|
: null;
|
||||||
|
const hcHeadersSend: { [key: string]: string } = {};
|
||||||
|
if (hcHeadersParse) {
|
||||||
|
hcHeadersParse.forEach(
|
||||||
|
(header: { name: string; value: string }) => {
|
||||||
|
hcHeadersSend[header.name] = header.value;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: target.targetId,
|
||||||
|
hcEnabled: target.hcEnabled,
|
||||||
|
hcPath: target.hcPath,
|
||||||
|
hcScheme: target.hcScheme,
|
||||||
|
hcMode: target.hcMode,
|
||||||
|
hcHostname: target.hcHostname,
|
||||||
|
hcPort: target.hcPort,
|
||||||
|
hcInterval: target.hcInterval, // in seconds
|
||||||
|
hcUnhealthyInterval: target.hcUnhealthyInterval, // in seconds
|
||||||
|
hcTimeout: target.hcTimeout, // in seconds
|
||||||
|
hcHeaders: hcHeadersSend,
|
||||||
|
hcMethod: target.hcMethod,
|
||||||
|
hcTlsServerName: target.hcTlsServerName
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
// Filter out any null values from health check targets
|
||||||
|
const validHealthCheckTargets = healthCheckTargets.filter(
|
||||||
|
(target) => target !== null
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
validHealthCheckTargets,
|
||||||
|
tcpTargets,
|
||||||
|
udpTargets
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -2,19 +2,10 @@ import { z } from "zod";
|
|||||||
import { MessageHandler } from "@server/routers/ws";
|
import { MessageHandler } from "@server/routers/ws";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import { fromError } from "zod-validation-error";
|
import { fromError } from "zod-validation-error";
|
||||||
import {
|
import { db, ExitNode, exitNodes, Newt, sites } from "@server/db";
|
||||||
db,
|
|
||||||
ExitNode,
|
|
||||||
exitNodes,
|
|
||||||
siteResources,
|
|
||||||
clientSiteResourcesAssociationsCache
|
|
||||||
} from "@server/db";
|
|
||||||
import { clients, clientSitesAssociationsCache, Newt, sites } from "@server/db";
|
|
||||||
import { eq } from "drizzle-orm";
|
import { eq } from "drizzle-orm";
|
||||||
import { initPeerAddHandshake, updatePeer } from "../olm/peers";
|
|
||||||
import { sendToExitNode } from "#dynamic/lib/exitNodes";
|
import { sendToExitNode } from "#dynamic/lib/exitNodes";
|
||||||
import { generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip";
|
import { buildClientConfigurationForNewtClient } from "./buildConfiguration";
|
||||||
import config from "@server/lib/config";
|
|
||||||
|
|
||||||
const inputSchema = z.object({
|
const inputSchema = z.object({
|
||||||
publicKey: z.string(),
|
publicKey: z.string(),
|
||||||
@@ -130,167 +121,18 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all clients connected to this site
|
const { peers, targets } = await buildClientConfigurationForNewtClient(
|
||||||
const clientsRes = await db
|
site,
|
||||||
.select()
|
exitNode
|
||||||
.from(clients)
|
);
|
||||||
.innerJoin(
|
|
||||||
clientSitesAssociationsCache,
|
|
||||||
eq(clients.clientId, clientSitesAssociationsCache.clientId)
|
|
||||||
)
|
|
||||||
.where(eq(clientSitesAssociationsCache.siteId, siteId));
|
|
||||||
|
|
||||||
let peers: Array<{
|
|
||||||
publicKey: string;
|
|
||||||
allowedIps: string[];
|
|
||||||
endpoint?: string;
|
|
||||||
}> = [];
|
|
||||||
|
|
||||||
if (site.publicKey && site.endpoint && exitNode) {
|
|
||||||
// Prepare peers data for the response
|
|
||||||
peers = await Promise.all(
|
|
||||||
clientsRes
|
|
||||||
.filter((client) => {
|
|
||||||
if (!client.clients.pubKey) {
|
|
||||||
logger.warn(
|
|
||||||
`Client ${client.clients.clientId} has no public key, skipping`
|
|
||||||
);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!client.clients.subnet) {
|
|
||||||
logger.warn(
|
|
||||||
`Client ${client.clients.clientId} has no subnet, skipping`
|
|
||||||
);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
})
|
|
||||||
.map(async (client) => {
|
|
||||||
// Add or update this peer on the olm if it is connected
|
|
||||||
|
|
||||||
// const allSiteResources = await db // only get the site resources that this client has access to
|
|
||||||
// .select()
|
|
||||||
// .from(siteResources)
|
|
||||||
// .innerJoin(
|
|
||||||
// clientSiteResourcesAssociationsCache,
|
|
||||||
// eq(
|
|
||||||
// siteResources.siteResourceId,
|
|
||||||
// clientSiteResourcesAssociationsCache.siteResourceId
|
|
||||||
// )
|
|
||||||
// )
|
|
||||||
// .where(
|
|
||||||
// and(
|
|
||||||
// eq(siteResources.siteId, site.siteId),
|
|
||||||
// eq(
|
|
||||||
// clientSiteResourcesAssociationsCache.clientId,
|
|
||||||
// client.clients.clientId
|
|
||||||
// )
|
|
||||||
// )
|
|
||||||
// );
|
|
||||||
|
|
||||||
// update the peer info on the olm
|
|
||||||
// if the peer has not been added yet this will be a no-op
|
|
||||||
await updatePeer(client.clients.clientId, {
|
|
||||||
siteId: site.siteId,
|
|
||||||
endpoint: site.endpoint!,
|
|
||||||
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
|
|
||||||
publicKey: site.publicKey!,
|
|
||||||
serverIP: site.address,
|
|
||||||
serverPort: site.listenPort
|
|
||||||
// remoteSubnets: generateRemoteSubnets(
|
|
||||||
// allSiteResources.map(
|
|
||||||
// ({ siteResources }) => siteResources
|
|
||||||
// )
|
|
||||||
// ),
|
|
||||||
// aliases: generateAliasConfig(
|
|
||||||
// allSiteResources.map(
|
|
||||||
// ({ siteResources }) => siteResources
|
|
||||||
// )
|
|
||||||
// )
|
|
||||||
});
|
|
||||||
|
|
||||||
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
|
|
||||||
// if it has already been added this will be a no-op
|
|
||||||
await initPeerAddHandshake(
|
|
||||||
// this will kick off the add peer process for the client
|
|
||||||
client.clients.clientId,
|
|
||||||
{
|
|
||||||
siteId,
|
|
||||||
exitNode: {
|
|
||||||
publicKey: exitNode.publicKey,
|
|
||||||
endpoint: exitNode.endpoint
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
return {
|
|
||||||
publicKey: client.clients.pubKey!,
|
|
||||||
allowedIps: [
|
|
||||||
`${client.clients.subnet.split("/")[0]}/32`
|
|
||||||
], // we want to only allow from that client
|
|
||||||
endpoint: client.clientSitesAssociationsCache.isRelayed
|
|
||||||
? ""
|
|
||||||
: client.clientSitesAssociationsCache.endpoint! // if its relayed it should be localhost
|
|
||||||
};
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter out any null values from peers that didn't have an olm
|
|
||||||
const validPeers = peers.filter((peer) => peer !== null);
|
|
||||||
|
|
||||||
// Get all enabled site resources for this site
|
|
||||||
const allSiteResources = await db
|
|
||||||
.select()
|
|
||||||
.from(siteResources)
|
|
||||||
.where(eq(siteResources.siteId, siteId));
|
|
||||||
|
|
||||||
const targetsToSend: SubnetProxyTarget[] = [];
|
|
||||||
|
|
||||||
for (const resource of allSiteResources) {
|
|
||||||
// Get clients associated with this specific resource
|
|
||||||
const resourceClients = await db
|
|
||||||
.select({
|
|
||||||
clientId: clients.clientId,
|
|
||||||
pubKey: clients.pubKey,
|
|
||||||
subnet: clients.subnet
|
|
||||||
})
|
|
||||||
.from(clients)
|
|
||||||
.innerJoin(
|
|
||||||
clientSiteResourcesAssociationsCache,
|
|
||||||
eq(
|
|
||||||
clients.clientId,
|
|
||||||
clientSiteResourcesAssociationsCache.clientId
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.where(
|
|
||||||
eq(
|
|
||||||
clientSiteResourcesAssociationsCache.siteResourceId,
|
|
||||||
resource.siteResourceId
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
const resourceTargets = generateSubnetProxyTargets(
|
|
||||||
resource,
|
|
||||||
resourceClients
|
|
||||||
);
|
|
||||||
|
|
||||||
targetsToSend.push(...resourceTargets);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the configuration response
|
|
||||||
const configResponse = {
|
|
||||||
ipAddress: site.address,
|
|
||||||
peers: validPeers,
|
|
||||||
targets: targetsToSend
|
|
||||||
};
|
|
||||||
|
|
||||||
logger.debug("Sending config: ", configResponse);
|
|
||||||
return {
|
return {
|
||||||
message: {
|
message: {
|
||||||
type: "newt/wg/receive-config",
|
type: "newt/wg/receive-config",
|
||||||
data: {
|
data: {
|
||||||
...configResponse
|
ipAddress: site.address,
|
||||||
|
peers,
|
||||||
|
targets
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
broadcast: false,
|
broadcast: false,
|
||||||
|
|||||||
163
server/routers/newt/handleNewtPingMessage.ts
Normal file
163
server/routers/newt/handleNewtPingMessage.ts
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
import { db, sites } from "@server/db";
|
||||||
|
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
|
||||||
|
import { MessageHandler } from "@server/routers/ws";
|
||||||
|
import { clients, Newt } from "@server/db";
|
||||||
|
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
||||||
|
import logger from "@server/logger";
|
||||||
|
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||||
|
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||||
|
import { sendTerminateClient } from "../client/terminate";
|
||||||
|
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||||
|
import { sha256 } from "@oslojs/crypto/sha2";
|
||||||
|
import { sendNewtSyncMessage } from "./sync";
|
||||||
|
|
||||||
|
// Track if the offline checker interval is running
|
||||||
|
// let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
||||||
|
// const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
|
||||||
|
// const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Starts the background interval that checks for clients that haven't pinged recently
|
||||||
|
* and marks them as offline
|
||||||
|
*/
|
||||||
|
// export const startNewtOfflineChecker = (): void => {
|
||||||
|
// if (offlineCheckerInterval) {
|
||||||
|
// return; // Already running
|
||||||
|
// }
|
||||||
|
|
||||||
|
// offlineCheckerInterval = setInterval(async () => {
|
||||||
|
// try {
|
||||||
|
// const twoMinutesAgo = Math.floor(
|
||||||
|
// (Date.now() - OFFLINE_THRESHOLD_MS) / 1000
|
||||||
|
// );
|
||||||
|
|
||||||
|
// // TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING
|
||||||
|
|
||||||
|
// // Find clients that haven't pinged in the last 2 minutes and mark them as offline
|
||||||
|
// const offlineClients = await db
|
||||||
|
// .update(clients)
|
||||||
|
// .set({ online: false })
|
||||||
|
// .where(
|
||||||
|
// and(
|
||||||
|
// eq(clients.online, true),
|
||||||
|
// or(
|
||||||
|
// lt(clients.lastPing, twoMinutesAgo),
|
||||||
|
// isNull(clients.lastPing)
|
||||||
|
// )
|
||||||
|
// )
|
||||||
|
// )
|
||||||
|
// .returning();
|
||||||
|
|
||||||
|
// for (const offlineClient of offlineClients) {
|
||||||
|
// logger.info(
|
||||||
|
// `Kicking offline newt client ${offlineClient.clientId} due to inactivity`
|
||||||
|
// );
|
||||||
|
|
||||||
|
// if (!offlineClient.newtId) {
|
||||||
|
// logger.warn(
|
||||||
|
// `Offline client ${offlineClient.clientId} has no newtId, cannot disconnect`
|
||||||
|
// );
|
||||||
|
// continue;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Send a disconnect message to the client if connected
|
||||||
|
// try {
|
||||||
|
// await sendTerminateClient(
|
||||||
|
// offlineClient.clientId,
|
||||||
|
// offlineClient.newtId
|
||||||
|
// ); // terminate first
|
||||||
|
// // wait a moment to ensure the message is sent
|
||||||
|
// await new Promise((resolve) => setTimeout(resolve, 1000));
|
||||||
|
// await disconnectClient(offlineClient.newtId);
|
||||||
|
// } catch (error) {
|
||||||
|
// logger.error(
|
||||||
|
// `Error sending disconnect to offline newt ${offlineClient.clientId}`,
|
||||||
|
// { error }
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// } catch (error) {
|
||||||
|
// logger.error("Error in offline checker interval", { error });
|
||||||
|
// }
|
||||||
|
// }, OFFLINE_CHECK_INTERVAL);
|
||||||
|
|
||||||
|
// logger.debug("Started offline checker interval");
|
||||||
|
// };
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stops the background interval that checks for offline clients
|
||||||
|
*/
|
||||||
|
// export const stopNewtOfflineChecker = (): void => {
|
||||||
|
// if (offlineCheckerInterval) {
|
||||||
|
// clearInterval(offlineCheckerInterval);
|
||||||
|
// offlineCheckerInterval = null;
|
||||||
|
// logger.info("Stopped offline checker interval");
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles ping messages from clients and responds with pong
|
||||||
|
*/
|
||||||
|
export const handleNewtPingMessage: MessageHandler = async (context) => {
|
||||||
|
const { message, client: c, sendToClient } = context;
|
||||||
|
const newt = c as Newt;
|
||||||
|
|
||||||
|
if (!newt) {
|
||||||
|
logger.warn("Newt ping message: Newt not found");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!newt.siteId) {
|
||||||
|
logger.warn("Newt ping message: has no site ID");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the version
|
||||||
|
const configVersion = await getClientConfigVersion(newt.newtId);
|
||||||
|
|
||||||
|
if (message.configVersion && configVersion != null && configVersion != message.configVersion) {
|
||||||
|
logger.warn(
|
||||||
|
`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
|
||||||
|
);
|
||||||
|
|
||||||
|
// get the site
|
||||||
|
const [site] = await db
|
||||||
|
.select()
|
||||||
|
.from(sites)
|
||||||
|
.where(eq(sites.siteId, newt.siteId))
|
||||||
|
.limit(1);
|
||||||
|
|
||||||
|
if (!site) {
|
||||||
|
logger.warn(
|
||||||
|
`Newt ping message: site with ID ${newt.siteId} not found`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
await sendNewtSyncMessage(newt, site);
|
||||||
|
}
|
||||||
|
|
||||||
|
// try {
|
||||||
|
// // Update the client's last ping timestamp
|
||||||
|
// await db
|
||||||
|
// .update(clients)
|
||||||
|
// .set({
|
||||||
|
// lastPing: Math.floor(Date.now() / 1000),
|
||||||
|
// online: true
|
||||||
|
// })
|
||||||
|
// .where(eq(clients.clientId, newt.clientId));
|
||||||
|
// } catch (error) {
|
||||||
|
// logger.error("Error handling ping message", { error });
|
||||||
|
// }
|
||||||
|
|
||||||
|
return {
|
||||||
|
message: {
|
||||||
|
type: "pong",
|
||||||
|
data: {
|
||||||
|
timestamp: new Date().toISOString()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
broadcast: false,
|
||||||
|
excludeSender: false
|
||||||
|
};
|
||||||
|
};
|
||||||
@@ -18,6 +18,7 @@ import {
|
|||||||
} from "#dynamic/lib/exitNodes";
|
} from "#dynamic/lib/exitNodes";
|
||||||
import { fetchContainers } from "./dockerSocket";
|
import { fetchContainers } from "./dockerSocket";
|
||||||
import { lockManager } from "#dynamic/lib/lock";
|
import { lockManager } from "#dynamic/lib/lock";
|
||||||
|
import { buildTargetConfigurationForNewtClient } from "./buildConfiguration";
|
||||||
|
|
||||||
export type ExitNodePingResult = {
|
export type ExitNodePingResult = {
|
||||||
exitNodeId: number;
|
exitNodeId: number;
|
||||||
@@ -233,109 +234,8 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
|||||||
.where(eq(newts.newtId, newt.newtId));
|
.where(eq(newts.newtId, newt.newtId));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all enabled targets with their resource protocol information
|
const { tcpTargets, udpTargets, validHealthCheckTargets } =
|
||||||
const allTargets = await db
|
await buildTargetConfigurationForNewtClient(siteId);
|
||||||
.select({
|
|
||||||
resourceId: targets.resourceId,
|
|
||||||
targetId: targets.targetId,
|
|
||||||
ip: targets.ip,
|
|
||||||
method: targets.method,
|
|
||||||
port: targets.port,
|
|
||||||
internalPort: targets.internalPort,
|
|
||||||
enabled: targets.enabled,
|
|
||||||
protocol: resources.protocol,
|
|
||||||
hcEnabled: targetHealthCheck.hcEnabled,
|
|
||||||
hcPath: targetHealthCheck.hcPath,
|
|
||||||
hcScheme: targetHealthCheck.hcScheme,
|
|
||||||
hcMode: targetHealthCheck.hcMode,
|
|
||||||
hcHostname: targetHealthCheck.hcHostname,
|
|
||||||
hcPort: targetHealthCheck.hcPort,
|
|
||||||
hcInterval: targetHealthCheck.hcInterval,
|
|
||||||
hcUnhealthyInterval: targetHealthCheck.hcUnhealthyInterval,
|
|
||||||
hcTimeout: targetHealthCheck.hcTimeout,
|
|
||||||
hcHeaders: targetHealthCheck.hcHeaders,
|
|
||||||
hcMethod: targetHealthCheck.hcMethod,
|
|
||||||
hcTlsServerName: targetHealthCheck.hcTlsServerName
|
|
||||||
})
|
|
||||||
.from(targets)
|
|
||||||
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
|
|
||||||
.leftJoin(
|
|
||||||
targetHealthCheck,
|
|
||||||
eq(targets.targetId, targetHealthCheck.targetId)
|
|
||||||
)
|
|
||||||
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
|
|
||||||
|
|
||||||
const { tcpTargets, udpTargets } = allTargets.reduce(
|
|
||||||
(acc, target) => {
|
|
||||||
// Filter out invalid targets
|
|
||||||
if (!target.internalPort || !target.ip || !target.port) {
|
|
||||||
return acc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format target into string
|
|
||||||
const formattedTarget = `${target.internalPort}:${target.ip}:${target.port}`;
|
|
||||||
|
|
||||||
// Add to the appropriate protocol array
|
|
||||||
if (target.protocol === "tcp") {
|
|
||||||
acc.tcpTargets.push(formattedTarget);
|
|
||||||
} else {
|
|
||||||
acc.udpTargets.push(formattedTarget);
|
|
||||||
}
|
|
||||||
|
|
||||||
return acc;
|
|
||||||
},
|
|
||||||
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
|
|
||||||
);
|
|
||||||
|
|
||||||
const healthCheckTargets = allTargets.map((target) => {
|
|
||||||
// make sure the stuff is defined
|
|
||||||
if (
|
|
||||||
!target.hcPath ||
|
|
||||||
!target.hcHostname ||
|
|
||||||
!target.hcPort ||
|
|
||||||
!target.hcInterval ||
|
|
||||||
!target.hcMethod
|
|
||||||
) {
|
|
||||||
logger.debug(
|
|
||||||
`Skipping target ${target.targetId} due to missing health check fields`
|
|
||||||
);
|
|
||||||
return null; // Skip targets with missing health check fields
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse headers
|
|
||||||
const hcHeadersParse = target.hcHeaders
|
|
||||||
? JSON.parse(target.hcHeaders)
|
|
||||||
: null;
|
|
||||||
const hcHeadersSend: { [key: string]: string } = {};
|
|
||||||
if (hcHeadersParse) {
|
|
||||||
hcHeadersParse.forEach(
|
|
||||||
(header: { name: string; value: string }) => {
|
|
||||||
hcHeadersSend[header.name] = header.value;
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
id: target.targetId,
|
|
||||||
hcEnabled: target.hcEnabled,
|
|
||||||
hcPath: target.hcPath,
|
|
||||||
hcScheme: target.hcScheme,
|
|
||||||
hcMode: target.hcMode,
|
|
||||||
hcHostname: target.hcHostname,
|
|
||||||
hcPort: target.hcPort,
|
|
||||||
hcInterval: target.hcInterval, // in seconds
|
|
||||||
hcUnhealthyInterval: target.hcUnhealthyInterval, // in seconds
|
|
||||||
hcTimeout: target.hcTimeout, // in seconds
|
|
||||||
hcHeaders: hcHeadersSend,
|
|
||||||
hcMethod: target.hcMethod,
|
|
||||||
hcTlsServerName: target.hcTlsServerName
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
// Filter out any null values from health check targets
|
|
||||||
const validHealthCheckTargets = healthCheckTargets.filter(
|
|
||||||
(target) => target !== null
|
|
||||||
);
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`Sending health check targets to newt ${newt.newtId}: ${JSON.stringify(validHealthCheckTargets)}`
|
`Sending health check targets to newt ${newt.newtId}: ${JSON.stringify(validHealthCheckTargets)}`
|
||||||
|
|||||||
@@ -6,3 +6,4 @@ export * from "./handleGetConfigMessage";
|
|||||||
export * from "./handleSocketMessages";
|
export * from "./handleSocketMessages";
|
||||||
export * from "./handleNewtPingRequestMessage";
|
export * from "./handleNewtPingRequestMessage";
|
||||||
export * from "./handleApplyBlueprintMessage";
|
export * from "./handleApplyBlueprintMessage";
|
||||||
|
export * from "./handleNewtPingMessage";
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ export async function addPeer(
|
|||||||
await sendToClient(newtId, {
|
await sendToClient(newtId, {
|
||||||
type: "newt/wg/peer/add",
|
type: "newt/wg/peer/add",
|
||||||
data: peer
|
data: peer
|
||||||
}).catch((error) => {
|
}, { incrementConfigVersion: true }).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ export async function deletePeer(
|
|||||||
data: {
|
data: {
|
||||||
publicKey
|
publicKey
|
||||||
}
|
}
|
||||||
}).catch((error) => {
|
}, { incrementConfigVersion: true }).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -128,7 +128,7 @@ export async function updatePeer(
|
|||||||
publicKey,
|
publicKey,
|
||||||
...peer
|
...peer
|
||||||
}
|
}
|
||||||
}).catch((error) => {
|
}, { incrementConfigVersion: true }).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
41
server/routers/newt/sync.ts
Normal file
41
server/routers/newt/sync.ts
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import { ExitNode, exitNodes, Newt, Site, db } from "@server/db";
|
||||||
|
import { eq } from "drizzle-orm";
|
||||||
|
import { sendToClient } from "#dynamic/routers/ws";
|
||||||
|
import logger from "@server/logger";
|
||||||
|
import {
|
||||||
|
buildClientConfigurationForNewtClient,
|
||||||
|
buildTargetConfigurationForNewtClient
|
||||||
|
} from "./buildConfiguration";
|
||||||
|
|
||||||
|
export async function sendNewtSyncMessage(newt: Newt, site: Site) {
|
||||||
|
const { tcpTargets, udpTargets, validHealthCheckTargets } =
|
||||||
|
await buildTargetConfigurationForNewtClient(site.siteId);
|
||||||
|
|
||||||
|
let exitNode: ExitNode | undefined;
|
||||||
|
if (site.exitNodeId) {
|
||||||
|
[exitNode] = await db
|
||||||
|
.select()
|
||||||
|
.from(exitNodes)
|
||||||
|
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
|
||||||
|
.limit(1);
|
||||||
|
}
|
||||||
|
const { peers, targets } = await buildClientConfigurationForNewtClient(
|
||||||
|
site,
|
||||||
|
exitNode
|
||||||
|
);
|
||||||
|
|
||||||
|
await sendToClient(newt.newtId, {
|
||||||
|
type: "newt/sync",
|
||||||
|
data: {
|
||||||
|
proxyTargets: {
|
||||||
|
udp: udpTargets,
|
||||||
|
tcp: tcpTargets
|
||||||
|
},
|
||||||
|
healthCheckTargets: validHealthCheckTargets,
|
||||||
|
peers: peers,
|
||||||
|
clientTargets: targets
|
||||||
|
}
|
||||||
|
}).catch((error) => {
|
||||||
|
logger.warn(`Error sending newt sync message:`, error);
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -22,7 +22,7 @@ export async function addTargets(
|
|||||||
data: {
|
data: {
|
||||||
targets: payloadTargets
|
targets: payloadTargets
|
||||||
}
|
}
|
||||||
});
|
}, { incrementConfigVersion: true });
|
||||||
|
|
||||||
// Create a map for quick lookup
|
// Create a map for quick lookup
|
||||||
const healthCheckMap = new Map<number, TargetHealthCheck>();
|
const healthCheckMap = new Map<number, TargetHealthCheck>();
|
||||||
@@ -103,7 +103,7 @@ export async function addTargets(
|
|||||||
data: {
|
data: {
|
||||||
targets: validHealthCheckTargets
|
targets: validHealthCheckTargets
|
||||||
}
|
}
|
||||||
});
|
}, { incrementConfigVersion: true });
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function removeTargets(
|
export async function removeTargets(
|
||||||
@@ -124,7 +124,7 @@ export async function removeTargets(
|
|||||||
data: {
|
data: {
|
||||||
targets: payloadTargets
|
targets: payloadTargets
|
||||||
}
|
}
|
||||||
});
|
}, { incrementConfigVersion: true });
|
||||||
|
|
||||||
const healthCheckTargets = targets.map((target) => {
|
const healthCheckTargets = targets.map((target) => {
|
||||||
return target.targetId;
|
return target.targetId;
|
||||||
@@ -135,5 +135,5 @@ export async function removeTargets(
|
|||||||
data: {
|
data: {
|
||||||
ids: healthCheckTargets
|
ids: healthCheckTargets
|
||||||
}
|
}
|
||||||
});
|
}, { incrementConfigVersion: true });
|
||||||
}
|
}
|
||||||
|
|||||||
145
server/routers/olm/buildConfiguration.ts
Normal file
145
server/routers/olm/buildConfiguration.ts
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import { Client, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, exitNodes, siteResources, sites } from "@server/db";
|
||||||
|
import { generateAliasConfig, generateRemoteSubnets } from "@server/lib/ip";
|
||||||
|
import logger from "@server/logger";
|
||||||
|
import { and, eq } from "drizzle-orm";
|
||||||
|
import { addPeer, deletePeer } from "../newt/peers";
|
||||||
|
import config from "@server/lib/config";
|
||||||
|
|
||||||
|
export async function buildSiteConfigurationForOlmClient(
|
||||||
|
client: Client,
|
||||||
|
publicKey: string | null,
|
||||||
|
relay: boolean
|
||||||
|
) {
|
||||||
|
const siteConfigurations = [];
|
||||||
|
|
||||||
|
// Get all sites data
|
||||||
|
const sitesData = await db
|
||||||
|
.select()
|
||||||
|
.from(sites)
|
||||||
|
.innerJoin(
|
||||||
|
clientSitesAssociationsCache,
|
||||||
|
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
||||||
|
)
|
||||||
|
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||||
|
|
||||||
|
// Process each site
|
||||||
|
for (const {
|
||||||
|
sites: site,
|
||||||
|
clientSitesAssociationsCache: association
|
||||||
|
} of sitesData) {
|
||||||
|
if (!site.exitNodeId) {
|
||||||
|
logger.warn(
|
||||||
|
`Site ${site.siteId} does not have exit node, skipping`
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate endpoint and hole punch status
|
||||||
|
if (!site.endpoint) {
|
||||||
|
logger.warn(
|
||||||
|
`In olm register: site ${site.siteId} has no endpoint, skipping`
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if (site.lastHolePunch && now - site.lastHolePunch > 6 && relay) {
|
||||||
|
// logger.warn(
|
||||||
|
// `Site ${site.siteId} last hole punch is too old, skipping`
|
||||||
|
// );
|
||||||
|
// continue;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// If public key changed, delete old peer from this site
|
||||||
|
if (client.pubKey && client.pubKey != publicKey) {
|
||||||
|
logger.info(
|
||||||
|
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
|
||||||
|
);
|
||||||
|
await deletePeer(site.siteId, client.pubKey!);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!site.subnet) {
|
||||||
|
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const [clientSite] = await db
|
||||||
|
.select()
|
||||||
|
.from(clientSitesAssociationsCache)
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
eq(clientSitesAssociationsCache.clientId, client.clientId),
|
||||||
|
eq(clientSitesAssociationsCache.siteId, site.siteId)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.limit(1);
|
||||||
|
|
||||||
|
// Add the peer to the exit node for this site
|
||||||
|
if (clientSite.endpoint && publicKey) {
|
||||||
|
logger.info(
|
||||||
|
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}`
|
||||||
|
);
|
||||||
|
await addPeer(site.siteId, {
|
||||||
|
publicKey: publicKey,
|
||||||
|
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
|
||||||
|
endpoint: relay ? "" : clientSite.endpoint
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
logger.warn(
|
||||||
|
`Client ${client.clientId} has no endpoint, skipping peer addition`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let relayEndpoint: string | undefined = undefined;
|
||||||
|
if (relay) {
|
||||||
|
const [exitNode] = await db
|
||||||
|
.select()
|
||||||
|
.from(exitNodes)
|
||||||
|
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
|
||||||
|
.limit(1);
|
||||||
|
if (!exitNode) {
|
||||||
|
logger.warn(`Exit node not found for site ${site.siteId}`);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const allSiteResources = await db // only get the site resources that this client has access to
|
||||||
|
.select()
|
||||||
|
.from(siteResources)
|
||||||
|
.innerJoin(
|
||||||
|
clientSiteResourcesAssociationsCache,
|
||||||
|
eq(
|
||||||
|
siteResources.siteResourceId,
|
||||||
|
clientSiteResourcesAssociationsCache.siteResourceId
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
eq(siteResources.siteId, site.siteId),
|
||||||
|
eq(
|
||||||
|
clientSiteResourcesAssociationsCache.clientId,
|
||||||
|
client.clientId
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Add site configuration to the array
|
||||||
|
siteConfigurations.push({
|
||||||
|
siteId: site.siteId,
|
||||||
|
name: site.name,
|
||||||
|
// relayEndpoint: relayEndpoint, // this can be undefined now if not relayed // lets not do this for now because it would conflict with the hole punch testing
|
||||||
|
endpoint: site.endpoint,
|
||||||
|
publicKey: site.publicKey,
|
||||||
|
serverIP: site.address,
|
||||||
|
serverPort: site.listenPort,
|
||||||
|
remoteSubnets: generateRemoteSubnets(
|
||||||
|
allSiteResources.map(({ siteResources }) => siteResources)
|
||||||
|
),
|
||||||
|
aliases: generateAliasConfig(
|
||||||
|
allSiteResources.map(({ siteResources }) => siteResources)
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return siteConfigurations;
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
|
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
|
||||||
import { clientPostureSnapshots, db, fingerprints } from "@server/db";
|
import { clientPostureSnapshots, db, fingerprints } from "@server/db";
|
||||||
import { disconnectClient } from "#dynamic/routers/ws";
|
|
||||||
import { MessageHandler } from "@server/routers/ws";
|
import { MessageHandler } from "@server/routers/ws";
|
||||||
import { clients, olms, Olm } from "@server/db";
|
import { clients, olms, Olm } from "@server/db";
|
||||||
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
||||||
@@ -9,6 +9,7 @@ import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
|||||||
import { sendTerminateClient } from "../client/terminate";
|
import { sendTerminateClient } from "../client/terminate";
|
||||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||||
import { sha256 } from "@oslojs/crypto/sha2";
|
import { sha256 } from "@oslojs/crypto/sha2";
|
||||||
|
import { sendOlmSyncMessage } from "./sync";
|
||||||
|
|
||||||
// Track if the offline checker interval is running
|
// Track if the offline checker interval is running
|
||||||
let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
||||||
@@ -128,7 +129,9 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
|
|||||||
|
|
||||||
if (client.blocked) {
|
if (client.blocked) {
|
||||||
// NOTE: by returning we dont update the lastPing, so the offline checker will eventually disconnect them
|
// NOTE: by returning we dont update the lastPing, so the offline checker will eventually disconnect them
|
||||||
logger.debug(`Blocked client ${client.clientId} attempted olm ping`);
|
logger.debug(
|
||||||
|
`Blocked client ${client.clientId} attempted olm ping`
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,6 +170,23 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get the version
|
||||||
|
logger.debug(`handleOlmPingMessage: About to get config version for olmId: ${olm.olmId}`);
|
||||||
|
const configVersion = await getClientConfigVersion(olm.olmId);
|
||||||
|
logger.debug(`handleOlmPingMessage: Got config version: ${configVersion} (type: ${typeof configVersion})`);
|
||||||
|
|
||||||
|
if (configVersion == null || configVersion === undefined) {
|
||||||
|
logger.debug(`handleOlmPingMessage: could not get config version from server for olmId: ${olm.olmId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (message.configVersion != null && configVersion != null && configVersion != message.configVersion) {
|
||||||
|
logger.debug(
|
||||||
|
`handleOlmPingMessage: Olm ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
|
||||||
|
);
|
||||||
|
await sendOlmSyncMessage(olm, client);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the client's last ping timestamp
|
||||||
await db
|
await db
|
||||||
.update(clients)
|
.update(clients)
|
||||||
.set({
|
.set({
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
|
Client,
|
||||||
clientPostureSnapshots,
|
clientPostureSnapshots,
|
||||||
clientSiteResourcesAssociationsCache,
|
clientSiteResourcesAssociationsCache,
|
||||||
db,
|
db,
|
||||||
@@ -15,7 +16,7 @@ import {
|
|||||||
olms,
|
olms,
|
||||||
sites
|
sites
|
||||||
} from "@server/db";
|
} from "@server/db";
|
||||||
import { and, eq, inArray, isNull } from "drizzle-orm";
|
import { and, count, eq, inArray, isNull } from "drizzle-orm";
|
||||||
import { addPeer, deletePeer } from "../newt/peers";
|
import { addPeer, deletePeer } from "../newt/peers";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import { generateAliasConfig } from "@server/lib/ip";
|
import { generateAliasConfig } from "@server/lib/ip";
|
||||||
@@ -25,6 +26,7 @@ import { validateSessionToken } from "@server/auth/sessions/app";
|
|||||||
import config from "@server/lib/config";
|
import config from "@server/lib/config";
|
||||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||||
import { sha256 } from "@oslojs/crypto/sha2";
|
import { sha256 } from "@oslojs/crypto/sha2";
|
||||||
|
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
|
||||||
|
|
||||||
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||||
logger.info("Handling register olm message!");
|
logger.info("Handling register olm message!");
|
||||||
@@ -163,8 +165,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get all sites data
|
// Get all sites data
|
||||||
const sitesData = await db
|
const sitesCountResult = await db
|
||||||
.select()
|
.select({ count: count() })
|
||||||
.from(sites)
|
.from(sites)
|
||||||
.innerJoin(
|
.innerJoin(
|
||||||
clientSitesAssociationsCache,
|
clientSitesAssociationsCache,
|
||||||
@@ -172,140 +174,29 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
|||||||
)
|
)
|
||||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||||
|
|
||||||
|
// Extract the count value from the result array
|
||||||
|
const sitesCount =
|
||||||
|
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
|
||||||
|
|
||||||
// Prepare an array to store site configurations
|
// Prepare an array to store site configurations
|
||||||
const siteConfigurations = [];
|
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
|
||||||
logger.debug(
|
|
||||||
`Found ${sitesData.length} sites for client ${client.clientId}`
|
|
||||||
);
|
|
||||||
|
|
||||||
// this prevents us from accepting a register from an olm that has not hole punched yet.
|
// this prevents us from accepting a register from an olm that has not hole punched yet.
|
||||||
// the olm will pump the register so we can keep checking
|
// the olm will pump the register so we can keep checking
|
||||||
// TODO: I still think there is a better way to do this rather than locking it out here but ???
|
// TODO: I still think there is a better way to do this rather than locking it out here but ???
|
||||||
if (now - (client.lastHolePunch || 0) > 5 && sitesData.length > 0) {
|
if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) {
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Client last hole punch is too old and we have sites to send; skipping this register"
|
"Client last hole punch is too old and we have sites to send; skipping this register"
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process each site
|
// NOTE: its important that the client here is the old client and the public key is the new key
|
||||||
for (const {
|
const siteConfigurations = await buildSiteConfigurationForOlmClient(
|
||||||
sites: site,
|
client,
|
||||||
clientSitesAssociationsCache: association
|
publicKey,
|
||||||
} of sitesData) {
|
relay
|
||||||
if (!site.exitNodeId) {
|
);
|
||||||
logger.warn(
|
|
||||||
`Site ${site.siteId} does not have exit node, skipping`
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate endpoint and hole punch status
|
|
||||||
if (!site.endpoint) {
|
|
||||||
logger.warn(
|
|
||||||
`In olm register: site ${site.siteId} has no endpoint, skipping`
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// if (site.lastHolePunch && now - site.lastHolePunch > 6 && relay) {
|
|
||||||
// logger.warn(
|
|
||||||
// `Site ${site.siteId} last hole punch is too old, skipping`
|
|
||||||
// );
|
|
||||||
// continue;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// If public key changed, delete old peer from this site
|
|
||||||
if (client.pubKey && client.pubKey != publicKey) {
|
|
||||||
logger.info(
|
|
||||||
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
|
|
||||||
);
|
|
||||||
await deletePeer(site.siteId, client.pubKey!);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!site.subnet) {
|
|
||||||
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const [clientSite] = await db
|
|
||||||
.select()
|
|
||||||
.from(clientSitesAssociationsCache)
|
|
||||||
.where(
|
|
||||||
and(
|
|
||||||
eq(clientSitesAssociationsCache.clientId, client.clientId),
|
|
||||||
eq(clientSitesAssociationsCache.siteId, site.siteId)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.limit(1);
|
|
||||||
|
|
||||||
// Add the peer to the exit node for this site
|
|
||||||
if (clientSite.endpoint) {
|
|
||||||
logger.info(
|
|
||||||
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}`
|
|
||||||
);
|
|
||||||
await addPeer(site.siteId, {
|
|
||||||
publicKey: publicKey,
|
|
||||||
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
|
|
||||||
endpoint: relay ? "" : clientSite.endpoint
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
logger.warn(
|
|
||||||
`Client ${client.clientId} has no endpoint, skipping peer addition`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let relayEndpoint: string | undefined = undefined;
|
|
||||||
if (relay) {
|
|
||||||
const [exitNode] = await db
|
|
||||||
.select()
|
|
||||||
.from(exitNodes)
|
|
||||||
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
|
|
||||||
.limit(1);
|
|
||||||
if (!exitNode) {
|
|
||||||
logger.warn(`Exit node not found for site ${site.siteId}`);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
const allSiteResources = await db // only get the site resources that this client has access to
|
|
||||||
.select()
|
|
||||||
.from(siteResources)
|
|
||||||
.innerJoin(
|
|
||||||
clientSiteResourcesAssociationsCache,
|
|
||||||
eq(
|
|
||||||
siteResources.siteResourceId,
|
|
||||||
clientSiteResourcesAssociationsCache.siteResourceId
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.where(
|
|
||||||
and(
|
|
||||||
eq(siteResources.siteId, site.siteId),
|
|
||||||
eq(
|
|
||||||
clientSiteResourcesAssociationsCache.clientId,
|
|
||||||
client.clientId
|
|
||||||
)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
// Add site configuration to the array
|
|
||||||
siteConfigurations.push({
|
|
||||||
siteId: site.siteId,
|
|
||||||
name: site.name,
|
|
||||||
// relayEndpoint: relayEndpoint, // this can be undefined now if not relayed // lets not do this for now because it would conflict with the hole punch testing
|
|
||||||
endpoint: site.endpoint,
|
|
||||||
publicKey: site.publicKey,
|
|
||||||
serverIP: site.address,
|
|
||||||
serverPort: site.listenPort,
|
|
||||||
remoteSubnets: generateRemoteSubnets(
|
|
||||||
allSiteResources.map(({ siteResources }) => siteResources)
|
|
||||||
),
|
|
||||||
aliases: generateAliasConfig(
|
|
||||||
allSiteResources.map(({ siteResources }) => siteResources)
|
|
||||||
)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (fingerprint) {
|
if (fingerprint) {
|
||||||
const [existingFingerprint] = await db
|
const [existingFingerprint] = await db
|
||||||
|
|||||||
@@ -32,20 +32,24 @@ export async function addPeer(
|
|||||||
olmId = olm.olmId;
|
olmId = olm.olmId;
|
||||||
}
|
}
|
||||||
|
|
||||||
await sendToClient(olmId, {
|
await sendToClient(
|
||||||
type: "olm/wg/peer/add",
|
olmId,
|
||||||
data: {
|
{
|
||||||
siteId: peer.siteId,
|
type: "olm/wg/peer/add",
|
||||||
name: peer.name,
|
data: {
|
||||||
publicKey: peer.publicKey,
|
siteId: peer.siteId,
|
||||||
endpoint: peer.endpoint,
|
name: peer.name,
|
||||||
relayEndpoint: peer.relayEndpoint,
|
publicKey: peer.publicKey,
|
||||||
serverIP: peer.serverIP,
|
endpoint: peer.endpoint,
|
||||||
serverPort: peer.serverPort,
|
relayEndpoint: peer.relayEndpoint,
|
||||||
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
|
serverIP: peer.serverIP,
|
||||||
aliases: peer.aliases
|
serverPort: peer.serverPort,
|
||||||
}
|
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
|
||||||
}).catch((error) => {
|
aliases: peer.aliases
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ incrementConfigVersion: true }
|
||||||
|
).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -70,13 +74,17 @@ export async function deletePeer(
|
|||||||
olmId = olm.olmId;
|
olmId = olm.olmId;
|
||||||
}
|
}
|
||||||
|
|
||||||
await sendToClient(olmId, {
|
await sendToClient(
|
||||||
type: "olm/wg/peer/remove",
|
olmId,
|
||||||
data: {
|
{
|
||||||
publicKey,
|
type: "olm/wg/peer/remove",
|
||||||
siteId: siteId
|
data: {
|
||||||
}
|
publicKey,
|
||||||
}).catch((error) => {
|
siteId: siteId
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ incrementConfigVersion: true }
|
||||||
|
).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -109,19 +117,23 @@ export async function updatePeer(
|
|||||||
olmId = olm.olmId;
|
olmId = olm.olmId;
|
||||||
}
|
}
|
||||||
|
|
||||||
await sendToClient(olmId, {
|
await sendToClient(
|
||||||
type: "olm/wg/peer/update",
|
olmId,
|
||||||
data: {
|
{
|
||||||
siteId: peer.siteId,
|
type: "olm/wg/peer/update",
|
||||||
publicKey: peer.publicKey,
|
data: {
|
||||||
endpoint: peer.endpoint,
|
siteId: peer.siteId,
|
||||||
relayEndpoint: peer.relayEndpoint,
|
publicKey: peer.publicKey,
|
||||||
serverIP: peer.serverIP,
|
endpoint: peer.endpoint,
|
||||||
serverPort: peer.serverPort,
|
relayEndpoint: peer.relayEndpoint,
|
||||||
remoteSubnets: peer.remoteSubnets,
|
serverIP: peer.serverIP,
|
||||||
aliases: peer.aliases
|
serverPort: peer.serverPort,
|
||||||
}
|
remoteSubnets: peer.remoteSubnets,
|
||||||
}).catch((error) => {
|
aliases: peer.aliases
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ incrementConfigVersion: true }
|
||||||
|
).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -151,17 +163,21 @@ export async function initPeerAddHandshake(
|
|||||||
olmId = olm.olmId;
|
olmId = olm.olmId;
|
||||||
}
|
}
|
||||||
|
|
||||||
await sendToClient(olmId, {
|
await sendToClient(
|
||||||
type: "olm/wg/peer/holepunch/site/add",
|
olmId,
|
||||||
data: {
|
{
|
||||||
siteId: peer.siteId,
|
type: "olm/wg/peer/holepunch/site/add",
|
||||||
exitNode: {
|
data: {
|
||||||
publicKey: peer.exitNode.publicKey,
|
siteId: peer.siteId,
|
||||||
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
exitNode: {
|
||||||
endpoint: peer.exitNode.endpoint
|
publicKey: peer.exitNode.publicKey,
|
||||||
|
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
||||||
|
endpoint: peer.exitNode.endpoint
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}).catch((error) => {
|
{ incrementConfigVersion: true }
|
||||||
|
).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
80
server/routers/olm/sync.ts
Normal file
80
server/routers/olm/sync.ts
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import { Client, db, exitNodes, Olm, sites, clientSitesAssociationsCache } from "@server/db";
|
||||||
|
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
|
||||||
|
import { sendToClient } from "#dynamic/routers/ws";
|
||||||
|
import logger from "@server/logger";
|
||||||
|
import { eq, inArray } from "drizzle-orm";
|
||||||
|
import config from "@server/lib/config";
|
||||||
|
|
||||||
|
export async function sendOlmSyncMessage(olm: Olm, client: Client) {
|
||||||
|
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
|
||||||
|
const siteConfigurations = await buildSiteConfigurationForOlmClient(
|
||||||
|
client,
|
||||||
|
client.pubKey,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
// Get all exit nodes from sites where the client has peers
|
||||||
|
const clientSites = await db
|
||||||
|
.select()
|
||||||
|
.from(clientSitesAssociationsCache)
|
||||||
|
.innerJoin(
|
||||||
|
sites,
|
||||||
|
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
||||||
|
)
|
||||||
|
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||||
|
|
||||||
|
// Extract unique exit node IDs
|
||||||
|
const exitNodeIds = Array.from(
|
||||||
|
new Set(
|
||||||
|
clientSites
|
||||||
|
.map(({ sites: site }) => site.exitNodeId)
|
||||||
|
.filter((id): id is number => id !== null)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
let exitNodesData: {
|
||||||
|
publicKey: string;
|
||||||
|
relayPort: number;
|
||||||
|
endpoint: string;
|
||||||
|
siteIds: number[];
|
||||||
|
}[] = [];
|
||||||
|
|
||||||
|
if (exitNodeIds.length > 0) {
|
||||||
|
const allExitNodes = await db
|
||||||
|
.select()
|
||||||
|
.from(exitNodes)
|
||||||
|
.where(inArray(exitNodes.exitNodeId, exitNodeIds));
|
||||||
|
|
||||||
|
// Map exitNodeId to siteIds
|
||||||
|
const exitNodeIdToSiteIds: Record<number, number[]> = {};
|
||||||
|
for (const { sites: site } of clientSites) {
|
||||||
|
if (site.exitNodeId !== null) {
|
||||||
|
if (!exitNodeIdToSiteIds[site.exitNodeId]) {
|
||||||
|
exitNodeIdToSiteIds[site.exitNodeId] = [];
|
||||||
|
}
|
||||||
|
exitNodeIdToSiteIds[site.exitNodeId].push(site.siteId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
exitNodesData = allExitNodes.map((exitNode) => {
|
||||||
|
return {
|
||||||
|
publicKey: exitNode.publicKey,
|
||||||
|
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
||||||
|
endpoint: exitNode.endpoint,
|
||||||
|
siteIds: exitNodeIdToSiteIds[exitNode.exitNodeId] ?? []
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug("sendOlmSyncMessage: sending sync message")
|
||||||
|
|
||||||
|
await sendToClient(olm.olmId, {
|
||||||
|
type: "olm/sync",
|
||||||
|
data: {
|
||||||
|
sites: siteConfigurations,
|
||||||
|
exitNodes: exitNodesData
|
||||||
|
}
|
||||||
|
}).catch((error) => {
|
||||||
|
logger.warn(`Error sending olm sync message:`, error);
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -5,7 +5,8 @@ import {
|
|||||||
handleDockerStatusMessage,
|
handleDockerStatusMessage,
|
||||||
handleDockerContainersMessage,
|
handleDockerContainersMessage,
|
||||||
handleNewtPingRequestMessage,
|
handleNewtPingRequestMessage,
|
||||||
handleApplyBlueprintMessage
|
handleApplyBlueprintMessage,
|
||||||
|
handleNewtPingMessage
|
||||||
} from "../newt";
|
} from "../newt";
|
||||||
import {
|
import {
|
||||||
handleOlmRegisterMessage,
|
handleOlmRegisterMessage,
|
||||||
@@ -24,6 +25,7 @@ export const messageHandlers: Record<string, MessageHandler> = {
|
|||||||
"olm/wg/relay": handleOlmRelayMessage,
|
"olm/wg/relay": handleOlmRelayMessage,
|
||||||
"olm/wg/unrelay": handleOlmUnRelayMessage,
|
"olm/wg/unrelay": handleOlmUnRelayMessage,
|
||||||
"olm/ping": handleOlmPingMessage,
|
"olm/ping": handleOlmPingMessage,
|
||||||
|
"newt/ping": handleNewtPingMessage,
|
||||||
"newt/wg/register": handleNewtRegisterMessage,
|
"newt/wg/register": handleNewtRegisterMessage,
|
||||||
"newt/wg/get-config": handleGetConfigMessage,
|
"newt/wg/get-config": handleGetConfigMessage,
|
||||||
"newt/receive-bandwidth": handleReceiveBandwidthMessage,
|
"newt/receive-bandwidth": handleReceiveBandwidthMessage,
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ export interface AuthenticatedWebSocket extends WebSocket {
|
|||||||
connectionId?: string;
|
connectionId?: string;
|
||||||
isFullyConnected?: boolean;
|
isFullyConnected?: boolean;
|
||||||
pendingMessages?: Buffer[];
|
pendingMessages?: Buffer[];
|
||||||
|
configVersion?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TokenPayload {
|
export interface TokenPayload {
|
||||||
@@ -36,6 +37,7 @@ export interface TokenPayload {
|
|||||||
export interface WSMessage {
|
export interface WSMessage {
|
||||||
type: string;
|
type: string;
|
||||||
data: any;
|
data: any;
|
||||||
|
configVersion?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface HandlerResponse {
|
export interface HandlerResponse {
|
||||||
@@ -43,6 +45,7 @@ export interface HandlerResponse {
|
|||||||
broadcast?: boolean;
|
broadcast?: boolean;
|
||||||
excludeSender?: boolean;
|
excludeSender?: boolean;
|
||||||
targetClientId?: string;
|
targetClientId?: string;
|
||||||
|
options?: SendMessageOptions;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface HandlerContext {
|
export interface HandlerContext {
|
||||||
@@ -50,10 +53,15 @@ export interface HandlerContext {
|
|||||||
senderWs: WebSocket;
|
senderWs: WebSocket;
|
||||||
client: Newt | Olm | RemoteExitNode | undefined;
|
client: Newt | Olm | RemoteExitNode | undefined;
|
||||||
clientType: ClientType;
|
clientType: ClientType;
|
||||||
sendToClient: (clientId: string, message: WSMessage) => Promise<boolean>;
|
sendToClient: (
|
||||||
|
clientId: string,
|
||||||
|
message: WSMessage,
|
||||||
|
options?: SendMessageOptions
|
||||||
|
) => Promise<boolean>;
|
||||||
broadcastToAllExcept: (
|
broadcastToAllExcept: (
|
||||||
message: WSMessage,
|
message: WSMessage,
|
||||||
excludeClientId?: string
|
excludeClientId?: string,
|
||||||
|
options?: SendMessageOptions
|
||||||
) => Promise<void>;
|
) => Promise<void>;
|
||||||
connectedClients: Map<string, WebSocket[]>;
|
connectedClients: Map<string, WebSocket[]>;
|
||||||
}
|
}
|
||||||
@@ -62,6 +70,11 @@ export type MessageHandler = (
|
|||||||
context: HandlerContext
|
context: HandlerContext
|
||||||
) => Promise<HandlerResponse | void>;
|
) => Promise<HandlerResponse | void>;
|
||||||
|
|
||||||
|
// Options for sending messages with config version tracking
|
||||||
|
export interface SendMessageOptions {
|
||||||
|
incrementConfigVersion?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
// Redis message type for cross-node communication
|
// Redis message type for cross-node communication
|
||||||
export interface RedisMessage {
|
export interface RedisMessage {
|
||||||
type: "direct" | "broadcast";
|
type: "direct" | "broadcast";
|
||||||
@@ -69,4 +82,5 @@ export interface RedisMessage {
|
|||||||
excludeClientId?: string;
|
excludeClientId?: string;
|
||||||
message: WSMessage;
|
message: WSMessage;
|
||||||
fromNodeId: string;
|
fromNodeId: string;
|
||||||
|
options?: SendMessageOptions;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ import {
|
|||||||
TokenPayload,
|
TokenPayload,
|
||||||
WebSocketRequest,
|
WebSocketRequest,
|
||||||
WSMessage,
|
WSMessage,
|
||||||
AuthenticatedWebSocket
|
AuthenticatedWebSocket,
|
||||||
|
SendMessageOptions
|
||||||
} from "./types";
|
} from "./types";
|
||||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||||
|
|
||||||
@@ -34,6 +35,8 @@ const NODE_ID = uuidv4();
|
|||||||
|
|
||||||
// Client tracking map (local to this node)
|
// Client tracking map (local to this node)
|
||||||
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
|
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
|
||||||
|
// Config version tracking map (clientId -> version)
|
||||||
|
const clientConfigVersions: Map<string, number> = new Map();
|
||||||
// Helper to get map key
|
// Helper to get map key
|
||||||
const getClientMapKey = (clientId: string) => clientId;
|
const getClientMapKey = (clientId: string) => clientId;
|
||||||
|
|
||||||
@@ -53,6 +56,13 @@ const addClient = async (
|
|||||||
existingClients.push(ws);
|
existingClients.push(ws);
|
||||||
connectedClients.set(mapKey, existingClients);
|
connectedClients.set(mapKey, existingClients);
|
||||||
|
|
||||||
|
// Initialize config version to 0 if not already set, otherwise use existing
|
||||||
|
if (!clientConfigVersions.has(clientId)) {
|
||||||
|
clientConfigVersions.set(clientId, 0);
|
||||||
|
}
|
||||||
|
// Set the current config version on the websocket
|
||||||
|
ws.configVersion = clientConfigVersions.get(clientId) || 0;
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`
|
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`
|
||||||
);
|
);
|
||||||
@@ -84,14 +94,28 @@ const removeClient = async (
|
|||||||
// Local message sending (within this node)
|
// Local message sending (within this node)
|
||||||
const sendToClientLocal = async (
|
const sendToClientLocal = async (
|
||||||
clientId: string,
|
clientId: string,
|
||||||
message: WSMessage
|
message: WSMessage,
|
||||||
|
options: SendMessageOptions = {}
|
||||||
): Promise<boolean> => {
|
): Promise<boolean> => {
|
||||||
const mapKey = getClientMapKey(clientId);
|
const mapKey = getClientMapKey(clientId);
|
||||||
const clients = connectedClients.get(mapKey);
|
const clients = connectedClients.get(mapKey);
|
||||||
if (!clients || clients.length === 0) {
|
if (!clients || clients.length === 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const messageString = JSON.stringify(message);
|
|
||||||
|
// Include config version in message
|
||||||
|
const configVersion = clientConfigVersions.get(clientId) || 0;
|
||||||
|
// Update version on all client connections
|
||||||
|
clients.forEach((client) => {
|
||||||
|
client.configVersion = configVersion;
|
||||||
|
});
|
||||||
|
|
||||||
|
const messageWithVersion = {
|
||||||
|
...message,
|
||||||
|
configVersion
|
||||||
|
};
|
||||||
|
|
||||||
|
const messageString = JSON.stringify(messageWithVersion);
|
||||||
clients.forEach((client) => {
|
clients.forEach((client) => {
|
||||||
if (client.readyState === WebSocket.OPEN) {
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
client.send(messageString);
|
client.send(messageString);
|
||||||
@@ -102,14 +126,30 @@ const sendToClientLocal = async (
|
|||||||
|
|
||||||
const broadcastToAllExceptLocal = async (
|
const broadcastToAllExceptLocal = async (
|
||||||
message: WSMessage,
|
message: WSMessage,
|
||||||
excludeClientId?: string
|
excludeClientId?: string,
|
||||||
|
options: SendMessageOptions = {}
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
connectedClients.forEach((clients, mapKey) => {
|
connectedClients.forEach((clients, mapKey) => {
|
||||||
const [type, id] = mapKey.split(":");
|
const clientId = mapKey; // mapKey is the clientId
|
||||||
if (!(excludeClientId && id === excludeClientId)) {
|
if (!(excludeClientId && clientId === excludeClientId)) {
|
||||||
|
// Handle config version per client
|
||||||
|
if (options.incrementConfigVersion) {
|
||||||
|
const currentVersion = clientConfigVersions.get(clientId) || 0;
|
||||||
|
const newVersion = currentVersion + 1;
|
||||||
|
clientConfigVersions.set(clientId, newVersion);
|
||||||
|
clients.forEach((client) => {
|
||||||
|
client.configVersion = newVersion;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// Include config version in message for this client
|
||||||
|
const configVersion = clientConfigVersions.get(clientId) || 0;
|
||||||
|
const messageWithVersion = {
|
||||||
|
...message,
|
||||||
|
configVersion
|
||||||
|
};
|
||||||
clients.forEach((client) => {
|
clients.forEach((client) => {
|
||||||
if (client.readyState === WebSocket.OPEN) {
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
client.send(JSON.stringify(message));
|
client.send(JSON.stringify(messageWithVersion));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -119,10 +159,18 @@ const broadcastToAllExceptLocal = async (
|
|||||||
// Cross-node message sending
|
// Cross-node message sending
|
||||||
const sendToClient = async (
|
const sendToClient = async (
|
||||||
clientId: string,
|
clientId: string,
|
||||||
message: WSMessage
|
message: WSMessage,
|
||||||
|
options: SendMessageOptions = {}
|
||||||
): Promise<boolean> => {
|
): Promise<boolean> => {
|
||||||
|
// Increment config version if requested
|
||||||
|
if (options.incrementConfigVersion) {
|
||||||
|
const currentVersion = clientConfigVersions.get(clientId) || 0;
|
||||||
|
const newVersion = currentVersion + 1;
|
||||||
|
clientConfigVersions.set(clientId, newVersion);
|
||||||
|
}
|
||||||
|
|
||||||
// Try to send locally first
|
// Try to send locally first
|
||||||
const localSent = await sendToClientLocal(clientId, message);
|
const localSent = await sendToClientLocal(clientId, message, options);
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
|
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
|
||||||
@@ -133,10 +181,11 @@ const sendToClient = async (
|
|||||||
|
|
||||||
const broadcastToAllExcept = async (
|
const broadcastToAllExcept = async (
|
||||||
message: WSMessage,
|
message: WSMessage,
|
||||||
excludeClientId?: string
|
excludeClientId?: string,
|
||||||
|
options: SendMessageOptions = {}
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
// Broadcast locally
|
// Broadcast locally
|
||||||
await broadcastToAllExceptLocal(message, excludeClientId);
|
await broadcastToAllExceptLocal(message, excludeClientId, options);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if a client has active connections across all nodes
|
// Check if a client has active connections across all nodes
|
||||||
@@ -146,6 +195,13 @@ const hasActiveConnections = async (clientId: string): Promise<boolean> => {
|
|||||||
return !!(clients && clients.length > 0);
|
return !!(clients && clients.length > 0);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Get the current config version for a client
|
||||||
|
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
|
||||||
|
const version = clientConfigVersions.get(clientId);
|
||||||
|
logger.debug(`getClientConfigVersion called for clientId: ${clientId}, returning: ${version} (type: ${typeof version})`);
|
||||||
|
return version;
|
||||||
|
};
|
||||||
|
|
||||||
// Get all active nodes for a client
|
// Get all active nodes for a client
|
||||||
const getActiveNodes = async (
|
const getActiveNodes = async (
|
||||||
clientType: ClientType,
|
clientType: ClientType,
|
||||||
@@ -259,15 +315,21 @@ const setupConnection = async (
|
|||||||
if (response.broadcast) {
|
if (response.broadcast) {
|
||||||
await broadcastToAllExcept(
|
await broadcastToAllExcept(
|
||||||
response.message,
|
response.message,
|
||||||
response.excludeSender ? clientId : undefined
|
response.excludeSender ? clientId : undefined,
|
||||||
|
response.options
|
||||||
);
|
);
|
||||||
} else if (response.targetClientId) {
|
} else if (response.targetClientId) {
|
||||||
await sendToClient(
|
await sendToClient(
|
||||||
response.targetClientId,
|
response.targetClientId,
|
||||||
response.message
|
response.message,
|
||||||
|
response.options
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
ws.send(JSON.stringify(response.message));
|
await sendToClient(
|
||||||
|
clientId,
|
||||||
|
response.message,
|
||||||
|
response.options
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -434,5 +496,6 @@ export {
|
|||||||
getActiveNodes,
|
getActiveNodes,
|
||||||
disconnectClient,
|
disconnectClient,
|
||||||
NODE_ID,
|
NODE_ID,
|
||||||
cleanup
|
cleanup,
|
||||||
|
getClientConfigVersion
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user