mirror of
https://github.com/fosrl/pangolin.git
synced 2026-01-28 22:00:51 +00:00
275 lines
8.4 KiB
TypeScript
275 lines
8.4 KiB
TypeScript
import { db, orgs } from "@server/db";
|
|
import { MessageHandler } from "@server/routers/ws";
|
|
import {
|
|
clients,
|
|
clientSitesAssociationsCache,
|
|
Olm,
|
|
olms,
|
|
sites
|
|
} from "@server/db";
|
|
import { count, eq } from "drizzle-orm";
|
|
import logger from "@server/logger";
|
|
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
|
import { validateSessionToken } from "@server/auth/sessions/app";
|
|
import { encodeHexLowerCase } from "@oslojs/encoding";
|
|
import { sha256 } from "@oslojs/crypto/sha2";
|
|
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
|
|
import { OlmErrorCodes, sendOlmError } from "./error";
|
|
import { handleFingerprintInsertion } from "./fingerprintingUtils";
|
|
|
|
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
|
logger.info("Handling register olm message!");
|
|
const { message, client: c, sendToClient } = context;
|
|
const olm = c as Olm;
|
|
|
|
const now = Math.floor(Date.now() / 1000);
|
|
|
|
if (!olm) {
|
|
logger.warn("Olm not found");
|
|
return;
|
|
}
|
|
|
|
const {
|
|
publicKey,
|
|
relay,
|
|
olmVersion,
|
|
olmAgent,
|
|
orgId,
|
|
userToken,
|
|
fingerprint,
|
|
postures
|
|
} = message.data;
|
|
|
|
if (!olm.clientId) {
|
|
logger.warn("Olm client ID not found");
|
|
sendOlmError(OlmErrorCodes.CLIENT_ID_NOT_FOUND, olm.olmId);
|
|
return;
|
|
}
|
|
|
|
logger.debug("Handling fingerprint insertion for olm register...", {
|
|
olmId: olm.olmId,
|
|
fingerprint,
|
|
postures
|
|
});
|
|
|
|
await handleFingerprintInsertion(olm, fingerprint, postures);
|
|
|
|
if (
|
|
(olmVersion && olm.version !== olmVersion) ||
|
|
(olmAgent && olm.agent !== olmAgent) ||
|
|
olm.archived
|
|
) {
|
|
await db
|
|
.update(olms)
|
|
.set({
|
|
version: olmVersion,
|
|
agent: olmAgent,
|
|
archived: false
|
|
})
|
|
.where(eq(olms.olmId, olm.olmId));
|
|
}
|
|
|
|
const [client] = await db
|
|
.select()
|
|
.from(clients)
|
|
.where(eq(clients.clientId, olm.clientId))
|
|
.limit(1);
|
|
|
|
if (!client) {
|
|
logger.warn("Client ID not found");
|
|
sendOlmError(OlmErrorCodes.CLIENT_NOT_FOUND, olm.olmId);
|
|
return;
|
|
}
|
|
|
|
if (client.blocked) {
|
|
logger.debug(
|
|
`Client ${client.clientId} is blocked. Ignoring register.`
|
|
);
|
|
sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId);
|
|
return;
|
|
}
|
|
|
|
if (client.approvalState == "pending") {
|
|
logger.debug(
|
|
`Client ${client.clientId} approval is pending. Ignoring register.`
|
|
);
|
|
sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId);
|
|
return;
|
|
}
|
|
|
|
const [org] = await db
|
|
.select()
|
|
.from(orgs)
|
|
.where(eq(orgs.orgId, client.orgId))
|
|
.limit(1);
|
|
|
|
if (!org) {
|
|
logger.warn("Org not found");
|
|
sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId);
|
|
return;
|
|
}
|
|
|
|
if (orgId) {
|
|
if (!olm.userId) {
|
|
logger.warn("Olm has no user ID");
|
|
sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId);
|
|
return;
|
|
}
|
|
|
|
const { session: userSession, user } =
|
|
await validateSessionToken(userToken);
|
|
if (!userSession || !user) {
|
|
logger.warn("Invalid user session for olm register");
|
|
sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId);
|
|
return;
|
|
}
|
|
if (user.userId !== olm.userId) {
|
|
logger.warn("User ID mismatch for olm register");
|
|
sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId);
|
|
return;
|
|
}
|
|
|
|
const sessionId = encodeHexLowerCase(
|
|
sha256(new TextEncoder().encode(userToken))
|
|
);
|
|
|
|
const policyCheck = await checkOrgAccessPolicy({
|
|
orgId: orgId,
|
|
userId: olm.userId,
|
|
sessionId // this is the user token passed in the message
|
|
});
|
|
|
|
logger.debug("Policy check result:", policyCheck);
|
|
|
|
if (policyCheck?.error) {
|
|
logger.error(
|
|
`Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}`
|
|
);
|
|
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
|
|
return;
|
|
}
|
|
|
|
if (!policyCheck.policies?.passwordAge?.compliant === false) {
|
|
logger.warn(
|
|
`Olm user ${olm.userId} has non-compliant password age for org ${orgId}`
|
|
);
|
|
sendOlmError(
|
|
OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED,
|
|
olm.olmId
|
|
);
|
|
return;
|
|
} else if (
|
|
!policyCheck.policies?.maxSessionLength?.compliant === false
|
|
) {
|
|
logger.warn(
|
|
`Olm user ${olm.userId} has non-compliant session length for org ${orgId}`
|
|
);
|
|
sendOlmError(
|
|
OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED,
|
|
olm.olmId
|
|
);
|
|
return;
|
|
} else if (policyCheck.policies?.requiredTwoFactor === false) {
|
|
logger.warn(
|
|
`Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}`
|
|
);
|
|
sendOlmError(
|
|
OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED,
|
|
olm.olmId
|
|
);
|
|
return;
|
|
} else if (!policyCheck.allowed) {
|
|
logger.warn(
|
|
`Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`
|
|
);
|
|
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
|
|
return;
|
|
}
|
|
}
|
|
|
|
logger.debug(
|
|
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
|
|
);
|
|
|
|
if (!publicKey) {
|
|
logger.warn("Public key not provided");
|
|
return;
|
|
}
|
|
|
|
if (client.pubKey !== publicKey || client.archived) {
|
|
logger.info(
|
|
"Public key mismatch. Updating public key and clearing session info..."
|
|
);
|
|
// Update the client's public key
|
|
await db
|
|
.update(clients)
|
|
.set({
|
|
pubKey: publicKey,
|
|
archived: false
|
|
})
|
|
.where(eq(clients.clientId, client.clientId));
|
|
|
|
// set isRelay to false for all of the client's sites to reset the connection metadata
|
|
await db
|
|
.update(clientSitesAssociationsCache)
|
|
.set({
|
|
isRelayed: relay == true
|
|
})
|
|
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
|
}
|
|
|
|
// Get all sites data
|
|
const sitesCountResult = await db
|
|
.select({ count: count() })
|
|
.from(sites)
|
|
.innerJoin(
|
|
clientSitesAssociationsCache,
|
|
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
|
)
|
|
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
|
|
|
// Extract the count value from the result array
|
|
const sitesCount =
|
|
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
|
|
|
|
// Prepare an array to store site configurations
|
|
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
|
|
|
|
// this prevents us from accepting a register from an olm that has not hole punched yet.
|
|
// the olm will pump the register so we can keep checking
|
|
// TODO: I still think there is a better way to do this rather than locking it out here but ???
|
|
if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) {
|
|
logger.warn(
|
|
"Client last hole punch is too old and we have sites to send; skipping this register"
|
|
);
|
|
return;
|
|
}
|
|
|
|
// NOTE: its important that the client here is the old client and the public key is the new key
|
|
const siteConfigurations = await buildSiteConfigurationForOlmClient(
|
|
client,
|
|
publicKey,
|
|
relay
|
|
);
|
|
|
|
// REMOVED THIS SO IT CREATES THE INTERFACE AND JUST WAITS FOR THE SITES
|
|
// if (siteConfigurations.length === 0) {
|
|
// logger.warn("No valid site configurations found");
|
|
// return;
|
|
// }
|
|
|
|
// Return connect message with all site configurations
|
|
return {
|
|
message: {
|
|
type: "olm/wg/connect",
|
|
data: {
|
|
sites: siteConfigurations,
|
|
tunnelIP: client.subnet,
|
|
utilitySubnet: org.utilitySubnet
|
|
}
|
|
},
|
|
broadcast: false,
|
|
excludeSender: false
|
|
};
|
|
};
|