mirror of
https://github.com/fosrl/pangolin.git
synced 2026-01-28 22:00:51 +00:00
493 lines
15 KiB
TypeScript
493 lines
15 KiB
TypeScript
import { Router, Request, Response } from "express";
|
|
import { Server as HttpServer } from "http";
|
|
import { WebSocket, WebSocketServer } from "ws";
|
|
import { Socket } from "net";
|
|
import { Newt, newts, NewtSession, olms, Olm, OlmSession } from "@server/db";
|
|
import { eq } from "drizzle-orm";
|
|
import { db } from "@server/db";
|
|
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
|
|
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
|
|
import { messageHandlers } from "./messageHandlers";
|
|
import logger from "@server/logger";
|
|
import { v4 as uuidv4 } from "uuid";
|
|
import {
|
|
ClientType,
|
|
TokenPayload,
|
|
WebSocketRequest,
|
|
WSMessage,
|
|
AuthenticatedWebSocket,
|
|
SendMessageOptions
|
|
} from "./types";
|
|
import { validateSessionToken } from "@server/auth/sessions/app";
|
|
|
|
// Subset of TokenPayload for public ws.ts (newt and olm only)
|
|
interface PublicTokenPayload {
|
|
client: Newt | Olm;
|
|
session: NewtSession | OlmSession;
|
|
clientType: "newt" | "olm";
|
|
}
|
|
|
|
const router: Router = Router();
|
|
const wss: WebSocketServer = new WebSocketServer({ noServer: true });
|
|
|
|
// Generate unique node ID for this instance
|
|
const NODE_ID = uuidv4();
|
|
|
|
// Client tracking map (local to this node)
|
|
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
|
|
// Config version tracking map (clientId -> version)
|
|
const clientConfigVersions: Map<string, number> = new Map();
|
|
// Helper to get map key
|
|
const getClientMapKey = (clientId: string) => clientId;
|
|
|
|
// Helper functions for client management
|
|
const addClient = async (
|
|
clientType: ClientType,
|
|
clientId: string,
|
|
ws: AuthenticatedWebSocket
|
|
): Promise<void> => {
|
|
// Generate unique connection ID
|
|
const connectionId = uuidv4();
|
|
ws.connectionId = connectionId;
|
|
|
|
// Add to local tracking
|
|
const mapKey = getClientMapKey(clientId);
|
|
const existingClients = connectedClients.get(mapKey) || [];
|
|
existingClients.push(ws);
|
|
connectedClients.set(mapKey, existingClients);
|
|
|
|
logger.info(
|
|
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`
|
|
);
|
|
};
|
|
|
|
const removeClient = async (
|
|
clientType: ClientType,
|
|
clientId: string,
|
|
ws: AuthenticatedWebSocket
|
|
): Promise<void> => {
|
|
const mapKey = getClientMapKey(clientId);
|
|
const existingClients = connectedClients.get(mapKey) || [];
|
|
const updatedClients = existingClients.filter((client) => client !== ws);
|
|
if (updatedClients.length === 0) {
|
|
connectedClients.delete(mapKey);
|
|
|
|
logger.info(
|
|
`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`
|
|
);
|
|
} else {
|
|
connectedClients.set(mapKey, updatedClients);
|
|
|
|
logger.info(
|
|
`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`
|
|
);
|
|
}
|
|
};
|
|
|
|
// Local message sending (within this node)
|
|
const sendToClientLocal = async (
|
|
clientId: string,
|
|
message: WSMessage,
|
|
options: SendMessageOptions = {}
|
|
): Promise<boolean> => {
|
|
const mapKey = getClientMapKey(clientId);
|
|
const clients = connectedClients.get(mapKey);
|
|
if (!clients || clients.length === 0) {
|
|
return false;
|
|
}
|
|
|
|
// Increment config version if requested
|
|
if (options.incrementConfigVersion) {
|
|
const currentVersion = clientConfigVersions.get(clientId) || 0;
|
|
const newVersion = currentVersion + 1;
|
|
clientConfigVersions.set(clientId, newVersion);
|
|
// Update version on all client connections
|
|
clients.forEach((client) => {
|
|
client.configVersion = newVersion;
|
|
});
|
|
}
|
|
|
|
// Include config version in message
|
|
const configVersion = clientConfigVersions.get(clientId) || 0;
|
|
const messageWithVersion = {
|
|
...message,
|
|
configVersion
|
|
};
|
|
|
|
const messageString = JSON.stringify(messageWithVersion);
|
|
clients.forEach((client) => {
|
|
if (client.readyState === WebSocket.OPEN) {
|
|
client.send(messageString);
|
|
}
|
|
});
|
|
return true;
|
|
};
|
|
|
|
const broadcastToAllExceptLocal = async (
|
|
message: WSMessage,
|
|
excludeClientId?: string,
|
|
options: SendMessageOptions = {}
|
|
): Promise<void> => {
|
|
connectedClients.forEach((clients, mapKey) => {
|
|
const [type, id] = mapKey.split(":");
|
|
const clientId = mapKey; // mapKey is the clientId
|
|
if (!(excludeClientId && clientId === excludeClientId)) {
|
|
// Handle config version per client
|
|
if (options.incrementConfigVersion) {
|
|
const currentVersion = clientConfigVersions.get(clientId) || 0;
|
|
const newVersion = currentVersion + 1;
|
|
clientConfigVersions.set(clientId, newVersion);
|
|
clients.forEach((client) => {
|
|
client.configVersion = newVersion;
|
|
});
|
|
}
|
|
// Include config version in message for this client
|
|
const configVersion = clientConfigVersions.get(clientId) || 0;
|
|
const messageWithVersion = {
|
|
...message,
|
|
configVersion
|
|
};
|
|
clients.forEach((client) => {
|
|
if (client.readyState === WebSocket.OPEN) {
|
|
client.send(JSON.stringify(messageWithVersion));
|
|
}
|
|
});
|
|
}
|
|
});
|
|
};
|
|
|
|
// Cross-node message sending
|
|
const sendToClient = async (
|
|
clientId: string,
|
|
message: WSMessage,
|
|
options: SendMessageOptions = {}
|
|
): Promise<boolean> => {
|
|
// Try to send locally first
|
|
const localSent = await sendToClientLocal(clientId, message, options);
|
|
|
|
logger.debug(
|
|
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
|
|
);
|
|
|
|
return localSent;
|
|
};
|
|
|
|
const broadcastToAllExcept = async (
|
|
message: WSMessage,
|
|
excludeClientId?: string,
|
|
options: SendMessageOptions = {}
|
|
): Promise<void> => {
|
|
// Broadcast locally
|
|
await broadcastToAllExceptLocal(message, excludeClientId, options);
|
|
};
|
|
|
|
// Check if a client has active connections across all nodes
|
|
const hasActiveConnections = async (clientId: string): Promise<boolean> => {
|
|
const mapKey = getClientMapKey(clientId);
|
|
const clients = connectedClients.get(mapKey);
|
|
return !!(clients && clients.length > 0);
|
|
};
|
|
|
|
// Get the current config version for a client
|
|
const getClientConfigVersion = async (clientId: string): Promise<number> => {
|
|
return clientConfigVersions.get(clientId) || 0;
|
|
};
|
|
|
|
// Get all active nodes for a client
|
|
const getActiveNodes = async (
|
|
clientType: ClientType,
|
|
clientId: string
|
|
): Promise<string[]> => {
|
|
const mapKey = getClientMapKey(clientId);
|
|
const clients = connectedClients.get(mapKey);
|
|
return clients && clients.length > 0 ? [NODE_ID] : [];
|
|
};
|
|
|
|
// Token verification middleware
|
|
const verifyToken = async (
|
|
token: string,
|
|
clientType: ClientType,
|
|
userToken: string
|
|
): Promise<PublicTokenPayload | null> => {
|
|
try {
|
|
if (clientType === "newt") {
|
|
const { session, newt } = await validateNewtSessionToken(token);
|
|
if (!session || !newt) {
|
|
return null;
|
|
}
|
|
const existingNewt = await db
|
|
.select()
|
|
.from(newts)
|
|
.where(eq(newts.newtId, newt.newtId));
|
|
if (!existingNewt || !existingNewt[0]) {
|
|
return null;
|
|
}
|
|
return { client: existingNewt[0], session, clientType };
|
|
} else if (clientType === "olm") {
|
|
const { session, olm } = await validateOlmSessionToken(token);
|
|
if (!session || !olm) {
|
|
return null;
|
|
}
|
|
const existingOlm = await db
|
|
.select()
|
|
.from(olms)
|
|
.where(eq(olms.olmId, olm.olmId));
|
|
if (!existingOlm || !existingOlm[0]) {
|
|
return null;
|
|
}
|
|
|
|
if (olm.userId) {
|
|
// this is a user device and we need to check the user token
|
|
const { session: userSession, user } =
|
|
await validateSessionToken(userToken);
|
|
if (!userSession || !user) {
|
|
return null;
|
|
}
|
|
if (user.userId !== olm.userId) {
|
|
return null;
|
|
}
|
|
}
|
|
|
|
return { client: existingOlm[0], session, clientType };
|
|
}
|
|
|
|
return null;
|
|
} catch (error) {
|
|
logger.error("Token verification failed:", error);
|
|
return null;
|
|
}
|
|
};
|
|
|
|
const setupConnection = async (
|
|
ws: AuthenticatedWebSocket,
|
|
client: Newt | Olm,
|
|
clientType: "newt" | "olm"
|
|
): Promise<void> => {
|
|
logger.info("Establishing websocket connection");
|
|
if (!client) {
|
|
logger.error("Connection attempt without client");
|
|
return ws.terminate();
|
|
}
|
|
|
|
ws.client = client;
|
|
ws.clientType = clientType;
|
|
|
|
// Add client to tracking
|
|
const clientId =
|
|
clientType === "newt" ? (client as Newt).newtId : (client as Olm).olmId;
|
|
await addClient(clientType, clientId, ws);
|
|
|
|
ws.on("message", async (data) => {
|
|
try {
|
|
const message: WSMessage = JSON.parse(data.toString());
|
|
|
|
if (!message.type || typeof message.type !== "string") {
|
|
throw new Error(
|
|
"Invalid message format: missing or invalid type"
|
|
);
|
|
}
|
|
|
|
const handler = messageHandlers[message.type];
|
|
if (!handler) {
|
|
throw new Error(`Unsupported message type: ${message.type}`);
|
|
}
|
|
|
|
const response = await handler({
|
|
message,
|
|
senderWs: ws,
|
|
client: ws.client,
|
|
clientType: ws.clientType!,
|
|
sendToClient,
|
|
broadcastToAllExcept,
|
|
connectedClients
|
|
});
|
|
|
|
if (response) {
|
|
if (response.broadcast) {
|
|
await broadcastToAllExcept(
|
|
response.message,
|
|
response.excludeSender ? clientId : undefined,
|
|
response.options
|
|
);
|
|
} else if (response.targetClientId) {
|
|
await sendToClient(
|
|
response.targetClientId,
|
|
response.message,
|
|
response.options
|
|
);
|
|
} else {
|
|
await sendToClient(
|
|
clientId,
|
|
response.message,
|
|
response.options
|
|
);
|
|
}
|
|
}
|
|
} catch (error) {
|
|
logger.error("Message handling error:", error);
|
|
ws.send(
|
|
JSON.stringify({
|
|
type: "error",
|
|
data: {
|
|
message:
|
|
error instanceof Error
|
|
? error.message
|
|
: "Unknown error occurred",
|
|
originalMessage: data.toString()
|
|
}
|
|
})
|
|
);
|
|
}
|
|
});
|
|
|
|
ws.on("close", () => {
|
|
removeClient(clientType, clientId, ws);
|
|
logger.info(
|
|
`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`
|
|
);
|
|
});
|
|
|
|
ws.on("error", (error: Error) => {
|
|
logger.error(
|
|
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,
|
|
error
|
|
);
|
|
});
|
|
|
|
logger.info(
|
|
`WebSocket connection established - ${clientType.toUpperCase()} ID: ${clientId}`
|
|
);
|
|
};
|
|
|
|
// Router endpoint
|
|
router.get("/ws", (req: Request, res: Response) => {
|
|
res.status(200).send("WebSocket endpoint");
|
|
});
|
|
|
|
// WebSocket upgrade handler
|
|
const handleWSUpgrade = (server: HttpServer): void => {
|
|
server.on(
|
|
"upgrade",
|
|
async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
|
|
try {
|
|
const url = new URL(
|
|
request.url || "",
|
|
`http://${request.headers.host}`
|
|
);
|
|
const token =
|
|
url.searchParams.get("token") ||
|
|
request.headers["sec-websocket-protocol"] ||
|
|
"";
|
|
const userToken = url.searchParams.get("userToken") || "";
|
|
let clientType = url.searchParams.get(
|
|
"clientType"
|
|
) as ClientType;
|
|
|
|
if (!clientType) {
|
|
clientType = "newt";
|
|
}
|
|
|
|
if (
|
|
!token ||
|
|
!clientType ||
|
|
!["newt", "olm"].includes(clientType)
|
|
) {
|
|
logger.warn(
|
|
"Unauthorized connection attempt: invalid token or client type..."
|
|
);
|
|
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
|
|
socket.destroy();
|
|
return;
|
|
}
|
|
|
|
const tokenPayload = await verifyToken(
|
|
token,
|
|
clientType,
|
|
userToken
|
|
);
|
|
if (!tokenPayload) {
|
|
logger.warn(
|
|
"Unauthorized connection attempt: invalid token..."
|
|
);
|
|
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
|
|
socket.destroy();
|
|
return;
|
|
}
|
|
|
|
wss.handleUpgrade(
|
|
request,
|
|
socket,
|
|
head,
|
|
(ws: AuthenticatedWebSocket) => {
|
|
setupConnection(
|
|
ws,
|
|
tokenPayload.client,
|
|
tokenPayload.clientType
|
|
);
|
|
}
|
|
);
|
|
} catch (error) {
|
|
logger.error("WebSocket upgrade error:", error);
|
|
socket.write("HTTP/1.1 500 Internal Server Error\r\n\r\n");
|
|
socket.destroy();
|
|
}
|
|
}
|
|
);
|
|
};
|
|
|
|
// Disconnect a specific client and force them to reconnect
|
|
const disconnectClient = async (clientId: string): Promise<boolean> => {
|
|
const mapKey = getClientMapKey(clientId);
|
|
const clients = connectedClients.get(mapKey);
|
|
|
|
if (!clients || clients.length === 0) {
|
|
logger.debug(`No connections found for client ID: ${clientId}`);
|
|
return false;
|
|
}
|
|
|
|
logger.info(
|
|
`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`
|
|
);
|
|
|
|
// Close all connections for this client
|
|
clients.forEach((client) => {
|
|
if (client.readyState === WebSocket.OPEN) {
|
|
client.close(1000, "Disconnected by server");
|
|
}
|
|
});
|
|
|
|
return true;
|
|
};
|
|
|
|
// Cleanup function for graceful shutdown
|
|
const cleanup = async (): Promise<void> => {
|
|
try {
|
|
// Close all WebSocket connections
|
|
connectedClients.forEach((clients) => {
|
|
clients.forEach((client) => {
|
|
if (client.readyState === WebSocket.OPEN) {
|
|
client.terminate();
|
|
}
|
|
});
|
|
});
|
|
|
|
logger.info("WebSocket cleanup completed");
|
|
} catch (error) {
|
|
logger.error("Error during WebSocket cleanup:", error);
|
|
}
|
|
};
|
|
|
|
export {
|
|
router,
|
|
handleWSUpgrade,
|
|
sendToClient,
|
|
broadcastToAllExcept,
|
|
connectedClients,
|
|
hasActiveConnections,
|
|
getActiveNodes,
|
|
disconnectClient,
|
|
NODE_ID,
|
|
cleanup,
|
|
getClientConfigVersion
|
|
};
|