diff --git a/packages/api/src/__tests__/assistant-input-bounds.test.ts b/packages/api/src/__tests__/assistant-input-bounds.test.ts new file mode 100644 index 0000000..de92e8b --- /dev/null +++ b/packages/api/src/__tests__/assistant-input-bounds.test.ts @@ -0,0 +1,71 @@ +import { describe, expect, it } from "vitest"; +import { + ASSISTANT_MAX_AGGREGATE_BYTES, + ASSISTANT_MAX_CONTENT_LENGTH, + ASSISTANT_MAX_PAGE_CONTEXT, + assistantChatInputSchema, +} from "../router/assistant-procedure-support.js"; + +describe("assistantChatInputSchema bounds", () => { + it("accepts a normal-sized message", () => { + const result = assistantChatInputSchema.safeParse({ + messages: [{ role: "user", content: "Hello" }], + }); + expect(result.success).toBe(true); + }); + + it("rejects a single message above the per-message length cap", () => { + const huge = "x".repeat(ASSISTANT_MAX_CONTENT_LENGTH + 1); + const result = assistantChatInputSchema.safeParse({ + messages: [{ role: "user", content: huge }], + }); + expect(result.success).toBe(false); + }); + + it("rejects a pageContext above the page-context cap", () => { + const huge = "x".repeat(ASSISTANT_MAX_PAGE_CONTEXT + 1); + const result = assistantChatInputSchema.safeParse({ + messages: [{ role: "user", content: "Hi" }], + pageContext: huge, + }); + expect(result.success).toBe(false); + }); + + it("rejects an aggregate payload above the total-bytes cap", () => { + // Each message is below the per-message cap, but together they exceed + // the aggregate cap. + const oneMessageBytes = 5_000; + const each = "x".repeat(oneMessageBytes); + const count = Math.ceil(ASSISTANT_MAX_AGGREGATE_BYTES / oneMessageBytes) + 2; + const messages = Array.from({ length: count }, () => ({ + role: "user" as const, + content: each, + })); + const result = assistantChatInputSchema.safeParse({ messages }); + expect(result.success).toBe(false); + }); + + it("accepts an aggregate payload right under the cap", () => { + const count = Math.floor(ASSISTANT_MAX_AGGREGATE_BYTES / 1_000) - 1; + const messages = Array.from({ length: count }, () => ({ + role: "user" as const, + content: "x".repeat(1_000), + })); + const result = assistantChatInputSchema.safeParse({ messages }); + expect(result.success).toBe(true); + }); + + it("rejects an empty messages array", () => { + const result = assistantChatInputSchema.safeParse({ messages: [] }); + expect(result.success).toBe(false); + }); + + it("rejects more than 200 messages", () => { + const messages = Array.from({ length: 201 }, () => ({ + role: "user" as const, + content: "x", + })); + const result = assistantChatInputSchema.safeParse({ messages }); + expect(result.success).toBe(false); + }); +}); diff --git a/packages/api/src/router/assistant-procedure-support.ts b/packages/api/src/router/assistant-procedure-support.ts index 95adc9a..a61dbf6 100644 --- a/packages/api/src/router/assistant-procedure-support.ts +++ b/packages/api/src/router/assistant-procedure-support.ts @@ -34,24 +34,47 @@ import { const MAX_TOOL_ITERATIONS = 8; -type AssistantProcedureContext = Pick< - TRPCContext, - "db" | "dbUser" | "roleDefaults" | "session" ->; +type AssistantProcedureContext = Pick; type OpenAiMessage = { role: "system" | "user" | "assistant"; content: string; }; -export const assistantChatInputSchema = z.object({ - messages: z.array(z.object({ - role: z.enum(["user", "assistant"]), - content: z.string(), - })).min(1).max(200), - pageContext: z.string().optional(), - conversationId: z.string().max(120).optional(), -}); +// Per-message and aggregate caps. The per-message cap stops a single 50 MB +// payload from OOM-ing JSON.parse / blowing up the prompt assembly; the +// aggregate cap stops the same with 200 messages × 9 999 chars each. +// 10 000 chars is generous for normal chat, 200 KB total is comfortably under +// any provider's request-budget. +export const ASSISTANT_MAX_CONTENT_LENGTH = 10_000; +export const ASSISTANT_MAX_PAGE_CONTEXT = 2_000; +export const ASSISTANT_MAX_AGGREGATE_BYTES = 200_000; + +export const assistantChatInputSchema = z + .object({ + messages: z + .array( + z.object({ + role: z.enum(["user", "assistant"]), + content: z.string().max(ASSISTANT_MAX_CONTENT_LENGTH), + }), + ) + .min(1) + .max(200), + pageContext: z.string().max(ASSISTANT_MAX_PAGE_CONTEXT).optional(), + conversationId: z.string().max(120).optional(), + }) + .superRefine((val, ctx) => { + let total = 0; + for (const m of val.messages) total += Buffer.byteLength(m.content, "utf8"); + if (val.pageContext) total += Buffer.byteLength(val.pageContext, "utf8"); + if (total > ASSISTANT_MAX_AGGREGATE_BYTES) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: `Aggregate message payload too large (${total} bytes > ${ASSISTANT_MAX_AGGREGATE_BYTES})`, + }); + } + }); type AssistantChatInput = z.infer; @@ -70,14 +93,13 @@ function buildAssistantContextBlock(input: { pageContext?: string | undefined; }) { const permissionList = [...input.permissions]; - let contextBlock = - `\n\nAktueller User: ${input.session?.user?.name ?? "Unknown"} (Rolle: ${input.userRole})`; - contextBlock += - `\nBerechtigungen: ${permissionList.length > 0 ? permissionList.join(", ") : "Nur Lese-Zugriff auf eigene Daten"}`; + let contextBlock = `\n\nAktueller User: ${input.session?.user?.name ?? "Unknown"} (Rolle: ${input.userRole})`; + contextBlock += `\nBerechtigungen: ${permissionList.length > 0 ? permissionList.join(", ") : "Nur Lese-Zugriff auf eigene Daten"}`; if (input.pageContext) { contextBlock += `\nAktuelle Seite: ${input.pageContext}`; - contextBlock += "\nHinweis: Beziehe dich bevorzugt auf den Kontext der aktuellen Seite wenn die Frage des Users dazu passt."; + contextBlock += + "\nHinweis: Beziehe dich bevorzugt auf den Kontext der aktuellen Seite wenn die Frage des Users dazu passt."; } return contextBlock; @@ -94,8 +116,8 @@ function buildOpenAiMessages(input: { { role: "system", content: - ASSISTANT_SYSTEM_PROMPT - + buildAssistantContextBlock({ + ASSISTANT_SYSTEM_PROMPT + + buildAssistantContextBlock({ session: input.session, userRole: input.userRole, permissions: input.permissions, @@ -155,10 +177,7 @@ export async function listPendingApprovalPayloads(ctx: AssistantProcedureContext return approvals.map((approval) => toApprovalPayload(approval, "pending")); } -export async function runAssistantChat( - ctx: AssistantProcedureContext, - input: AssistantChatInput, -) { +export async function runAssistantChat(ctx: AssistantProcedureContext, input: AssistantChatInput) { const dbUser = requireAssistantUser(ctx); const configuredSettings = await ctx.db.systemSettings.findUnique({ where: { id: "singleton" }, diff --git a/packages/api/src/router/project-cover.ts b/packages/api/src/router/project-cover.ts index 6c3943e..0843963 100644 --- a/packages/api/src/router/project-cover.ts +++ b/packages/api/src/router/project-cover.ts @@ -5,6 +5,7 @@ import { createDalleClient, isDalleConfigured, loggedAiCall, parseAiError } from import { findUniqueOrThrow } from "../db/helpers.js"; import { generateGeminiImage, isGeminiConfigured, parseGeminiError } from "../gemini-client.js"; import { validateImageDataUrl } from "../lib/image-validation.js"; +import { checkPromptInjection } from "../lib/prompt-guard.js"; import { resolveSystemSettingsRuntime } from "../lib/system-settings-runtime.js"; import { managerProcedure, protectedProcedure, requirePermission } from "../trpc.js"; @@ -19,9 +20,8 @@ async function readImageGenerationStatus(db: { where: { id: "singleton" }, }); const imageProvider = settings?.["imageProvider"] === "gemini" ? "gemini" : "dalle"; - const configured = imageProvider === "gemini" - ? isGeminiConfigured(settings) - : isDalleConfigured(settings); + const configured = + imageProvider === "gemini" ? isGeminiConfigured(settings) : isDalleConfigured(settings); return { configured, @@ -31,13 +31,30 @@ async function readImageGenerationStatus(db: { export const projectCoverProcedures = { generateCover: managerProcedure - .input(z.object({ - projectId: z.string(), - prompt: z.string().max(500).optional(), - })) + .input( + z.object({ + projectId: z.string(), + prompt: z.string().max(500).optional(), + }), + ) .mutation(async ({ ctx, input }) => { requirePermission(ctx, PermissionKey.MANAGE_PROJECTS); + // The user's free-text "Additional direction" is concatenated into the + // image-generation prompt. Run the same injection guard we apply to + // assistant chat (EGAI 4.6.3.2) so a manager-role user can't pivot the + // image model into "ignore previous instructions" / role-override + // attacks against downstream prompt-aware infra. + if (input.prompt) { + const guard = checkPromptInjection(input.prompt); + if (!guard.safe) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Prompt rejected: contains an injection pattern.", + }); + } + } + const project = await findUniqueOrThrow( ctx.db.project.findUnique({ where: { id: input.projectId }, @@ -85,7 +102,10 @@ export const projectCoverProcedures = { } } else { const dalleClient = createDalleClient(runtimeSettings); - const model = runtimeSettings.aiProvider === "azure" ? runtimeSettings.azureDalleDeployment! : "dall-e-3"; + const model = + runtimeSettings.aiProvider === "azure" + ? runtimeSettings.azureDalleDeployment! + : "dall-e-3"; // eslint-disable-next-line @typescript-eslint/no-explicit-any let response: any; @@ -126,10 +146,12 @@ export const projectCoverProcedures = { }), uploadCover: managerProcedure - .input(z.object({ - projectId: z.string(), - imageDataUrl: z.string(), - })) + .input( + z.object({ + projectId: z.string(), + imageDataUrl: z.string(), + }), + ) .mutation(async ({ ctx, input }) => { requirePermission(ctx, PermissionKey.MANAGE_PROJECTS); @@ -187,10 +209,12 @@ export const projectCoverProcedures = { }), updateCoverFocus: managerProcedure - .input(z.object({ - projectId: z.string(), - coverFocusY: z.number().int().min(0).max(100), - })) + .input( + z.object({ + projectId: z.string(), + coverFocusY: z.number().int().min(0).max(100), + }), + ) .mutation(async ({ ctx, input }) => { requirePermission(ctx, PermissionKey.MANAGE_PROJECTS); await ctx.db.project.update({ @@ -200,13 +224,13 @@ export const projectCoverProcedures = { return { ok: true }; }), - isImageGenConfigured: protectedProcedure - .query(async ({ ctx }) => readImageGenerationStatus(ctx.db)), + isImageGenConfigured: protectedProcedure.query(async ({ ctx }) => + readImageGenerationStatus(ctx.db), + ), /** @deprecated Use isImageGenConfigured instead */ - isDalleConfigured: protectedProcedure - .query(async ({ ctx }) => { - const { configured } = await readImageGenerationStatus(ctx.db); - return { configured }; - }), + isDalleConfigured: protectedProcedure.query(async ({ ctx }) => { + const { configured } = await readImageGenerationStatus(ctx.db); + return { configured }; + }), };