mirror of
https://github.com/fosrl/pangolin.git
synced 2026-04-26 07:52:24 +00:00
Compare commits
18 Commits
06aaa7c680
...
79ba804c88
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79ba804c88 | ||
|
|
05748bf8ff | ||
|
|
f8c98bf6bf | ||
|
|
f4496bb23a | ||
|
|
c93766bb48 | ||
|
|
a1ea3f74b3 | ||
|
|
65e8bfc93e | ||
|
|
1065004fa3 | ||
|
|
d52bd65d21 | ||
|
|
eb0cdda0f9 | ||
|
|
eba25fcc4d | ||
|
|
0ccd5714f9 | ||
|
|
e2dfc3eb20 | ||
|
|
446eba8bc9 | ||
|
|
18579c0647 | ||
|
|
0d37e08638 | ||
|
|
75b9703793 | ||
|
|
322f3bfb1d |
@@ -1308,6 +1308,7 @@
|
||||
"setupErrorCreateAdmin": "An error occurred while creating the server admin account.",
|
||||
"certificateStatus": "Certificate Status",
|
||||
"loading": "Loading",
|
||||
"loadingAnalytics": "Loading Analytics",
|
||||
"restart": "Restart",
|
||||
"domains": "Domains",
|
||||
"domainsDescription": "Create and manage domains available in the organization",
|
||||
|
||||
@@ -26,6 +26,7 @@ export function initLogCleanupInterval() {
|
||||
)
|
||||
);
|
||||
|
||||
// TODO: handle when there are multiple nodes doing this clearing using redis
|
||||
for (const org of orgsToClean) {
|
||||
const {
|
||||
orgId,
|
||||
|
||||
@@ -50,10 +50,14 @@ export async function sendToExitNode(
|
||||
);
|
||||
}
|
||||
|
||||
return sendToClient(remoteExitNode.remoteExitNodeId, {
|
||||
type: request.remoteType,
|
||||
data: request.data
|
||||
});
|
||||
return sendToClient(
|
||||
remoteExitNode.remoteExitNodeId,
|
||||
{
|
||||
type: request.remoteType,
|
||||
data: request.data
|
||||
},
|
||||
{ incrementConfigVersion: true }
|
||||
);
|
||||
} else {
|
||||
let hostname = exitNode.reachableAt;
|
||||
|
||||
|
||||
@@ -573,6 +573,20 @@ class RedisManager {
|
||||
}
|
||||
}
|
||||
|
||||
public async incr(key: string): Promise<number> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return 0;
|
||||
|
||||
try {
|
||||
return await this.executeWithRetry(
|
||||
() => this.writeClient!.incr(key),
|
||||
"Redis INCR"
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Redis INCR error:", error);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
public async sadd(key: string, member: string): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ import {
|
||||
WSMessage,
|
||||
TokenPayload,
|
||||
WebSocketRequest,
|
||||
RedisMessage
|
||||
RedisMessage,
|
||||
SendMessageOptions
|
||||
} from "@server/routers/ws";
|
||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
|
||||
@@ -118,12 +119,21 @@ const processMessage = async (
|
||||
if (response.broadcast) {
|
||||
await broadcastToAllExcept(
|
||||
response.message,
|
||||
response.excludeSender ? clientId : undefined
|
||||
response.excludeSender ? clientId : undefined,
|
||||
response.options
|
||||
);
|
||||
} else if (response.targetClientId) {
|
||||
await sendToClient(response.targetClientId, response.message);
|
||||
await sendToClient(
|
||||
response.targetClientId,
|
||||
response.message,
|
||||
response.options
|
||||
);
|
||||
} else {
|
||||
ws.send(JSON.stringify(response.message));
|
||||
await sendToClient(
|
||||
clientId,
|
||||
response.message,
|
||||
response.options
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -172,6 +182,9 @@ const REDIS_CHANNEL = "websocket_messages";
|
||||
// Client tracking map (local to this node)
|
||||
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
|
||||
|
||||
// Config version tracking map (local to this node, resets on server restart)
|
||||
const clientConfigVersions: Map<string, number> = new Map();
|
||||
|
||||
// Recovery tracking
|
||||
let isRedisRecoveryInProgress = false;
|
||||
|
||||
@@ -182,6 +195,8 @@ const getClientMapKey = (clientId: string) => clientId;
|
||||
const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`;
|
||||
const getNodeConnectionsKey = (nodeId: string, clientId: string) =>
|
||||
`ws:node:${nodeId}:${clientId}`;
|
||||
const getConfigVersionKey = (clientId: string) =>
|
||||
`ws:configVersion:${clientId}`;
|
||||
|
||||
// Initialize Redis subscription for cross-node messaging
|
||||
const initializeRedisSubscription = async (): Promise<void> => {
|
||||
@@ -304,6 +319,45 @@ const addClient = async (
|
||||
existingClients.push(ws);
|
||||
connectedClients.set(mapKey, existingClients);
|
||||
|
||||
// Get or initialize config version
|
||||
let configVersion = 0;
|
||||
|
||||
// Check Redis first if enabled
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
const redisVersion = await redisManager.get(getConfigVersionKey(clientId));
|
||||
if (redisVersion !== null) {
|
||||
configVersion = parseInt(redisVersion, 10);
|
||||
// Sync to local cache
|
||||
clientConfigVersions.set(clientId, configVersion);
|
||||
} else if (!clientConfigVersions.has(clientId)) {
|
||||
// No version in Redis or local cache, initialize to 0
|
||||
await redisManager.set(getConfigVersionKey(clientId), "0");
|
||||
clientConfigVersions.set(clientId, 0);
|
||||
} else {
|
||||
// Use local cache version and sync to Redis
|
||||
configVersion = clientConfigVersions.get(clientId) || 0;
|
||||
await redisManager.set(getConfigVersionKey(clientId), configVersion.toString());
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Failed to get/set config version in Redis:", error);
|
||||
// Fall back to local cache
|
||||
if (!clientConfigVersions.has(clientId)) {
|
||||
clientConfigVersions.set(clientId, 0);
|
||||
}
|
||||
configVersion = clientConfigVersions.get(clientId) || 0;
|
||||
}
|
||||
} else {
|
||||
// Redis not enabled, use local cache only
|
||||
if (!clientConfigVersions.has(clientId)) {
|
||||
clientConfigVersions.set(clientId, 0);
|
||||
}
|
||||
configVersion = clientConfigVersions.get(clientId) || 0;
|
||||
}
|
||||
|
||||
// Set config version on websocket
|
||||
ws.configVersion = configVersion;
|
||||
|
||||
// Add to Redis tracking if enabled
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
@@ -322,7 +376,7 @@ const addClient = async (
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`
|
||||
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}, Config version: ${configVersion}`
|
||||
);
|
||||
};
|
||||
|
||||
@@ -377,53 +431,133 @@ const removeClient = async (
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to get the current config version for a client
|
||||
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
|
||||
// Try Redis first if available
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
const redisVersion = await redisManager.get(
|
||||
getConfigVersionKey(clientId)
|
||||
);
|
||||
if (redisVersion !== null) {
|
||||
const version = parseInt(redisVersion, 10);
|
||||
// Sync local cache with Redis
|
||||
clientConfigVersions.set(clientId, version);
|
||||
return version;
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Failed to get config version from Redis:", error);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to local cache
|
||||
return clientConfigVersions.get(clientId);
|
||||
};
|
||||
|
||||
// Helper to increment and get the new config version for a client
|
||||
const incrementClientConfigVersion = async (
|
||||
clientId: string
|
||||
): Promise<number> => {
|
||||
let newVersion: number;
|
||||
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
// Use Redis INCR for atomic increment across nodes
|
||||
newVersion = await redisManager.incr(getConfigVersionKey(clientId));
|
||||
// Sync local cache
|
||||
clientConfigVersions.set(clientId, newVersion);
|
||||
return newVersion;
|
||||
} catch (error) {
|
||||
logger.error("Failed to increment config version in Redis:", error);
|
||||
// Fall through to local increment
|
||||
}
|
||||
}
|
||||
|
||||
// Local increment
|
||||
const currentVersion = clientConfigVersions.get(clientId) || 0;
|
||||
newVersion = currentVersion + 1;
|
||||
clientConfigVersions.set(clientId, newVersion);
|
||||
return newVersion;
|
||||
};
|
||||
|
||||
// Local message sending (within this node)
|
||||
const sendToClientLocal = async (
|
||||
clientId: string,
|
||||
message: WSMessage
|
||||
message: WSMessage,
|
||||
options: SendMessageOptions = {}
|
||||
): Promise<boolean> => {
|
||||
const mapKey = getClientMapKey(clientId);
|
||||
const clients = connectedClients.get(mapKey);
|
||||
if (!clients || clients.length === 0) {
|
||||
return false;
|
||||
}
|
||||
const messageString = JSON.stringify(message);
|
||||
|
||||
// Handle config version
|
||||
let configVersion = await getClientConfigVersion(clientId);
|
||||
|
||||
// Add config version to message
|
||||
const messageWithVersion = {
|
||||
...message,
|
||||
configVersion
|
||||
};
|
||||
|
||||
const messageString = JSON.stringify(messageWithVersion);
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(messageString);
|
||||
}
|
||||
});
|
||||
|
||||
logger.debug(
|
||||
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
|
||||
);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
const broadcastToAllExceptLocal = async (
|
||||
message: WSMessage,
|
||||
excludeClientId?: string
|
||||
excludeClientId?: string,
|
||||
options: SendMessageOptions = {}
|
||||
): Promise<void> => {
|
||||
connectedClients.forEach((clients, mapKey) => {
|
||||
for (const [mapKey, clients] of connectedClients.entries()) {
|
||||
const [type, id] = mapKey.split(":");
|
||||
if (!(excludeClientId && id === excludeClientId)) {
|
||||
const clientId = mapKey; // mapKey is the clientId
|
||||
if (!(excludeClientId && clientId === excludeClientId)) {
|
||||
// Handle config version per client
|
||||
let configVersion = await getClientConfigVersion(clientId);
|
||||
if (options.incrementConfigVersion) {
|
||||
configVersion = await incrementClientConfigVersion(clientId);
|
||||
}
|
||||
|
||||
// Add config version to message
|
||||
const messageWithVersion = {
|
||||
...message,
|
||||
configVersion
|
||||
};
|
||||
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(JSON.stringify(message));
|
||||
client.send(JSON.stringify(messageWithVersion));
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Cross-node message sending (via Redis)
|
||||
const sendToClient = async (
|
||||
clientId: string,
|
||||
message: WSMessage
|
||||
message: WSMessage,
|
||||
options: SendMessageOptions = {}
|
||||
): Promise<boolean> => {
|
||||
let configVersion = await getClientConfigVersion(clientId);
|
||||
if (options.incrementConfigVersion) {
|
||||
configVersion = await incrementClientConfigVersion(clientId);
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`sendToClient: Message type ${message.type} sent to clientId ${clientId} (new configVersion: ${configVersion})`
|
||||
);
|
||||
|
||||
// Try to send locally first
|
||||
const localSent = await sendToClientLocal(clientId, message);
|
||||
const localSent = await sendToClientLocal(clientId, message, options);
|
||||
|
||||
// Only send via Redis if the client is not connected locally and Redis is enabled
|
||||
if (!localSent && redisManager.isRedisEnabled()) {
|
||||
@@ -431,7 +565,10 @@ const sendToClient = async (
|
||||
const redisMessage: RedisMessage = {
|
||||
type: "direct",
|
||||
targetClientId: clientId,
|
||||
message,
|
||||
message: {
|
||||
...message,
|
||||
configVersion
|
||||
},
|
||||
fromNodeId: NODE_ID
|
||||
};
|
||||
|
||||
@@ -458,19 +595,22 @@ const sendToClient = async (
|
||||
|
||||
const broadcastToAllExcept = async (
|
||||
message: WSMessage,
|
||||
excludeClientId?: string
|
||||
excludeClientId?: string,
|
||||
options: SendMessageOptions = {}
|
||||
): Promise<void> => {
|
||||
// Broadcast locally
|
||||
await broadcastToAllExceptLocal(message, excludeClientId);
|
||||
await broadcastToAllExceptLocal(message, excludeClientId, options);
|
||||
|
||||
// If Redis is enabled, also broadcast via Redis pub/sub to other nodes
|
||||
// Note: For broadcasts, we include the options so remote nodes can handle versioning
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
const redisMessage: RedisMessage = {
|
||||
type: "broadcast",
|
||||
excludeClientId,
|
||||
message,
|
||||
fromNodeId: NODE_ID
|
||||
fromNodeId: NODE_ID,
|
||||
options
|
||||
};
|
||||
|
||||
await redisManager.publish(
|
||||
@@ -936,5 +1076,6 @@ export {
|
||||
getActiveNodes,
|
||||
disconnectClient,
|
||||
NODE_ID,
|
||||
cleanup
|
||||
cleanup,
|
||||
getClientConfigVersion
|
||||
};
|
||||
|
||||
@@ -28,7 +28,7 @@ export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
|
||||
await sendToClient(newtId, {
|
||||
type: `newt/wg/targets/add`,
|
||||
data: batches[i]
|
||||
});
|
||||
}, { incrementConfigVersion: true });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ export async function removeTargets(
|
||||
await sendToClient(newtId, {
|
||||
type: `newt/wg/targets/remove`,
|
||||
data: batches[i]
|
||||
});
|
||||
},{ incrementConfigVersion: true });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ export async function updateTargets(
|
||||
oldTargets: oldBatches[i] || [],
|
||||
newTargets: newBatches[i] || []
|
||||
}
|
||||
}).catch((error) => {
|
||||
}, { incrementConfigVersion: true }).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
@@ -101,7 +101,7 @@ export async function addPeerData(
|
||||
remoteSubnets: remoteSubnets,
|
||||
aliases: aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
}, { incrementConfigVersion: true }).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
@@ -132,7 +132,7 @@ export async function removePeerData(
|
||||
remoteSubnets: remoteSubnets,
|
||||
aliases: aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
}, { incrementConfigVersion: true }).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
@@ -173,7 +173,7 @@ export async function updatePeerData(
|
||||
...remoteSubnets,
|
||||
...aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
}, { incrementConfigVersion: true }).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
|
||||
278
server/routers/newt/buildConfiguration.ts
Normal file
278
server/routers/newt/buildConfiguration.ts
Normal file
@@ -0,0 +1,278 @@
|
||||
import { clients, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, ExitNode, resources, Site, siteResources, targetHealthCheck, targets } from "@server/db";
|
||||
import logger from "@server/logger";
|
||||
import { initPeerAddHandshake, updatePeer } from "../olm/peers";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import config from "@server/lib/config";
|
||||
import { generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip";
|
||||
|
||||
export async function buildClientConfigurationForNewtClient(
|
||||
site: Site,
|
||||
exitNode?: ExitNode
|
||||
) {
|
||||
const siteId = site.siteId;
|
||||
|
||||
// Get all clients connected to this site
|
||||
const clientsRes = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
eq(clients.clientId, clientSitesAssociationsCache.clientId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.siteId, siteId));
|
||||
|
||||
let peers: Array<{
|
||||
publicKey: string;
|
||||
allowedIps: string[];
|
||||
endpoint?: string;
|
||||
}> = [];
|
||||
|
||||
if (site.publicKey && site.endpoint && exitNode) {
|
||||
// Prepare peers data for the response
|
||||
peers = await Promise.all(
|
||||
clientsRes
|
||||
.filter((client) => {
|
||||
if (!client.clients.pubKey) {
|
||||
logger.warn(
|
||||
`Client ${client.clients.clientId} has no public key, skipping`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
if (!client.clients.subnet) {
|
||||
logger.warn(
|
||||
`Client ${client.clients.clientId} has no subnet, skipping`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.map(async (client) => {
|
||||
// Add or update this peer on the olm if it is connected
|
||||
|
||||
// const allSiteResources = await db // only get the site resources that this client has access to
|
||||
// .select()
|
||||
// .from(siteResources)
|
||||
// .innerJoin(
|
||||
// clientSiteResourcesAssociationsCache,
|
||||
// eq(
|
||||
// siteResources.siteResourceId,
|
||||
// clientSiteResourcesAssociationsCache.siteResourceId
|
||||
// )
|
||||
// )
|
||||
// .where(
|
||||
// and(
|
||||
// eq(siteResources.siteId, site.siteId),
|
||||
// eq(
|
||||
// clientSiteResourcesAssociationsCache.clientId,
|
||||
// client.clients.clientId
|
||||
// )
|
||||
// )
|
||||
// );
|
||||
|
||||
// update the peer info on the olm
|
||||
// if the peer has not been added yet this will be a no-op
|
||||
await updatePeer(client.clients.clientId, {
|
||||
siteId: site.siteId,
|
||||
endpoint: site.endpoint!,
|
||||
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
|
||||
publicKey: site.publicKey!,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort
|
||||
// remoteSubnets: generateRemoteSubnets(
|
||||
// allSiteResources.map(
|
||||
// ({ siteResources }) => siteResources
|
||||
// )
|
||||
// ),
|
||||
// aliases: generateAliasConfig(
|
||||
// allSiteResources.map(
|
||||
// ({ siteResources }) => siteResources
|
||||
// )
|
||||
// )
|
||||
});
|
||||
|
||||
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
|
||||
// if it has already been added this will be a no-op
|
||||
await initPeerAddHandshake(
|
||||
// this will kick off the add peer process for the client
|
||||
client.clients.clientId,
|
||||
{
|
||||
siteId,
|
||||
exitNode: {
|
||||
publicKey: exitNode.publicKey,
|
||||
endpoint: exitNode.endpoint
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
return {
|
||||
publicKey: client.clients.pubKey!,
|
||||
allowedIps: [
|
||||
`${client.clients.subnet.split("/")[0]}/32`
|
||||
], // we want to only allow from that client
|
||||
endpoint: client.clientSitesAssociationsCache.isRelayed
|
||||
? ""
|
||||
: client.clientSitesAssociationsCache.endpoint! // if its relayed it should be localhost
|
||||
};
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Filter out any null values from peers that didn't have an olm
|
||||
const validPeers = peers.filter((peer) => peer !== null);
|
||||
|
||||
// Get all enabled site resources for this site
|
||||
const allSiteResources = await db
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.where(eq(siteResources.siteId, siteId));
|
||||
|
||||
const targetsToSend: SubnetProxyTarget[] = [];
|
||||
|
||||
for (const resource of allSiteResources) {
|
||||
// Get clients associated with this specific resource
|
||||
const resourceClients = await db
|
||||
.select({
|
||||
clientId: clients.clientId,
|
||||
pubKey: clients.pubKey,
|
||||
subnet: clients.subnet
|
||||
})
|
||||
.from(clients)
|
||||
.innerJoin(
|
||||
clientSiteResourcesAssociationsCache,
|
||||
eq(
|
||||
clients.clientId,
|
||||
clientSiteResourcesAssociationsCache.clientId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.siteResourceId,
|
||||
resource.siteResourceId
|
||||
)
|
||||
);
|
||||
|
||||
const resourceTargets = generateSubnetProxyTargets(
|
||||
resource,
|
||||
resourceClients
|
||||
);
|
||||
|
||||
targetsToSend.push(...resourceTargets);
|
||||
}
|
||||
|
||||
return {
|
||||
peers: validPeers,
|
||||
targets: targetsToSend
|
||||
};
|
||||
}
|
||||
|
||||
export async function buildTargetConfigurationForNewtClient(siteId: number) {
|
||||
// Get all enabled targets with their resource protocol information
|
||||
const allTargets = await db
|
||||
.select({
|
||||
resourceId: targets.resourceId,
|
||||
targetId: targets.targetId,
|
||||
ip: targets.ip,
|
||||
method: targets.method,
|
||||
port: targets.port,
|
||||
internalPort: targets.internalPort,
|
||||
enabled: targets.enabled,
|
||||
protocol: resources.protocol,
|
||||
hcEnabled: targetHealthCheck.hcEnabled,
|
||||
hcPath: targetHealthCheck.hcPath,
|
||||
hcScheme: targetHealthCheck.hcScheme,
|
||||
hcMode: targetHealthCheck.hcMode,
|
||||
hcHostname: targetHealthCheck.hcHostname,
|
||||
hcPort: targetHealthCheck.hcPort,
|
||||
hcInterval: targetHealthCheck.hcInterval,
|
||||
hcUnhealthyInterval: targetHealthCheck.hcUnhealthyInterval,
|
||||
hcTimeout: targetHealthCheck.hcTimeout,
|
||||
hcHeaders: targetHealthCheck.hcHeaders,
|
||||
hcMethod: targetHealthCheck.hcMethod,
|
||||
hcTlsServerName: targetHealthCheck.hcTlsServerName
|
||||
})
|
||||
.from(targets)
|
||||
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
|
||||
.leftJoin(
|
||||
targetHealthCheck,
|
||||
eq(targets.targetId, targetHealthCheck.targetId)
|
||||
)
|
||||
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
|
||||
|
||||
const { tcpTargets, udpTargets } = allTargets.reduce(
|
||||
(acc, target) => {
|
||||
// Filter out invalid targets
|
||||
if (!target.internalPort || !target.ip || !target.port) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
// Format target into string
|
||||
const formattedTarget = `${target.internalPort}:${target.ip}:${target.port}`;
|
||||
|
||||
// Add to the appropriate protocol array
|
||||
if (target.protocol === "tcp") {
|
||||
acc.tcpTargets.push(formattedTarget);
|
||||
} else {
|
||||
acc.udpTargets.push(formattedTarget);
|
||||
}
|
||||
|
||||
return acc;
|
||||
},
|
||||
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
|
||||
);
|
||||
|
||||
const healthCheckTargets = allTargets.map((target) => {
|
||||
// make sure the stuff is defined
|
||||
if (
|
||||
!target.hcPath ||
|
||||
!target.hcHostname ||
|
||||
!target.hcPort ||
|
||||
!target.hcInterval ||
|
||||
!target.hcMethod
|
||||
) {
|
||||
logger.debug(
|
||||
`Skipping target ${target.targetId} due to missing health check fields`
|
||||
);
|
||||
return null; // Skip targets with missing health check fields
|
||||
}
|
||||
|
||||
// parse headers
|
||||
const hcHeadersParse = target.hcHeaders
|
||||
? JSON.parse(target.hcHeaders)
|
||||
: null;
|
||||
const hcHeadersSend: { [key: string]: string } = {};
|
||||
if (hcHeadersParse) {
|
||||
hcHeadersParse.forEach(
|
||||
(header: { name: string; value: string }) => {
|
||||
hcHeadersSend[header.name] = header.value;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
id: target.targetId,
|
||||
hcEnabled: target.hcEnabled,
|
||||
hcPath: target.hcPath,
|
||||
hcScheme: target.hcScheme,
|
||||
hcMode: target.hcMode,
|
||||
hcHostname: target.hcHostname,
|
||||
hcPort: target.hcPort,
|
||||
hcInterval: target.hcInterval, // in seconds
|
||||
hcUnhealthyInterval: target.hcUnhealthyInterval, // in seconds
|
||||
hcTimeout: target.hcTimeout, // in seconds
|
||||
hcHeaders: hcHeadersSend,
|
||||
hcMethod: target.hcMethod,
|
||||
hcTlsServerName: target.hcTlsServerName
|
||||
};
|
||||
});
|
||||
|
||||
// Filter out any null values from health check targets
|
||||
const validHealthCheckTargets = healthCheckTargets.filter(
|
||||
(target) => target !== null
|
||||
);
|
||||
|
||||
return {
|
||||
validHealthCheckTargets,
|
||||
tcpTargets,
|
||||
udpTargets
|
||||
};
|
||||
}
|
||||
@@ -2,19 +2,10 @@ import { z } from "zod";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import {
|
||||
db,
|
||||
ExitNode,
|
||||
exitNodes,
|
||||
siteResources,
|
||||
clientSiteResourcesAssociationsCache
|
||||
} from "@server/db";
|
||||
import { clients, clientSitesAssociationsCache, Newt, sites } from "@server/db";
|
||||
import { db, ExitNode, exitNodes, Newt, sites } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { initPeerAddHandshake, updatePeer } from "../olm/peers";
|
||||
import { sendToExitNode } from "#dynamic/lib/exitNodes";
|
||||
import { generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip";
|
||||
import config from "@server/lib/config";
|
||||
import { buildClientConfigurationForNewtClient } from "./buildConfiguration";
|
||||
|
||||
const inputSchema = z.object({
|
||||
publicKey: z.string(),
|
||||
@@ -130,167 +121,18 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
|
||||
}
|
||||
}
|
||||
|
||||
// Get all clients connected to this site
|
||||
const clientsRes = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
eq(clients.clientId, clientSitesAssociationsCache.clientId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.siteId, siteId));
|
||||
const { peers, targets } = await buildClientConfigurationForNewtClient(
|
||||
site,
|
||||
exitNode
|
||||
);
|
||||
|
||||
let peers: Array<{
|
||||
publicKey: string;
|
||||
allowedIps: string[];
|
||||
endpoint?: string;
|
||||
}> = [];
|
||||
|
||||
if (site.publicKey && site.endpoint && exitNode) {
|
||||
// Prepare peers data for the response
|
||||
peers = await Promise.all(
|
||||
clientsRes
|
||||
.filter((client) => {
|
||||
if (!client.clients.pubKey) {
|
||||
logger.warn(
|
||||
`Client ${client.clients.clientId} has no public key, skipping`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
if (!client.clients.subnet) {
|
||||
logger.warn(
|
||||
`Client ${client.clients.clientId} has no subnet, skipping`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.map(async (client) => {
|
||||
// Add or update this peer on the olm if it is connected
|
||||
|
||||
// const allSiteResources = await db // only get the site resources that this client has access to
|
||||
// .select()
|
||||
// .from(siteResources)
|
||||
// .innerJoin(
|
||||
// clientSiteResourcesAssociationsCache,
|
||||
// eq(
|
||||
// siteResources.siteResourceId,
|
||||
// clientSiteResourcesAssociationsCache.siteResourceId
|
||||
// )
|
||||
// )
|
||||
// .where(
|
||||
// and(
|
||||
// eq(siteResources.siteId, site.siteId),
|
||||
// eq(
|
||||
// clientSiteResourcesAssociationsCache.clientId,
|
||||
// client.clients.clientId
|
||||
// )
|
||||
// )
|
||||
// );
|
||||
|
||||
// update the peer info on the olm
|
||||
// if the peer has not been added yet this will be a no-op
|
||||
await updatePeer(client.clients.clientId, {
|
||||
siteId: site.siteId,
|
||||
endpoint: site.endpoint!,
|
||||
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
|
||||
publicKey: site.publicKey!,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort
|
||||
// remoteSubnets: generateRemoteSubnets(
|
||||
// allSiteResources.map(
|
||||
// ({ siteResources }) => siteResources
|
||||
// )
|
||||
// ),
|
||||
// aliases: generateAliasConfig(
|
||||
// allSiteResources.map(
|
||||
// ({ siteResources }) => siteResources
|
||||
// )
|
||||
// )
|
||||
});
|
||||
|
||||
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
|
||||
// if it has already been added this will be a no-op
|
||||
await initPeerAddHandshake(
|
||||
// this will kick off the add peer process for the client
|
||||
client.clients.clientId,
|
||||
{
|
||||
siteId,
|
||||
exitNode: {
|
||||
publicKey: exitNode.publicKey,
|
||||
endpoint: exitNode.endpoint
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
return {
|
||||
publicKey: client.clients.pubKey!,
|
||||
allowedIps: [
|
||||
`${client.clients.subnet.split("/")[0]}/32`
|
||||
], // we want to only allow from that client
|
||||
endpoint: client.clientSitesAssociationsCache.isRelayed
|
||||
? ""
|
||||
: client.clientSitesAssociationsCache.endpoint! // if its relayed it should be localhost
|
||||
};
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Filter out any null values from peers that didn't have an olm
|
||||
const validPeers = peers.filter((peer) => peer !== null);
|
||||
|
||||
// Get all enabled site resources for this site
|
||||
const allSiteResources = await db
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.where(eq(siteResources.siteId, siteId));
|
||||
|
||||
const targetsToSend: SubnetProxyTarget[] = [];
|
||||
|
||||
for (const resource of allSiteResources) {
|
||||
// Get clients associated with this specific resource
|
||||
const resourceClients = await db
|
||||
.select({
|
||||
clientId: clients.clientId,
|
||||
pubKey: clients.pubKey,
|
||||
subnet: clients.subnet
|
||||
})
|
||||
.from(clients)
|
||||
.innerJoin(
|
||||
clientSiteResourcesAssociationsCache,
|
||||
eq(
|
||||
clients.clientId,
|
||||
clientSiteResourcesAssociationsCache.clientId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.siteResourceId,
|
||||
resource.siteResourceId
|
||||
)
|
||||
);
|
||||
|
||||
const resourceTargets = generateSubnetProxyTargets(
|
||||
resource,
|
||||
resourceClients
|
||||
);
|
||||
|
||||
targetsToSend.push(...resourceTargets);
|
||||
}
|
||||
|
||||
// Build the configuration response
|
||||
const configResponse = {
|
||||
ipAddress: site.address,
|
||||
peers: validPeers,
|
||||
targets: targetsToSend
|
||||
};
|
||||
|
||||
logger.debug("Sending config: ", configResponse);
|
||||
return {
|
||||
message: {
|
||||
type: "newt/wg/receive-config",
|
||||
data: {
|
||||
...configResponse
|
||||
ipAddress: site.address,
|
||||
peers,
|
||||
targets
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
|
||||
163
server/routers/newt/handleNewtPingMessage.ts
Normal file
163
server/routers/newt/handleNewtPingMessage.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
import { db, sites } from "@server/db";
|
||||
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { clients, Newt } from "@server/db";
|
||||
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
import { sendTerminateClient } from "../client/terminate";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { sendNewtSyncMessage } from "./sync";
|
||||
|
||||
// Track if the offline checker interval is running
|
||||
// let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
||||
// const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
|
||||
// const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
|
||||
|
||||
/**
|
||||
* Starts the background interval that checks for clients that haven't pinged recently
|
||||
* and marks them as offline
|
||||
*/
|
||||
// export const startNewtOfflineChecker = (): void => {
|
||||
// if (offlineCheckerInterval) {
|
||||
// return; // Already running
|
||||
// }
|
||||
|
||||
// offlineCheckerInterval = setInterval(async () => {
|
||||
// try {
|
||||
// const twoMinutesAgo = Math.floor(
|
||||
// (Date.now() - OFFLINE_THRESHOLD_MS) / 1000
|
||||
// );
|
||||
|
||||
// // TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING
|
||||
|
||||
// // Find clients that haven't pinged in the last 2 minutes and mark them as offline
|
||||
// const offlineClients = await db
|
||||
// .update(clients)
|
||||
// .set({ online: false })
|
||||
// .where(
|
||||
// and(
|
||||
// eq(clients.online, true),
|
||||
// or(
|
||||
// lt(clients.lastPing, twoMinutesAgo),
|
||||
// isNull(clients.lastPing)
|
||||
// )
|
||||
// )
|
||||
// )
|
||||
// .returning();
|
||||
|
||||
// for (const offlineClient of offlineClients) {
|
||||
// logger.info(
|
||||
// `Kicking offline newt client ${offlineClient.clientId} due to inactivity`
|
||||
// );
|
||||
|
||||
// if (!offlineClient.newtId) {
|
||||
// logger.warn(
|
||||
// `Offline client ${offlineClient.clientId} has no newtId, cannot disconnect`
|
||||
// );
|
||||
// continue;
|
||||
// }
|
||||
|
||||
// // Send a disconnect message to the client if connected
|
||||
// try {
|
||||
// await sendTerminateClient(
|
||||
// offlineClient.clientId,
|
||||
// offlineClient.newtId
|
||||
// ); // terminate first
|
||||
// // wait a moment to ensure the message is sent
|
||||
// await new Promise((resolve) => setTimeout(resolve, 1000));
|
||||
// await disconnectClient(offlineClient.newtId);
|
||||
// } catch (error) {
|
||||
// logger.error(
|
||||
// `Error sending disconnect to offline newt ${offlineClient.clientId}`,
|
||||
// { error }
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
// } catch (error) {
|
||||
// logger.error("Error in offline checker interval", { error });
|
||||
// }
|
||||
// }, OFFLINE_CHECK_INTERVAL);
|
||||
|
||||
// logger.debug("Started offline checker interval");
|
||||
// };
|
||||
|
||||
/**
|
||||
* Stops the background interval that checks for offline clients
|
||||
*/
|
||||
// export const stopNewtOfflineChecker = (): void => {
|
||||
// if (offlineCheckerInterval) {
|
||||
// clearInterval(offlineCheckerInterval);
|
||||
// offlineCheckerInterval = null;
|
||||
// logger.info("Stopped offline checker interval");
|
||||
// }
|
||||
// };
|
||||
|
||||
/**
|
||||
* Handles ping messages from clients and responds with pong
|
||||
*/
|
||||
export const handleNewtPingMessage: MessageHandler = async (context) => {
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const newt = c as Newt;
|
||||
|
||||
if (!newt) {
|
||||
logger.warn("Newt ping message: Newt not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!newt.siteId) {
|
||||
logger.warn("Newt ping message: has no site ID");
|
||||
return;
|
||||
}
|
||||
|
||||
// get the version
|
||||
const configVersion = await getClientConfigVersion(newt.newtId);
|
||||
|
||||
if (message.configVersion && configVersion != null && configVersion != message.configVersion) {
|
||||
logger.warn(
|
||||
`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
|
||||
);
|
||||
|
||||
// get the site
|
||||
const [site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, newt.siteId))
|
||||
.limit(1);
|
||||
|
||||
if (!site) {
|
||||
logger.warn(
|
||||
`Newt ping message: site with ID ${newt.siteId} not found`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
await sendNewtSyncMessage(newt, site);
|
||||
}
|
||||
|
||||
// try {
|
||||
// // Update the client's last ping timestamp
|
||||
// await db
|
||||
// .update(clients)
|
||||
// .set({
|
||||
// lastPing: Math.floor(Date.now() / 1000),
|
||||
// online: true
|
||||
// })
|
||||
// .where(eq(clients.clientId, newt.clientId));
|
||||
// } catch (error) {
|
||||
// logger.error("Error handling ping message", { error });
|
||||
// }
|
||||
|
||||
return {
|
||||
message: {
|
||||
type: "pong",
|
||||
data: {
|
||||
timestamp: new Date().toISOString()
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
excludeSender: false
|
||||
};
|
||||
};
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
} from "#dynamic/lib/exitNodes";
|
||||
import { fetchContainers } from "./dockerSocket";
|
||||
import { lockManager } from "#dynamic/lib/lock";
|
||||
import { buildTargetConfigurationForNewtClient } from "./buildConfiguration";
|
||||
|
||||
export type ExitNodePingResult = {
|
||||
exitNodeId: number;
|
||||
@@ -233,109 +234,8 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
||||
.where(eq(newts.newtId, newt.newtId));
|
||||
}
|
||||
|
||||
// Get all enabled targets with their resource protocol information
|
||||
const allTargets = await db
|
||||
.select({
|
||||
resourceId: targets.resourceId,
|
||||
targetId: targets.targetId,
|
||||
ip: targets.ip,
|
||||
method: targets.method,
|
||||
port: targets.port,
|
||||
internalPort: targets.internalPort,
|
||||
enabled: targets.enabled,
|
||||
protocol: resources.protocol,
|
||||
hcEnabled: targetHealthCheck.hcEnabled,
|
||||
hcPath: targetHealthCheck.hcPath,
|
||||
hcScheme: targetHealthCheck.hcScheme,
|
||||
hcMode: targetHealthCheck.hcMode,
|
||||
hcHostname: targetHealthCheck.hcHostname,
|
||||
hcPort: targetHealthCheck.hcPort,
|
||||
hcInterval: targetHealthCheck.hcInterval,
|
||||
hcUnhealthyInterval: targetHealthCheck.hcUnhealthyInterval,
|
||||
hcTimeout: targetHealthCheck.hcTimeout,
|
||||
hcHeaders: targetHealthCheck.hcHeaders,
|
||||
hcMethod: targetHealthCheck.hcMethod,
|
||||
hcTlsServerName: targetHealthCheck.hcTlsServerName
|
||||
})
|
||||
.from(targets)
|
||||
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
|
||||
.leftJoin(
|
||||
targetHealthCheck,
|
||||
eq(targets.targetId, targetHealthCheck.targetId)
|
||||
)
|
||||
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
|
||||
|
||||
const { tcpTargets, udpTargets } = allTargets.reduce(
|
||||
(acc, target) => {
|
||||
// Filter out invalid targets
|
||||
if (!target.internalPort || !target.ip || !target.port) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
// Format target into string
|
||||
const formattedTarget = `${target.internalPort}:${target.ip}:${target.port}`;
|
||||
|
||||
// Add to the appropriate protocol array
|
||||
if (target.protocol === "tcp") {
|
||||
acc.tcpTargets.push(formattedTarget);
|
||||
} else {
|
||||
acc.udpTargets.push(formattedTarget);
|
||||
}
|
||||
|
||||
return acc;
|
||||
},
|
||||
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
|
||||
);
|
||||
|
||||
const healthCheckTargets = allTargets.map((target) => {
|
||||
// make sure the stuff is defined
|
||||
if (
|
||||
!target.hcPath ||
|
||||
!target.hcHostname ||
|
||||
!target.hcPort ||
|
||||
!target.hcInterval ||
|
||||
!target.hcMethod
|
||||
) {
|
||||
logger.debug(
|
||||
`Skipping target ${target.targetId} due to missing health check fields`
|
||||
);
|
||||
return null; // Skip targets with missing health check fields
|
||||
}
|
||||
|
||||
// parse headers
|
||||
const hcHeadersParse = target.hcHeaders
|
||||
? JSON.parse(target.hcHeaders)
|
||||
: null;
|
||||
const hcHeadersSend: { [key: string]: string } = {};
|
||||
if (hcHeadersParse) {
|
||||
hcHeadersParse.forEach(
|
||||
(header: { name: string; value: string }) => {
|
||||
hcHeadersSend[header.name] = header.value;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
id: target.targetId,
|
||||
hcEnabled: target.hcEnabled,
|
||||
hcPath: target.hcPath,
|
||||
hcScheme: target.hcScheme,
|
||||
hcMode: target.hcMode,
|
||||
hcHostname: target.hcHostname,
|
||||
hcPort: target.hcPort,
|
||||
hcInterval: target.hcInterval, // in seconds
|
||||
hcUnhealthyInterval: target.hcUnhealthyInterval, // in seconds
|
||||
hcTimeout: target.hcTimeout, // in seconds
|
||||
hcHeaders: hcHeadersSend,
|
||||
hcMethod: target.hcMethod,
|
||||
hcTlsServerName: target.hcTlsServerName
|
||||
};
|
||||
});
|
||||
|
||||
// Filter out any null values from health check targets
|
||||
const validHealthCheckTargets = healthCheckTargets.filter(
|
||||
(target) => target !== null
|
||||
);
|
||||
const { tcpTargets, udpTargets, validHealthCheckTargets } =
|
||||
await buildTargetConfigurationForNewtClient(siteId);
|
||||
|
||||
logger.debug(
|
||||
`Sending health check targets to newt ${newt.newtId}: ${JSON.stringify(validHealthCheckTargets)}`
|
||||
|
||||
@@ -6,3 +6,4 @@ export * from "./handleGetConfigMessage";
|
||||
export * from "./handleSocketMessages";
|
||||
export * from "./handleNewtPingRequestMessage";
|
||||
export * from "./handleApplyBlueprintMessage";
|
||||
export * from "./handleNewtPingMessage";
|
||||
|
||||
@@ -39,7 +39,7 @@ export async function addPeer(
|
||||
await sendToClient(newtId, {
|
||||
type: "newt/wg/peer/add",
|
||||
data: peer
|
||||
}).catch((error) => {
|
||||
}, { incrementConfigVersion: true }).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
@@ -81,7 +81,7 @@ export async function deletePeer(
|
||||
data: {
|
||||
publicKey
|
||||
}
|
||||
}).catch((error) => {
|
||||
}, { incrementConfigVersion: true }).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
@@ -128,7 +128,7 @@ export async function updatePeer(
|
||||
publicKey,
|
||||
...peer
|
||||
}
|
||||
}).catch((error) => {
|
||||
}, { incrementConfigVersion: true }).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
|
||||
41
server/routers/newt/sync.ts
Normal file
41
server/routers/newt/sync.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
import { ExitNode, exitNodes, Newt, Site, db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import logger from "@server/logger";
|
||||
import {
|
||||
buildClientConfigurationForNewtClient,
|
||||
buildTargetConfigurationForNewtClient
|
||||
} from "./buildConfiguration";
|
||||
|
||||
export async function sendNewtSyncMessage(newt: Newt, site: Site) {
|
||||
const { tcpTargets, udpTargets, validHealthCheckTargets } =
|
||||
await buildTargetConfigurationForNewtClient(site.siteId);
|
||||
|
||||
let exitNode: ExitNode | undefined;
|
||||
if (site.exitNodeId) {
|
||||
[exitNode] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
|
||||
.limit(1);
|
||||
}
|
||||
const { peers, targets } = await buildClientConfigurationForNewtClient(
|
||||
site,
|
||||
exitNode
|
||||
);
|
||||
|
||||
await sendToClient(newt.newtId, {
|
||||
type: "newt/sync",
|
||||
data: {
|
||||
proxyTargets: {
|
||||
udp: udpTargets,
|
||||
tcp: tcpTargets
|
||||
},
|
||||
healthCheckTargets: validHealthCheckTargets,
|
||||
peers: peers,
|
||||
clientTargets: targets
|
||||
}
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending newt sync message:`, error);
|
||||
});
|
||||
}
|
||||
@@ -22,7 +22,7 @@ export async function addTargets(
|
||||
data: {
|
||||
targets: payloadTargets
|
||||
}
|
||||
});
|
||||
}, { incrementConfigVersion: true });
|
||||
|
||||
// Create a map for quick lookup
|
||||
const healthCheckMap = new Map<number, TargetHealthCheck>();
|
||||
@@ -103,7 +103,7 @@ export async function addTargets(
|
||||
data: {
|
||||
targets: validHealthCheckTargets
|
||||
}
|
||||
});
|
||||
}, { incrementConfigVersion: true });
|
||||
}
|
||||
|
||||
export async function removeTargets(
|
||||
@@ -124,7 +124,7 @@ export async function removeTargets(
|
||||
data: {
|
||||
targets: payloadTargets
|
||||
}
|
||||
});
|
||||
}, { incrementConfigVersion: true });
|
||||
|
||||
const healthCheckTargets = targets.map((target) => {
|
||||
return target.targetId;
|
||||
@@ -135,5 +135,5 @@ export async function removeTargets(
|
||||
data: {
|
||||
ids: healthCheckTargets
|
||||
}
|
||||
});
|
||||
}, { incrementConfigVersion: true });
|
||||
}
|
||||
|
||||
145
server/routers/olm/buildConfiguration.ts
Normal file
145
server/routers/olm/buildConfiguration.ts
Normal file
@@ -0,0 +1,145 @@
|
||||
import { Client, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, exitNodes, siteResources, sites } from "@server/db";
|
||||
import { generateAliasConfig, generateRemoteSubnets } from "@server/lib/ip";
|
||||
import logger from "@server/logger";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { addPeer, deletePeer } from "../newt/peers";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
export async function buildSiteConfigurationForOlmClient(
|
||||
client: Client,
|
||||
publicKey: string | null,
|
||||
relay: boolean
|
||||
) {
|
||||
const siteConfigurations = [];
|
||||
|
||||
// Get all sites data
|
||||
const sitesData = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||
|
||||
// Process each site
|
||||
for (const {
|
||||
sites: site,
|
||||
clientSitesAssociationsCache: association
|
||||
} of sitesData) {
|
||||
if (!site.exitNodeId) {
|
||||
logger.warn(
|
||||
`Site ${site.siteId} does not have exit node, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate endpoint and hole punch status
|
||||
if (!site.endpoint) {
|
||||
logger.warn(
|
||||
`In olm register: site ${site.siteId} has no endpoint, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// if (site.lastHolePunch && now - site.lastHolePunch > 6 && relay) {
|
||||
// logger.warn(
|
||||
// `Site ${site.siteId} last hole punch is too old, skipping`
|
||||
// );
|
||||
// continue;
|
||||
// }
|
||||
|
||||
// If public key changed, delete old peer from this site
|
||||
if (client.pubKey && client.pubKey != publicKey) {
|
||||
logger.info(
|
||||
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
|
||||
);
|
||||
await deletePeer(site.siteId, client.pubKey!);
|
||||
}
|
||||
|
||||
if (!site.subnet) {
|
||||
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
|
||||
continue;
|
||||
}
|
||||
|
||||
const [clientSite] = await db
|
||||
.select()
|
||||
.from(clientSitesAssociationsCache)
|
||||
.where(
|
||||
and(
|
||||
eq(clientSitesAssociationsCache.clientId, client.clientId),
|
||||
eq(clientSitesAssociationsCache.siteId, site.siteId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
// Add the peer to the exit node for this site
|
||||
if (clientSite.endpoint && publicKey) {
|
||||
logger.info(
|
||||
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}`
|
||||
);
|
||||
await addPeer(site.siteId, {
|
||||
publicKey: publicKey,
|
||||
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
|
||||
endpoint: relay ? "" : clientSite.endpoint
|
||||
});
|
||||
} else {
|
||||
logger.warn(
|
||||
`Client ${client.clientId} has no endpoint, skipping peer addition`
|
||||
);
|
||||
}
|
||||
|
||||
let relayEndpoint: string | undefined = undefined;
|
||||
if (relay) {
|
||||
const [exitNode] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
|
||||
.limit(1);
|
||||
if (!exitNode) {
|
||||
logger.warn(`Exit node not found for site ${site.siteId}`);
|
||||
continue;
|
||||
}
|
||||
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
|
||||
}
|
||||
|
||||
const allSiteResources = await db // only get the site resources that this client has access to
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.innerJoin(
|
||||
clientSiteResourcesAssociationsCache,
|
||||
eq(
|
||||
siteResources.siteResourceId,
|
||||
clientSiteResourcesAssociationsCache.siteResourceId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(siteResources.siteId, site.siteId),
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.clientId,
|
||||
client.clientId
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
// Add site configuration to the array
|
||||
siteConfigurations.push({
|
||||
siteId: site.siteId,
|
||||
name: site.name,
|
||||
// relayEndpoint: relayEndpoint, // this can be undefined now if not relayed // lets not do this for now because it would conflict with the hole punch testing
|
||||
endpoint: site.endpoint,
|
||||
publicKey: site.publicKey,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort,
|
||||
remoteSubnets: generateRemoteSubnets(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
),
|
||||
aliases: generateAliasConfig(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
return siteConfigurations;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
|
||||
import { clientPostureSnapshots, db, fingerprints } from "@server/db";
|
||||
import { disconnectClient } from "#dynamic/routers/ws";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { clients, olms, Olm } from "@server/db";
|
||||
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
||||
@@ -9,6 +9,7 @@ import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
import { sendTerminateClient } from "../client/terminate";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { sendOlmSyncMessage } from "./sync";
|
||||
|
||||
// Track if the offline checker interval is running
|
||||
let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
||||
@@ -128,7 +129,9 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
|
||||
|
||||
if (client.blocked) {
|
||||
// NOTE: by returning we dont update the lastPing, so the offline checker will eventually disconnect them
|
||||
logger.debug(`Blocked client ${client.clientId} attempted olm ping`);
|
||||
logger.debug(
|
||||
`Blocked client ${client.clientId} attempted olm ping`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -167,6 +170,23 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
|
||||
}
|
||||
}
|
||||
|
||||
// get the version
|
||||
logger.debug(`handleOlmPingMessage: About to get config version for olmId: ${olm.olmId}`);
|
||||
const configVersion = await getClientConfigVersion(olm.olmId);
|
||||
logger.debug(`handleOlmPingMessage: Got config version: ${configVersion} (type: ${typeof configVersion})`);
|
||||
|
||||
if (configVersion == null || configVersion === undefined) {
|
||||
logger.debug(`handleOlmPingMessage: could not get config version from server for olmId: ${olm.olmId}`)
|
||||
}
|
||||
|
||||
if (message.configVersion != null && configVersion != null && configVersion != message.configVersion) {
|
||||
logger.debug(
|
||||
`handleOlmPingMessage: Olm ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
|
||||
);
|
||||
await sendOlmSyncMessage(olm, client);
|
||||
}
|
||||
|
||||
// Update the client's last ping timestamp
|
||||
await db
|
||||
.update(clients)
|
||||
.set({
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import {
|
||||
Client,
|
||||
clientPostureSnapshots,
|
||||
clientSiteResourcesAssociationsCache,
|
||||
db,
|
||||
@@ -15,7 +16,7 @@ import {
|
||||
olms,
|
||||
sites
|
||||
} from "@server/db";
|
||||
import { and, eq, inArray, isNull } from "drizzle-orm";
|
||||
import { and, count, eq, inArray, isNull } from "drizzle-orm";
|
||||
import { addPeer, deletePeer } from "../newt/peers";
|
||||
import logger from "@server/logger";
|
||||
import { generateAliasConfig } from "@server/lib/ip";
|
||||
@@ -25,6 +26,7 @@ import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
import config from "@server/lib/config";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
|
||||
|
||||
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
logger.info("Handling register olm message!");
|
||||
@@ -163,8 +165,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
}
|
||||
|
||||
// Get all sites data
|
||||
const sitesData = await db
|
||||
.select()
|
||||
const sitesCountResult = await db
|
||||
.select({ count: count() })
|
||||
.from(sites)
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
@@ -172,140 +174,29 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||
|
||||
// Extract the count value from the result array
|
||||
const sitesCount =
|
||||
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
|
||||
|
||||
// Prepare an array to store site configurations
|
||||
const siteConfigurations = [];
|
||||
logger.debug(
|
||||
`Found ${sitesData.length} sites for client ${client.clientId}`
|
||||
);
|
||||
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
|
||||
|
||||
// this prevents us from accepting a register from an olm that has not hole punched yet.
|
||||
// the olm will pump the register so we can keep checking
|
||||
// TODO: I still think there is a better way to do this rather than locking it out here but ???
|
||||
if (now - (client.lastHolePunch || 0) > 5 && sitesData.length > 0) {
|
||||
if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) {
|
||||
logger.warn(
|
||||
"Client last hole punch is too old and we have sites to send; skipping this register"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Process each site
|
||||
for (const {
|
||||
sites: site,
|
||||
clientSitesAssociationsCache: association
|
||||
} of sitesData) {
|
||||
if (!site.exitNodeId) {
|
||||
logger.warn(
|
||||
`Site ${site.siteId} does not have exit node, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate endpoint and hole punch status
|
||||
if (!site.endpoint) {
|
||||
logger.warn(
|
||||
`In olm register: site ${site.siteId} has no endpoint, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// if (site.lastHolePunch && now - site.lastHolePunch > 6 && relay) {
|
||||
// logger.warn(
|
||||
// `Site ${site.siteId} last hole punch is too old, skipping`
|
||||
// );
|
||||
// continue;
|
||||
// }
|
||||
|
||||
// If public key changed, delete old peer from this site
|
||||
if (client.pubKey && client.pubKey != publicKey) {
|
||||
logger.info(
|
||||
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
|
||||
);
|
||||
await deletePeer(site.siteId, client.pubKey!);
|
||||
}
|
||||
|
||||
if (!site.subnet) {
|
||||
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
|
||||
continue;
|
||||
}
|
||||
|
||||
const [clientSite] = await db
|
||||
.select()
|
||||
.from(clientSitesAssociationsCache)
|
||||
.where(
|
||||
and(
|
||||
eq(clientSitesAssociationsCache.clientId, client.clientId),
|
||||
eq(clientSitesAssociationsCache.siteId, site.siteId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
// Add the peer to the exit node for this site
|
||||
if (clientSite.endpoint) {
|
||||
logger.info(
|
||||
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}`
|
||||
);
|
||||
await addPeer(site.siteId, {
|
||||
publicKey: publicKey,
|
||||
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
|
||||
endpoint: relay ? "" : clientSite.endpoint
|
||||
});
|
||||
} else {
|
||||
logger.warn(
|
||||
`Client ${client.clientId} has no endpoint, skipping peer addition`
|
||||
);
|
||||
}
|
||||
|
||||
let relayEndpoint: string | undefined = undefined;
|
||||
if (relay) {
|
||||
const [exitNode] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
|
||||
.limit(1);
|
||||
if (!exitNode) {
|
||||
logger.warn(`Exit node not found for site ${site.siteId}`);
|
||||
continue;
|
||||
}
|
||||
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
|
||||
}
|
||||
|
||||
const allSiteResources = await db // only get the site resources that this client has access to
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.innerJoin(
|
||||
clientSiteResourcesAssociationsCache,
|
||||
eq(
|
||||
siteResources.siteResourceId,
|
||||
clientSiteResourcesAssociationsCache.siteResourceId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(siteResources.siteId, site.siteId),
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.clientId,
|
||||
client.clientId
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
// Add site configuration to the array
|
||||
siteConfigurations.push({
|
||||
siteId: site.siteId,
|
||||
name: site.name,
|
||||
// relayEndpoint: relayEndpoint, // this can be undefined now if not relayed // lets not do this for now because it would conflict with the hole punch testing
|
||||
endpoint: site.endpoint,
|
||||
publicKey: site.publicKey,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort,
|
||||
remoteSubnets: generateRemoteSubnets(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
),
|
||||
aliases: generateAliasConfig(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
)
|
||||
});
|
||||
}
|
||||
// NOTE: its important that the client here is the old client and the public key is the new key
|
||||
const siteConfigurations = await buildSiteConfigurationForOlmClient(
|
||||
client,
|
||||
publicKey,
|
||||
relay
|
||||
);
|
||||
|
||||
if (fingerprint) {
|
||||
const [existingFingerprint] = await db
|
||||
|
||||
@@ -32,20 +32,24 @@ export async function addPeer(
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/add",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
name: peer.name,
|
||||
publicKey: peer.publicKey,
|
||||
endpoint: peer.endpoint,
|
||||
relayEndpoint: peer.relayEndpoint,
|
||||
serverIP: peer.serverIP,
|
||||
serverPort: peer.serverPort,
|
||||
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
|
||||
aliases: peer.aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
await sendToClient(
|
||||
olmId,
|
||||
{
|
||||
type: "olm/wg/peer/add",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
name: peer.name,
|
||||
publicKey: peer.publicKey,
|
||||
endpoint: peer.endpoint,
|
||||
relayEndpoint: peer.relayEndpoint,
|
||||
serverIP: peer.serverIP,
|
||||
serverPort: peer.serverPort,
|
||||
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
|
||||
aliases: peer.aliases
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: true }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
@@ -70,13 +74,17 @@ export async function deletePeer(
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/remove",
|
||||
data: {
|
||||
publicKey,
|
||||
siteId: siteId
|
||||
}
|
||||
}).catch((error) => {
|
||||
await sendToClient(
|
||||
olmId,
|
||||
{
|
||||
type: "olm/wg/peer/remove",
|
||||
data: {
|
||||
publicKey,
|
||||
siteId: siteId
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: true }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
@@ -109,19 +117,23 @@ export async function updatePeer(
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/update",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
publicKey: peer.publicKey,
|
||||
endpoint: peer.endpoint,
|
||||
relayEndpoint: peer.relayEndpoint,
|
||||
serverIP: peer.serverIP,
|
||||
serverPort: peer.serverPort,
|
||||
remoteSubnets: peer.remoteSubnets,
|
||||
aliases: peer.aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
await sendToClient(
|
||||
olmId,
|
||||
{
|
||||
type: "olm/wg/peer/update",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
publicKey: peer.publicKey,
|
||||
endpoint: peer.endpoint,
|
||||
relayEndpoint: peer.relayEndpoint,
|
||||
serverIP: peer.serverIP,
|
||||
serverPort: peer.serverPort,
|
||||
remoteSubnets: peer.remoteSubnets,
|
||||
aliases: peer.aliases
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: true }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
@@ -151,17 +163,21 @@ export async function initPeerAddHandshake(
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/holepunch/site/add",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
exitNode: {
|
||||
publicKey: peer.exitNode.publicKey,
|
||||
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
||||
endpoint: peer.exitNode.endpoint
|
||||
await sendToClient(
|
||||
olmId,
|
||||
{
|
||||
type: "olm/wg/peer/holepunch/site/add",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
exitNode: {
|
||||
publicKey: peer.exitNode.publicKey,
|
||||
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
||||
endpoint: peer.exitNode.endpoint
|
||||
}
|
||||
}
|
||||
}
|
||||
}).catch((error) => {
|
||||
},
|
||||
{ incrementConfigVersion: true }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
|
||||
80
server/routers/olm/sync.ts
Normal file
80
server/routers/olm/sync.ts
Normal file
@@ -0,0 +1,80 @@
|
||||
import { Client, db, exitNodes, Olm, sites, clientSitesAssociationsCache } from "@server/db";
|
||||
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import logger from "@server/logger";
|
||||
import { eq, inArray } from "drizzle-orm";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
export async function sendOlmSyncMessage(olm: Olm, client: Client) {
|
||||
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
|
||||
const siteConfigurations = await buildSiteConfigurationForOlmClient(
|
||||
client,
|
||||
client.pubKey,
|
||||
false
|
||||
);
|
||||
|
||||
// Get all exit nodes from sites where the client has peers
|
||||
const clientSites = await db
|
||||
.select()
|
||||
.from(clientSitesAssociationsCache)
|
||||
.innerJoin(
|
||||
sites,
|
||||
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||
|
||||
// Extract unique exit node IDs
|
||||
const exitNodeIds = Array.from(
|
||||
new Set(
|
||||
clientSites
|
||||
.map(({ sites: site }) => site.exitNodeId)
|
||||
.filter((id): id is number => id !== null)
|
||||
)
|
||||
);
|
||||
|
||||
let exitNodesData: {
|
||||
publicKey: string;
|
||||
relayPort: number;
|
||||
endpoint: string;
|
||||
siteIds: number[];
|
||||
}[] = [];
|
||||
|
||||
if (exitNodeIds.length > 0) {
|
||||
const allExitNodes = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(inArray(exitNodes.exitNodeId, exitNodeIds));
|
||||
|
||||
// Map exitNodeId to siteIds
|
||||
const exitNodeIdToSiteIds: Record<number, number[]> = {};
|
||||
for (const { sites: site } of clientSites) {
|
||||
if (site.exitNodeId !== null) {
|
||||
if (!exitNodeIdToSiteIds[site.exitNodeId]) {
|
||||
exitNodeIdToSiteIds[site.exitNodeId] = [];
|
||||
}
|
||||
exitNodeIdToSiteIds[site.exitNodeId].push(site.siteId);
|
||||
}
|
||||
}
|
||||
|
||||
exitNodesData = allExitNodes.map((exitNode) => {
|
||||
return {
|
||||
publicKey: exitNode.publicKey,
|
||||
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
||||
endpoint: exitNode.endpoint,
|
||||
siteIds: exitNodeIdToSiteIds[exitNode.exitNodeId] ?? []
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
logger.debug("sendOlmSyncMessage: sending sync message")
|
||||
|
||||
await sendToClient(olm.olmId, {
|
||||
type: "olm/sync",
|
||||
data: {
|
||||
sites: siteConfigurations,
|
||||
exitNodes: exitNodesData
|
||||
}
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending olm sync message:`, error);
|
||||
});
|
||||
}
|
||||
@@ -5,7 +5,8 @@ import {
|
||||
handleDockerStatusMessage,
|
||||
handleDockerContainersMessage,
|
||||
handleNewtPingRequestMessage,
|
||||
handleApplyBlueprintMessage
|
||||
handleApplyBlueprintMessage,
|
||||
handleNewtPingMessage
|
||||
} from "../newt";
|
||||
import {
|
||||
handleOlmRegisterMessage,
|
||||
@@ -24,6 +25,7 @@ export const messageHandlers: Record<string, MessageHandler> = {
|
||||
"olm/wg/relay": handleOlmRelayMessage,
|
||||
"olm/wg/unrelay": handleOlmUnRelayMessage,
|
||||
"olm/ping": handleOlmPingMessage,
|
||||
"newt/ping": handleNewtPingMessage,
|
||||
"newt/wg/register": handleNewtRegisterMessage,
|
||||
"newt/wg/get-config": handleGetConfigMessage,
|
||||
"newt/receive-bandwidth": handleReceiveBandwidthMessage,
|
||||
|
||||
@@ -25,6 +25,7 @@ export interface AuthenticatedWebSocket extends WebSocket {
|
||||
connectionId?: string;
|
||||
isFullyConnected?: boolean;
|
||||
pendingMessages?: Buffer[];
|
||||
configVersion?: number;
|
||||
}
|
||||
|
||||
export interface TokenPayload {
|
||||
@@ -36,6 +37,7 @@ export interface TokenPayload {
|
||||
export interface WSMessage {
|
||||
type: string;
|
||||
data: any;
|
||||
configVersion?: number;
|
||||
}
|
||||
|
||||
export interface HandlerResponse {
|
||||
@@ -43,6 +45,7 @@ export interface HandlerResponse {
|
||||
broadcast?: boolean;
|
||||
excludeSender?: boolean;
|
||||
targetClientId?: string;
|
||||
options?: SendMessageOptions;
|
||||
}
|
||||
|
||||
export interface HandlerContext {
|
||||
@@ -50,10 +53,15 @@ export interface HandlerContext {
|
||||
senderWs: WebSocket;
|
||||
client: Newt | Olm | RemoteExitNode | undefined;
|
||||
clientType: ClientType;
|
||||
sendToClient: (clientId: string, message: WSMessage) => Promise<boolean>;
|
||||
sendToClient: (
|
||||
clientId: string,
|
||||
message: WSMessage,
|
||||
options?: SendMessageOptions
|
||||
) => Promise<boolean>;
|
||||
broadcastToAllExcept: (
|
||||
message: WSMessage,
|
||||
excludeClientId?: string
|
||||
excludeClientId?: string,
|
||||
options?: SendMessageOptions
|
||||
) => Promise<void>;
|
||||
connectedClients: Map<string, WebSocket[]>;
|
||||
}
|
||||
@@ -62,6 +70,11 @@ export type MessageHandler = (
|
||||
context: HandlerContext
|
||||
) => Promise<HandlerResponse | void>;
|
||||
|
||||
// Options for sending messages with config version tracking
|
||||
export interface SendMessageOptions {
|
||||
incrementConfigVersion?: boolean;
|
||||
}
|
||||
|
||||
// Redis message type for cross-node communication
|
||||
export interface RedisMessage {
|
||||
type: "direct" | "broadcast";
|
||||
@@ -69,4 +82,5 @@ export interface RedisMessage {
|
||||
excludeClientId?: string;
|
||||
message: WSMessage;
|
||||
fromNodeId: string;
|
||||
options?: SendMessageOptions;
|
||||
}
|
||||
|
||||
@@ -15,7 +15,8 @@ import {
|
||||
TokenPayload,
|
||||
WebSocketRequest,
|
||||
WSMessage,
|
||||
AuthenticatedWebSocket
|
||||
AuthenticatedWebSocket,
|
||||
SendMessageOptions
|
||||
} from "./types";
|
||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
|
||||
@@ -34,6 +35,8 @@ const NODE_ID = uuidv4();
|
||||
|
||||
// Client tracking map (local to this node)
|
||||
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
|
||||
// Config version tracking map (clientId -> version)
|
||||
const clientConfigVersions: Map<string, number> = new Map();
|
||||
// Helper to get map key
|
||||
const getClientMapKey = (clientId: string) => clientId;
|
||||
|
||||
@@ -53,6 +56,13 @@ const addClient = async (
|
||||
existingClients.push(ws);
|
||||
connectedClients.set(mapKey, existingClients);
|
||||
|
||||
// Initialize config version to 0 if not already set, otherwise use existing
|
||||
if (!clientConfigVersions.has(clientId)) {
|
||||
clientConfigVersions.set(clientId, 0);
|
||||
}
|
||||
// Set the current config version on the websocket
|
||||
ws.configVersion = clientConfigVersions.get(clientId) || 0;
|
||||
|
||||
logger.info(
|
||||
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`
|
||||
);
|
||||
@@ -84,14 +94,28 @@ const removeClient = async (
|
||||
// Local message sending (within this node)
|
||||
const sendToClientLocal = async (
|
||||
clientId: string,
|
||||
message: WSMessage
|
||||
message: WSMessage,
|
||||
options: SendMessageOptions = {}
|
||||
): Promise<boolean> => {
|
||||
const mapKey = getClientMapKey(clientId);
|
||||
const clients = connectedClients.get(mapKey);
|
||||
if (!clients || clients.length === 0) {
|
||||
return false;
|
||||
}
|
||||
const messageString = JSON.stringify(message);
|
||||
|
||||
// Include config version in message
|
||||
const configVersion = clientConfigVersions.get(clientId) || 0;
|
||||
// Update version on all client connections
|
||||
clients.forEach((client) => {
|
||||
client.configVersion = configVersion;
|
||||
});
|
||||
|
||||
const messageWithVersion = {
|
||||
...message,
|
||||
configVersion
|
||||
};
|
||||
|
||||
const messageString = JSON.stringify(messageWithVersion);
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(messageString);
|
||||
@@ -102,14 +126,30 @@ const sendToClientLocal = async (
|
||||
|
||||
const broadcastToAllExceptLocal = async (
|
||||
message: WSMessage,
|
||||
excludeClientId?: string
|
||||
excludeClientId?: string,
|
||||
options: SendMessageOptions = {}
|
||||
): Promise<void> => {
|
||||
connectedClients.forEach((clients, mapKey) => {
|
||||
const [type, id] = mapKey.split(":");
|
||||
if (!(excludeClientId && id === excludeClientId)) {
|
||||
const clientId = mapKey; // mapKey is the clientId
|
||||
if (!(excludeClientId && clientId === excludeClientId)) {
|
||||
// Handle config version per client
|
||||
if (options.incrementConfigVersion) {
|
||||
const currentVersion = clientConfigVersions.get(clientId) || 0;
|
||||
const newVersion = currentVersion + 1;
|
||||
clientConfigVersions.set(clientId, newVersion);
|
||||
clients.forEach((client) => {
|
||||
client.configVersion = newVersion;
|
||||
});
|
||||
}
|
||||
// Include config version in message for this client
|
||||
const configVersion = clientConfigVersions.get(clientId) || 0;
|
||||
const messageWithVersion = {
|
||||
...message,
|
||||
configVersion
|
||||
};
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(JSON.stringify(message));
|
||||
client.send(JSON.stringify(messageWithVersion));
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -119,10 +159,18 @@ const broadcastToAllExceptLocal = async (
|
||||
// Cross-node message sending
|
||||
const sendToClient = async (
|
||||
clientId: string,
|
||||
message: WSMessage
|
||||
message: WSMessage,
|
||||
options: SendMessageOptions = {}
|
||||
): Promise<boolean> => {
|
||||
// Increment config version if requested
|
||||
if (options.incrementConfigVersion) {
|
||||
const currentVersion = clientConfigVersions.get(clientId) || 0;
|
||||
const newVersion = currentVersion + 1;
|
||||
clientConfigVersions.set(clientId, newVersion);
|
||||
}
|
||||
|
||||
// Try to send locally first
|
||||
const localSent = await sendToClientLocal(clientId, message);
|
||||
const localSent = await sendToClientLocal(clientId, message, options);
|
||||
|
||||
logger.debug(
|
||||
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
|
||||
@@ -133,10 +181,11 @@ const sendToClient = async (
|
||||
|
||||
const broadcastToAllExcept = async (
|
||||
message: WSMessage,
|
||||
excludeClientId?: string
|
||||
excludeClientId?: string,
|
||||
options: SendMessageOptions = {}
|
||||
): Promise<void> => {
|
||||
// Broadcast locally
|
||||
await broadcastToAllExceptLocal(message, excludeClientId);
|
||||
await broadcastToAllExceptLocal(message, excludeClientId, options);
|
||||
};
|
||||
|
||||
// Check if a client has active connections across all nodes
|
||||
@@ -146,6 +195,13 @@ const hasActiveConnections = async (clientId: string): Promise<boolean> => {
|
||||
return !!(clients && clients.length > 0);
|
||||
};
|
||||
|
||||
// Get the current config version for a client
|
||||
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
|
||||
const version = clientConfigVersions.get(clientId);
|
||||
logger.debug(`getClientConfigVersion called for clientId: ${clientId}, returning: ${version} (type: ${typeof version})`);
|
||||
return version;
|
||||
};
|
||||
|
||||
// Get all active nodes for a client
|
||||
const getActiveNodes = async (
|
||||
clientType: ClientType,
|
||||
@@ -259,15 +315,21 @@ const setupConnection = async (
|
||||
if (response.broadcast) {
|
||||
await broadcastToAllExcept(
|
||||
response.message,
|
||||
response.excludeSender ? clientId : undefined
|
||||
response.excludeSender ? clientId : undefined,
|
||||
response.options
|
||||
);
|
||||
} else if (response.targetClientId) {
|
||||
await sendToClient(
|
||||
response.targetClientId,
|
||||
response.message
|
||||
response.message,
|
||||
response.options
|
||||
);
|
||||
} else {
|
||||
ws.send(JSON.stringify(response.message));
|
||||
await sendToClient(
|
||||
clientId,
|
||||
response.message,
|
||||
response.options
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -434,5 +496,6 @@ export {
|
||||
getActiveNodes,
|
||||
disconnectClient,
|
||||
NODE_ID,
|
||||
cleanup
|
||||
cleanup,
|
||||
getClientConfigVersion
|
||||
};
|
||||
|
||||
@@ -48,6 +48,7 @@ import {
|
||||
TooltipTrigger
|
||||
} from "./ui/tooltip";
|
||||
import { getSevenDaysAgo } from "@app/lib/getSevenDaysAgo";
|
||||
import type { QueryRequestAnalyticsResponse } from "@server/routers/auditLogs";
|
||||
|
||||
export type AnalyticsContentProps = {
|
||||
orgId: string;
|
||||
@@ -276,13 +277,32 @@ export function LogAnalyticsData(props: AnalyticsContentProps) {
|
||||
</CardHeader>
|
||||
</Card>
|
||||
|
||||
<Card className="w-full h-full flex flex-col gap-8">
|
||||
<Card className="w-full h-full flex flex-col gap-8 relative">
|
||||
{isLoadingAnalytics && (
|
||||
<div className="absolute z-20 left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2 border border-border rounded-md bg-muted">
|
||||
<div className="flex items-center gap-2 p-6">
|
||||
<LoaderIcon className="size-4 animate-spin" />
|
||||
{t("loadingAnalytics")}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<CardHeader>
|
||||
<h3 className="font-semibold">{t("requestsByDay")}</h3>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<CardContent className="relative">
|
||||
{isLoadingAnalytics && (
|
||||
<div className="backdrop-blur-[2px] z-10 absolute inset-0"></div>
|
||||
)}
|
||||
<RequestChart
|
||||
data={stats?.requestsPerDay ?? []}
|
||||
className={cn(
|
||||
isLoadingAnalytics &&
|
||||
"opacity-50 pointer-events-none"
|
||||
)}
|
||||
data={
|
||||
stats?.requestsPerDay ??
|
||||
generateSampleDailyRequests()
|
||||
}
|
||||
isLoading={isLoadingAnalytics}
|
||||
/>
|
||||
</CardContent>
|
||||
@@ -323,6 +343,28 @@ export function LogAnalyticsData(props: AnalyticsContentProps) {
|
||||
);
|
||||
}
|
||||
|
||||
function generateSampleDailyRequests(): QueryRequestAnalyticsResponse["requestsPerDay"] {
|
||||
const today = new Date();
|
||||
|
||||
// generate sample data for the last 7 days
|
||||
const requestsPerDay = Array.from({ length: 7 }, (_, i) => {
|
||||
const date = new Date(today);
|
||||
date.setDate(date.getDate() - (6 - i));
|
||||
// generate a random number of requests between 1 and 100
|
||||
const totalCount = Math.floor(Math.random() * 100) + 1;
|
||||
// generate a random number of requests between 1 and totalCount
|
||||
const blockedCount = Math.floor(Math.random() * (totalCount + 1));
|
||||
return {
|
||||
day: date.toISOString().split("T")[0],
|
||||
allowedCount: totalCount - blockedCount,
|
||||
blockedCount,
|
||||
totalCount
|
||||
};
|
||||
});
|
||||
|
||||
return requestsPerDay;
|
||||
}
|
||||
|
||||
type RequestChartProps = {
|
||||
data: {
|
||||
day: string;
|
||||
@@ -331,6 +373,7 @@ type RequestChartProps = {
|
||||
totalCount: number;
|
||||
}[];
|
||||
isLoading: boolean;
|
||||
className?: string;
|
||||
};
|
||||
|
||||
function RequestChart(props: RequestChartProps) {
|
||||
@@ -359,7 +402,7 @@ function RequestChart(props: RequestChartProps) {
|
||||
return (
|
||||
<ChartContainer
|
||||
config={chartConfig}
|
||||
className="min-h-[200px] w-full h-80"
|
||||
className={cn("min-h-50 w-full h-80", props.className)}
|
||||
>
|
||||
<LineChart accessibilityLayer data={props.data}>
|
||||
<ChartLegend content={<ChartLegendContent />} />
|
||||
@@ -467,7 +510,7 @@ function TopCountriesList(props: TopCountriesListProps) {
|
||||
</div>
|
||||
)}
|
||||
{/* `aspect-475/335` is the same aspect ratio as the world map component */}
|
||||
<ol className="w-full overflow-auto grid gap-1 aspect-475/335">
|
||||
<ol className="w-full overflow-auto gap-1 aspect-475/335 flex flex-col">
|
||||
{props.countries.length === 0 && (
|
||||
<div className="flex items-center justify-center size-full text-muted-foreground gap-1">
|
||||
{props.isLoading ? (
|
||||
@@ -485,7 +528,7 @@ function TopCountriesList(props: TopCountriesListProps) {
|
||||
return (
|
||||
<li
|
||||
key={country.code}
|
||||
className="grid grid-cols-7 rounded-xs hover:bg-muted relative items-center text-sm"
|
||||
className="w-full grid grid-cols-7 rounded-xs hover:bg-muted relative items-center text-sm"
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
|
||||
Reference in New Issue
Block a user