From 552adf320055d2620fff61af7b610662affc7001 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 12 Jan 2026 21:14:18 -0800 Subject: [PATCH] Properly handle blocked devices --- server/routers/olm/handleOlmPingMessage.ts | 113 +++++++++--------- .../routers/olm/handleOlmRegisterMessage.ts | 16 ++- 2 files changed, 71 insertions(+), 58 deletions(-) diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index 36ca0001..9361193d 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -108,65 +108,65 @@ export const handleOlmPingMessage: MessageHandler = async (context) => { return; } - let client: (typeof clients.$inferSelect) | undefined; - - if (olm.userId) { - // we need to check a user token to make sure its still valid - const { session: userSession, user } = - await validateSessionToken(userToken); - if (!userSession || !user) { - logger.warn("Invalid user session for olm ping"); - return; // by returning here we just ignore the ping and the setInterval will force it to disconnect - } - if (user.userId !== olm.userId) { - logger.warn("User ID mismatch for olm ping"); - return; - } - - // get the client - const [userClient] = await db - .select() - .from(clients) - .where( - and( - eq(clients.olmId, olm.olmId), - eq(clients.userId, olm.userId) - ) - ) - .limit(1); - - if (!userClient) { - logger.warn("Client not found for olm ping"); - return; - } - - client = userClient; - - const sessionId = encodeHexLowerCase( - sha256(new TextEncoder().encode(userToken)) - ); - - const policyCheck = await checkOrgAccessPolicy({ - orgId: client.orgId, - userId: olm.userId, - sessionId // this is the user token passed in the message - }); - - if (!policyCheck.allowed) { - logger.warn( - `Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}` - ); - return; - } - } - if (!olm.clientId) { logger.warn("Olm has no client ID!"); return; } try { - // Update the client's last ping timestamp + // get the client + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, olm.clientId)) + .limit(1); + + if (!client) { + logger.warn("Client not found for olm ping"); + return; + } + + if (client.blocked) { + // NOTE: by returning we dont update the lastPing, so the offline checker will eventually disconnect them + logger.debug(`Blocked client ${client.clientId} attempted olm ping`); + return; + } + + if (olm.userId) { + // we need to check a user token to make sure its still valid + const { session: userSession, user } = + await validateSessionToken(userToken); + if (!userSession || !user) { + logger.warn("Invalid user session for olm ping"); + return; // by returning here we just ignore the ping and the setInterval will force it to disconnect + } + if (user.userId !== olm.userId) { + logger.warn("User ID mismatch for olm ping"); + return; + } + if (user.userId !== client.userId) { + logger.warn("Client user ID mismatch for olm ping"); + return; + } + + const sessionId = encodeHexLowerCase( + sha256(new TextEncoder().encode(userToken)) + ); + + const policyCheck = await checkOrgAccessPolicy({ + orgId: client.orgId, + userId: olm.userId, + sessionId // this is the user token passed in the message + }); + + if (!policyCheck.allowed) { + logger.warn( + `Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}` + ); + return; + } + } + await db .update(clients) .set({ @@ -176,7 +176,12 @@ export const handleOlmPingMessage: MessageHandler = async (context) => { }) .where(eq(clients.clientId, olm.clientId)); - await db.update(olms).set({ archived: false }).where(eq(olms.olmId, olm.olmId)); + if (olm.archived) { + await db + .update(olms) + .set({ archived: false }) + .where(eq(olms.olmId, olm.olmId)); + } } catch (error) { logger.error("Error handling ping message", { error }); } diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 0f71ee8b..3334101e 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -55,6 +55,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } + if (client.blocked) { + logger.debug(`Client ${client.clientId} is blocked. Ignoring register.`); + return; + } + const [org] = await db .select() .from(orgs) @@ -112,18 +117,20 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { if ( (olmVersion && olm.version !== olmVersion) || - (olmAgent && olm.agent !== olmAgent) + (olmAgent && olm.agent !== olmAgent) || + olm.archived ) { await db .update(olms) .set({ version: olmVersion, - agent: olmAgent + agent: olmAgent, + archived: false }) .where(eq(olms.olmId, olm.olmId)); } - if (client.pubKey !== publicKey) { + if (client.pubKey !== publicKey || client.archived) { logger.info( "Public key mismatch. Updating public key and clearing session info..." ); @@ -131,7 +138,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { await db .update(clients) .set({ - pubKey: publicKey + pubKey: publicKey, + archived: false, }) .where(eq(clients.clientId, client.clientId));