From bdb564823d3dc984a7197e86c1f5d210acc5251d Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 6 Nov 2025 21:19:21 -0800 Subject: [PATCH] Require valid user token --- server/private/routers/ws/ws.ts | 41 +++++++++++++++- server/routers/olm/handleOlmPingMessage.ts | 54 ++++++++++++++++++++-- server/routers/ws/ws.ts | 40 +++++++++++++++- 3 files changed, 126 insertions(+), 9 deletions(-) diff --git a/server/private/routers/ws/ws.ts b/server/private/routers/ws/ws.ts index 0122126f..9e307a64 100644 --- a/server/private/routers/ws/ws.ts +++ b/server/private/routers/ws/ws.ts @@ -38,6 +38,7 @@ import { rateLimitService } from "#private/lib/rateLimit"; import { messageHandlers } from "@server/routers/ws/messageHandlers"; import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers"; import { AuthenticatedWebSocket, ClientType, WSMessage, TokenPayload, WebSocketRequest, RedisMessage } from "@server/routers/ws"; +import { validateSessionToken } from "@server/auth/sessions/app"; // Merge public and private message handlers Object.assign(messageHandlers, privateMessageHandlers); @@ -478,7 +479,8 @@ const getActiveNodes = async ( // Token verification middleware const verifyToken = async ( token: string, - clientType: ClientType + clientType: ClientType, + userToken: string ): Promise => { try { if (clientType === "newt") { @@ -506,6 +508,17 @@ const verifyToken = async ( if (!existingOlm || !existingOlm[0]) { return null; } + + if (olm.userId) { // this is a user device and we need to check the user token + const { session: userSession, user } = await validateSessionToken(userToken); + if (!userSession || !user) { + return null; + } + if (user.userId !== olm.userId) { + return null; + } + } + return { client: existingOlm[0], session, clientType }; } else if (clientType === "remoteExitNode") { const { session, remoteExitNode } = @@ -652,6 +665,7 @@ const handleWSUpgrade = (server: HttpServer): void => { url.searchParams.get("token") || request.headers["sec-websocket-protocol"] || ""; + const userToken = url.searchParams.get('userToken') || ''; let clientType = url.searchParams.get( "clientType" ) as ClientType; @@ -673,7 +687,7 @@ const handleWSUpgrade = (server: HttpServer): void => { return; } - const tokenPayload = await verifyToken(token, clientType); + const tokenPayload = await verifyToken(token, clientType, userToken); if (!tokenPayload) { logger.debug( "Unauthorized connection attempt: invalid token..." @@ -792,6 +806,28 @@ if (redisManager.isRedisEnabled()) { ); } +// Disconnect a specific client and force them to reconnect +const disconnectClient = async (clientId: string): Promise => { + const mapKey = getClientMapKey(clientId); + const clients = connectedClients.get(mapKey); + + if (!clients || clients.length === 0) { + logger.debug(`No connections found for client ID: ${clientId}`); + return false; + } + + logger.info(`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`); + + // Close all connections for this client + clients.forEach((client) => { + if (client.readyState === WebSocket.OPEN) { + client.close(1000, "Disconnected by server"); + } + }); + + return true; +}; + // Cleanup function for graceful shutdown const cleanup = async (): Promise => { try { @@ -829,6 +865,7 @@ export { connectedClients, hasActiveConnections, getActiveNodes, + disconnectClient, NODE_ID, cleanup }; diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index 6f00640d..a80030f9 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -1,8 +1,9 @@ import { db } from "@server/db"; -import { MessageHandler } from "@server/routers/ws"; +import { disconnectClient, MessageHandler } from "#dynamic/routers/ws"; import { clients, Olm } from "@server/db"; import { eq, lt, isNull, and, or } from "drizzle-orm"; import logger from "@server/logger"; +import { validateSessionToken } from "@server/auth/sessions/app"; // Track if the offline checker interval is running let offlineCheckerInterval: NodeJS.Timeout | null = null; @@ -20,10 +21,14 @@ export const startOlmOfflineChecker = (): void => { offlineCheckerInterval = setInterval(async () => { try { - const twoMinutesAgo = Math.floor((Date.now() - OFFLINE_THRESHOLD_MS) / 1000); + const twoMinutesAgo = Math.floor( + (Date.now() - OFFLINE_THRESHOLD_MS) / 1000 + ); + + // TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING // Find clients that haven't pinged in the last 2 minutes and mark them as offline - await db + const offlineClients = await db .update(clients) .set({ online: false }) .where( @@ -34,8 +39,31 @@ export const startOlmOfflineChecker = (): void => { isNull(clients.lastPing) ) ) + ) + .returning(); + + for (const offlineClient of offlineClients) { + logger.info( + `Kicking offline olm client ${offlineClient.clientId} due to inactivity` ); + if (!offlineClient.olmId) { + logger.warn( + `Offline client ${offlineClient.clientId} has no olmId, cannot disconnect` + ); + continue; + } + + // Send a disconnect message to the client if connected + try { + await disconnectClient(offlineClient.olmId); + } catch (error) { + logger.error( + `Error sending disconnect to offline olm ${offlineClient.clientId}`, + { error } + ); + } + } } catch (error) { logger.error("Error in offline checker interval", { error }); } @@ -62,11 +90,27 @@ export const handleOlmPingMessage: MessageHandler = async (context) => { const { message, client: c, sendToClient } = context; const olm = c as Olm; + const { userToken } = message.data; + if (!olm) { logger.warn("Olm not found"); return; } + if (olm.userId) { + // we need to check a user token to make sure its still valid + const { session: userSession, user } = + await validateSessionToken(userToken); + if (!userSession || !user) { + logger.warn("Invalid user session for olm ping"); + return; // by returning here we just ignore the ping and the setInterval will force it to disconnect + } + if (user.userId !== olm.userId) { + logger.warn("User ID mismatch for olm ping"); + return; + } + } + if (!olm.clientId) { logger.warn("Olm has no client ID!"); return; @@ -78,7 +122,7 @@ export const handleOlmPingMessage: MessageHandler = async (context) => { .update(clients) .set({ lastPing: Math.floor(Date.now() / 1000), - online: true, + online: true }) .where(eq(clients.clientId, olm.clientId)); } catch (error) { @@ -89,7 +133,7 @@ export const handleOlmPingMessage: MessageHandler = async (context) => { message: { type: "pong", data: { - timestamp: new Date().toISOString(), + timestamp: new Date().toISOString() } }, broadcast: false, diff --git a/server/routers/ws/ws.ts b/server/routers/ws/ws.ts index 9bba41dc..5ab2c85f 100644 --- a/server/routers/ws/ws.ts +++ b/server/routers/ws/ws.ts @@ -11,6 +11,7 @@ import { messageHandlers } from "./messageHandlers"; import logger from "@server/logger"; import { v4 as uuidv4 } from "uuid"; import { ClientType, TokenPayload, WebSocketRequest, WSMessage, AuthenticatedWebSocket } from "./types"; +import { validateSessionToken } from "@server/auth/sessions/app"; // Subset of TokenPayload for public ws.ts (newt and olm only) interface PublicTokenPayload { @@ -117,7 +118,7 @@ const getActiveNodes = async (clientType: ClientType, clientId: string): Promise }; // Token verification middleware -const verifyToken = async (token: string, clientType: ClientType): Promise => { +const verifyToken = async (token: string, clientType: ClientType, userToken: string): Promise => { try { if (clientType === 'newt') { @@ -145,6 +146,17 @@ try { if (!existingOlm || !existingOlm[0]) { return null; } + + if (olm.userId) { // this is a user device and we need to check the user token + const { session: userSession, user } = await validateSessionToken(userToken); + if (!userSession || !user) { + return null; + } + if (user.userId !== olm.userId) { + return null; + } + } + return { client: existingOlm[0], session, clientType }; } @@ -239,6 +251,7 @@ const handleWSUpgrade = (server: HttpServer): void => { try { const url = new URL(request.url || '', `http://${request.headers.host}`); const token = url.searchParams.get('token') || request.headers["sec-websocket-protocol"] || ''; + const userToken = url.searchParams.get('userToken') || ''; let clientType = url.searchParams.get('clientType') as ClientType; if (!clientType) { @@ -252,7 +265,7 @@ const handleWSUpgrade = (server: HttpServer): void => { return; } - const tokenPayload = await verifyToken(token, clientType); + const tokenPayload = await verifyToken(token, clientType, userToken); if (!tokenPayload) { logger.warn("Unauthorized connection attempt: invalid token..."); socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n"); @@ -271,6 +284,28 @@ const handleWSUpgrade = (server: HttpServer): void => { }); }; +// Disconnect a specific client and force them to reconnect +const disconnectClient = async (clientId: string): Promise => { + const mapKey = getClientMapKey(clientId); + const clients = connectedClients.get(mapKey); + + if (!clients || clients.length === 0) { + logger.debug(`No connections found for client ID: ${clientId}`); + return false; + } + + logger.info(`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`); + + // Close all connections for this client + clients.forEach((client) => { + if (client.readyState === WebSocket.OPEN) { + client.close(1000, "Disconnected by server"); + } + }); + + return true; +}; + // Cleanup function for graceful shutdown const cleanup = async (): Promise => { try { @@ -297,6 +332,7 @@ export { connectedClients, hasActiveConnections, getActiveNodes, + disconnectClient, NODE_ID, cleanup };