Working on updating targets

This commit is contained in:
Owen
2025-11-17 20:44:39 -05:00
parent dbb1e37033
commit 97c707248e
7 changed files with 149 additions and 187 deletions

View File

@@ -6,7 +6,7 @@ import logger from "@server/logger";
import { sites } from "@server/db"; import { sites } from "@server/db";
import { eq, and, isNotNull } from "drizzle-orm"; import { eq, and, isNotNull } from "drizzle-orm";
import { addTargets as addProxyTargets } from "@server/routers/newt/targets"; import { addTargets as addProxyTargets } from "@server/routers/newt/targets";
import { addTargets as addClientTargets } from "@server/routers/client/targets"; import { addTarget as addClientTargets } from "@server/routers/client/targets";
import { import {
ClientResourcesResults, ClientResourcesResults,
updateClientResources updateClientResources

View File

@@ -326,3 +326,35 @@ export function generateRemoteSubnetsStr(allSiteResources: SiteResource[]) {
remoteSubnets.length > 0 ? remoteSubnets.join(",") : null; remoteSubnets.length > 0 ? remoteSubnets.join(",") : null;
return remoteSubnetsStr; return remoteSubnetsStr;
} }
export type SubnetProxyTarget = {
cidr: string;
portRange?: {
min: number;
max: number;
}[];
};
export function generateSubnetProxyTargets(
allSiteResources: SiteResource[]
): SubnetProxyTarget[] {
let targets: SubnetProxyTarget[] = [];
for (const siteResource of allSiteResources) {
if (siteResource.mode == "host") {
// check if this is a valid ip
const ipSchema = z.union([z.ipv4(), z.ipv6()]);
if (ipSchema.safeParse(siteResource.destination).success) {
targets.push({
cidr: `${siteResource.destination}/32`
});
}
} else if (siteResource.mode == "cidr") {
targets.push({
cidr: siteResource.destination
});
}
}
return targets;
}

View File

@@ -1,35 +1,30 @@
import { sendToClient } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import { SubnetProxyTarget } from "@server/lib/ip";
export async function addTargets( export async function addTarget(newtId: string, target: SubnetProxyTarget) {
newtId: string,
destinationIp: string,
destinationPort: number,
protocol: string,
port: number
) {
const target = `${port}:${destinationIp}:${destinationPort}`;
await sendToClient(newtId, { await sendToClient(newtId, {
type: `newt/wg/${protocol}/add`, type: `newt/wg/target/add`,
data: { data: target
targets: [target] // We can only use one target for WireGuard right now
}
}); });
} }
export async function removeTargets( export async function removeTarget(newtId: string, target: SubnetProxyTarget) {
newtId: string,
destinationIp: string,
destinationPort: number,
protocol: string,
port: number
) {
const target = `${port}:${destinationIp}:${destinationPort}`;
await sendToClient(newtId, { await sendToClient(newtId, {
type: `newt/wg/${protocol}/remove`, type: `newt/wg/target/remove`,
data: { data: target
targets: [target] // We can only use one target for WireGuard right now
}
}); });
} }
export async function updateTarget(
newtId: string,
oldTarget: SubnetProxyTarget,
newTarget: SubnetProxyTarget
) {
await sendToClient(newtId, {
type: `newt/wg/target/update`,
data: {
oldTarget,
newTarget
}
});
}

View File

@@ -15,7 +15,7 @@ import { clients, clientSites, Newt, sites } from "@server/db";
import { eq, and, inArray } from "drizzle-orm"; import { eq, and, inArray } from "drizzle-orm";
import { updatePeer } from "../olm/peers"; import { updatePeer } from "../olm/peers";
import { sendToExitNode } from "#dynamic/lib/exitNodes"; import { sendToExitNode } from "#dynamic/lib/exitNodes";
import { generateRemoteSubnetsStr } from "@server/lib/ip"; import { generateRemoteSubnetsStr, generateSubnetProxyTargets } from "@server/lib/ip";
const inputSchema = z.object({ const inputSchema = z.object({
publicKey: z.string(), publicKey: z.string(),
@@ -222,35 +222,11 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
.from(siteResources) .from(siteResources)
.where(eq(siteResources.siteId, siteId)); .where(eq(siteResources.siteId, siteId));
let targets: {
cidr: string;
portRange?: {
min: number;
max: number;
}[];
}[] = [];
for (const siteResource of allSiteResources) {
if (siteResource.mode == "host") {
// check if this is a valid ip
const ipSchema = z.union([z.ipv4(), z.ipv6()]);
if (ipSchema.safeParse(siteResource.destination).success) {
targets.push({
cidr: `${siteResource.destination}/32`
});
}
} else if (siteResource.mode == "cidr") {
targets.push({
cidr: siteResource.destination
});
}
}
// Build the configuration response // Build the configuration response
const configResponse = { const configResponse = {
ipAddress: site.address, ipAddress: site.address,
peers: validPeers, peers: validPeers,
targets: targets targets: generateSubnetProxyTargets(allSiteResources)
}; };
logger.debug("Sending config: ", configResponse); logger.debug("Sending config: ", configResponse);

View File

@@ -9,16 +9,18 @@ import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { addTargets } from "../client/targets"; import { addTarget } from "../client/targets";
import { getUniqueSiteResourceName } from "@server/db/names"; import { getUniqueSiteResourceName } from "@server/db/names";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations"; import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { generateSubnetProxyTargets } from "@server/lib/ip";
const createSiteResourceParamsSchema = z.strictObject({ const createSiteResourceParamsSchema = z.strictObject({
siteId: z.string().transform(Number).pipe(z.int().positive()), siteId: z.string().transform(Number).pipe(z.int().positive()),
orgId: z.string() orgId: z.string()
}); });
const createSiteResourceSchema = z.strictObject({ const createSiteResourceSchema = z
.strictObject({
name: z.string().min(1).max(255), name: z.string().min(1).max(255),
mode: z.enum(["host", "cidr", "port"]), mode: z.enum(["host", "cidr", "port"]),
protocol: z.enum(["tcp", "udp"]).optional(), protocol: z.enum(["tcp", "udp"]).optional(),
@@ -49,12 +51,15 @@ const createSiteResourceSchema = z.strictObject({
(data) => { (data) => {
if (data.mode === "host") { if (data.mode === "host") {
// Check if it's a valid IP address using zod (v4 or v6) // Check if it's a valid IP address using zod (v4 or v6)
const isValidIP = z.union([z.ipv4(), z.ipv6()]).safeParse(data.destination).success; const isValidIP = z
.union([z.ipv4(), z.ipv6()])
.safeParse(data.destination).success;
// Check if it's a valid domain (hostname pattern, TLD not required) // Check if it's a valid domain (hostname pattern, TLD not required)
const domainRegex = /^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$/; const domainRegex =
/^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$/;
const isValidDomain = domainRegex.test(data.destination); const isValidDomain = domainRegex.test(data.destination);
return isValidIP || isValidDomain; return isValidIP || isValidDomain;
} }
return true; return true;
@@ -68,7 +73,9 @@ const createSiteResourceSchema = z.strictObject({
(data) => { (data) => {
if (data.mode === "cidr") { if (data.mode === "cidr") {
// Check if it's a valid CIDR (v4 or v6) // Check if it's a valid CIDR (v4 or v6)
const isValidCIDR = z.union([z.cidrv4(), z.cidrv6()]).safeParse(data.destination).success; const isValidCIDR = z
.union([z.cidrv4(), z.cidrv6()])
.safeParse(data.destination).success;
return isValidCIDR; return isValidCIDR;
} }
return true; return true;
@@ -76,7 +83,7 @@ const createSiteResourceSchema = z.strictObject({
{ {
message: "Destination must be a valid CIDR notation for cidr mode" message: "Destination must be a valid CIDR notation for cidr mode"
} }
); );
export type CreateSiteResourceBody = z.infer<typeof createSiteResourceSchema>; export type CreateSiteResourceBody = z.infer<typeof createSiteResourceSchema>;
export type CreateSiteResourceResponse = SiteResource; export type CreateSiteResourceResponse = SiteResource;
@@ -213,29 +220,22 @@ export async function createSiteResource(
siteResourceId: newSiteResource.siteResourceId siteResourceId: newSiteResource.siteResourceId
}); });
// Only add targets for port mode const [newt] = await trx
if (mode === "port" && protocol && proxyPort && destinationPort) { .select()
const [newt] = await trx .from(newts)
.select() .where(eq(newts.siteId, site.siteId))
.from(newts) .limit(1);
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) { if (!newt) {
return next( return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found") createHttpError(HttpCode.NOT_FOUND, "Newt not found")
);
}
await addTargets(
newt.newtId,
destination,
destinationPort,
protocol,
proxyPort
); );
} }
const [target] = generateSubnetProxyTargets([newSiteResource]);
await addTarget(newt.newtId, target);
await rebuildSiteClientAssociations(newSiteResource, trx); // we need to call this because we added to the admin role await rebuildSiteClientAssociations(newSiteResource, trx); // we need to call this because we added to the admin role
}); });

View File

@@ -9,14 +9,15 @@ import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { removeTargets } from "../client/targets"; import { removeTarget } from "../client/targets";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations"; import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { generateSubnetProxyTargets } from "@server/lib/ip";
const deleteSiteResourceParamsSchema = z.strictObject({ const deleteSiteResourceParamsSchema = z.strictObject({
siteResourceId: z.string().transform(Number).pipe(z.int().positive()), siteResourceId: z.string().transform(Number).pipe(z.int().positive()),
siteId: z.string().transform(Number).pipe(z.int().positive()), siteId: z.string().transform(Number).pipe(z.int().positive()),
orgId: z.string() orgId: z.string()
}); });
export type DeleteSiteResourceResponse = { export type DeleteSiteResourceResponse = {
message: string; message: string;
@@ -84,7 +85,7 @@ export async function deleteSiteResource(
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// Delete the site resource // Delete the site resource
await trx const [removedSiteResource] = await trx
.delete(siteResources) .delete(siteResources)
.where( .where(
and( and(
@@ -92,36 +93,24 @@ export async function deleteSiteResource(
eq(siteResources.siteId, siteId), eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId) eq(siteResources.orgId, orgId)
) )
); )
.returning();
// Only remove targets for port mode const [newt] = await trx
if ( .select()
existingSiteResource.mode === "port" && .from(newts)
existingSiteResource.protocol && .where(eq(newts.siteId, site.siteId))
existingSiteResource.proxyPort && .limit(1);
existingSiteResource.destinationPort
) {
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) { if (!newt) {
return next( return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found") createHttpError(HttpCode.NOT_FOUND, "Newt not found")
);
}
await removeTargets(
newt.newtId,
existingSiteResource.destination,
existingSiteResource.destinationPort,
existingSiteResource.protocol,
existingSiteResource.proxyPort
); );
} }
const [target] = generateSubnetProxyTargets([removedSiteResource]);
await removeTarget(newt.newtId, target);
await rebuildSiteClientAssociations(existingSiteResource, trx); await rebuildSiteClientAssociations(existingSiteResource, trx);
}); });

View File

@@ -9,18 +9,17 @@ import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { addTargets } from "../client/targets"; import { updateTarget } from "@server/routers/client/targets";
import { generateSubnetProxyTargets } from "@server/lib/ip";
const updateSiteResourceParamsSchema = z.strictObject({ const updateSiteResourceParamsSchema = z.strictObject({
siteResourceId: z siteResourceId: z.string().transform(Number).pipe(z.int().positive()),
.string() siteId: z.string().transform(Number).pipe(z.int().positive()),
.transform(Number) orgId: z.string()
.pipe(z.int().positive()), });
siteId: z.string().transform(Number).pipe(z.int().positive()),
orgId: z.string()
});
const updateSiteResourceSchema = z.strictObject({ const updateSiteResourceSchema = z
.strictObject({
name: z.string().min(1).max(255).optional(), name: z.string().min(1).max(255).optional(),
mode: z.enum(["host", "cidr", "port"]).optional(), mode: z.enum(["host", "cidr", "port"]).optional(),
protocol: z.enum(["tcp", "udp"]).nullish(), protocol: z.enum(["tcp", "udp"]).nullish(),
@@ -119,65 +118,42 @@ export async function updateSiteResource(
const finalProxyPort = updateData.proxyPort !== undefined ? updateData.proxyPort : existingSiteResource.proxyPort; const finalProxyPort = updateData.proxyPort !== undefined ? updateData.proxyPort : existingSiteResource.proxyPort;
const finalDestinationPort = updateData.destinationPort !== undefined ? updateData.destinationPort : existingSiteResource.destinationPort; const finalDestinationPort = updateData.destinationPort !== undefined ? updateData.destinationPort : existingSiteResource.destinationPort;
if (finalMode === "port") {
if (!finalProtocol || !finalProxyPort || !finalDestinationPort) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Protocol, proxy port, and destination port are required for port mode"
)
);
}
// check if resource with same protocol and proxy port already exists
const [existingResource] = await db
.select()
.from(siteResources)
.where(
and(
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId),
eq(siteResources.protocol, finalProtocol),
eq(siteResources.proxyPort, finalProxyPort)
)
)
.limit(1);
if (
existingResource &&
existingResource.siteResourceId !== siteResourceId
) {
return next(
createHttpError(
HttpCode.CONFLICT,
"A resource with the same protocol and proxy port already exists"
)
);
}
}
// Prepare update data // Prepare update data
const updateValues: any = {}; const updateValues: any = {};
if (updateData.name !== undefined) updateValues.name = updateData.name; if (updateData.name !== undefined) updateValues.name = updateData.name;
if (updateData.mode !== undefined) updateValues.mode = updateData.mode; if (updateData.mode !== undefined) updateValues.mode = updateData.mode;
if (updateData.destination !== undefined) updateValues.destination = updateData.destination; if (updateData.destination !== undefined)
if (updateData.enabled !== undefined) updateValues.enabled = updateData.enabled; updateValues.destination = updateData.destination;
if (updateData.enabled !== undefined)
updateValues.enabled = updateData.enabled;
// Handle nullish fields (can be undefined, null, or a value) // Handle nullish fields (can be undefined, null, or a value)
if (updateData.alias !== undefined) { if (updateData.alias !== undefined) {
updateValues.alias = updateData.alias && updateData.alias.trim() ? updateData.alias : null; updateValues.alias =
updateData.alias && updateData.alias.trim()
? updateData.alias
: null;
} }
// Handle port mode fields - include in update if explicitly provided (null or value) or if mode changed // Handle port mode fields - include in update if explicitly provided (null or value) or if mode changed
const isModeChangingFromPort = existingSiteResource.mode === "port" && updateData.mode && updateData.mode !== "port"; const isModeChangingFromPort =
existingSiteResource.mode === "port" &&
updateData.mode &&
updateData.mode !== "port";
if (updateData.protocol !== undefined || isModeChangingFromPort) { if (updateData.protocol !== undefined || isModeChangingFromPort) {
updateValues.protocol = finalMode === "port" ? finalProtocol : null; updateValues.protocol = finalMode === "port" ? finalProtocol : null;
} }
if (updateData.proxyPort !== undefined || isModeChangingFromPort) { if (updateData.proxyPort !== undefined || isModeChangingFromPort) {
updateValues.proxyPort = finalMode === "port" ? finalProxyPort : null; updateValues.proxyPort =
finalMode === "port" ? finalProxyPort : null;
} }
if (updateData.destinationPort !== undefined || isModeChangingFromPort) { if (
updateValues.destinationPort = finalMode === "port" ? finalDestinationPort : null; updateData.destinationPort !== undefined ||
isModeChangingFromPort
) {
updateValues.destinationPort =
finalMode === "port" ? finalDestinationPort : null;
} }
// Update the site resource // Update the site resource
@@ -193,27 +169,21 @@ export async function updateSiteResource(
) )
.returning(); .returning();
// Only add targets for port mode const [newt] = await db
if (updatedSiteResource.mode === "port" && updatedSiteResource.protocol && updatedSiteResource.proxyPort && updatedSiteResource.destinationPort) { .select()
const [newt] = await db .from(newts)
.select() .where(eq(newts.siteId, site.siteId))
.from(newts) .limit(1);
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) { if (!newt) {
return next(createHttpError(HttpCode.NOT_FOUND, "Newt not found")); return next(createHttpError(HttpCode.NOT_FOUND, "Newt not found"));
}
await addTargets(
newt.newtId,
updatedSiteResource.destination,
updatedSiteResource.destinationPort,
updatedSiteResource.protocol,
updatedSiteResource.proxyPort
);
} }
const [oldTarget] = generateSubnetProxyTargets([existingSiteResource]);
const [newTarget] = generateSubnetProxyTargets([updatedSiteResource]);
await updateTarget(newt.newtId, oldTarget, newTarget);
logger.info( logger.info(
`Updated site resource ${siteResourceId} for site ${siteId}` `Updated site resource ${siteResourceId} for site ${siteId}`
); );