refactor(api): add redis-backed rate limiting fallback

This commit is contained in:
2026-03-30 23:23:56 +02:00
parent bcfb18393e
commit ef5e8016a4
9 changed files with 357 additions and 61 deletions
@@ -204,9 +204,9 @@ function createMissingApprovalTableError() {
describe("assistant router tool gating", () => {
let approvalStore = createApprovalStoreMock();
beforeEach(() => {
beforeEach(async () => {
approvalStore = createApprovalStoreMock();
apiRateLimiter.reset();
await apiRateLimiter.reset();
resetAssistantApprovalStorageWarningStateForTests();
});
@@ -148,10 +148,10 @@ function createToolContext(
}
describe("assistant import/export and dispo tools", () => {
beforeEach(() => {
beforeEach(async () => {
vi.clearAllMocks();
vi.unstubAllEnvs();
apiRateLimiter.reset();
await apiRateLimiter.reset();
totpValidateMock.mockReset();
vi.mocked(approveEstimateVersion).mockReset();
vi.mocked(cloneEstimate).mockReset();
@@ -0,0 +1,119 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
describe("rate limiter", () => {
beforeEach(() => {
vi.useFakeTimers();
vi.setSystemTime(new Date("2026-03-30T10:00:00.000Z"));
});
afterEach(() => {
vi.useRealTimers();
vi.resetModules();
vi.unmock("ioredis");
vi.unstubAllEnvs();
});
it("enforces limits and reset in memory mode", async () => {
const { createRateLimiter } = await import("../middleware/rate-limit.js");
const limiter = createRateLimiter(60_000, 2, { backend: "memory", name: "memory-test" });
const first = await limiter("user-1");
const second = await limiter("user-1");
const third = await limiter("user-1");
expect(first.allowed).toBe(true);
expect(first.remaining).toBe(1);
expect(second.allowed).toBe(true);
expect(second.remaining).toBe(0);
expect(third.allowed).toBe(false);
expect(third.remaining).toBe(0);
await limiter.reset();
const afterReset = await limiter("user-1");
expect(afterReset.allowed).toBe(true);
expect(afterReset.remaining).toBe(1);
});
it("uses a shared Redis counter when Redis mode is enabled", async () => {
const store = new Map<string, { count: number; resetAt: number }>();
const delMock = vi.fn(async (...keys: string[]) => {
for (const key of keys) {
store.delete(key);
}
return keys.length;
});
vi.doMock("ioredis", () => ({
Redis: vi.fn().mockImplementation(() => ({
on: vi.fn(),
eval: vi.fn(async (_script: string, _numKeys: number, key: string, windowMsValue: string) => {
const now = Date.now();
const windowMs = Number(windowMsValue);
const existing = store.get(key);
if (!existing || existing.resetAt <= now) {
store.set(key, { count: 1, resetAt: now + windowMs });
} else {
existing.count += 1;
}
const current = store.get(key)!;
return [current.count, current.resetAt - now];
}),
scan: vi.fn(async () => ["0", [...store.keys()]]),
del: delMock,
})),
}));
const { createRateLimiter } = await import("../middleware/rate-limit.js");
const limiter = createRateLimiter(60_000, 2, {
backend: "redis",
redisUrl: "redis://test",
name: "redis-test",
});
const first = await limiter("user-1");
const second = await limiter("user-1");
const third = await limiter("user-1");
expect(first.allowed).toBe(true);
expect(second.allowed).toBe(true);
expect(third.allowed).toBe(false);
expect(third.remaining).toBe(0);
await limiter.reset();
expect(delMock).toHaveBeenCalled();
const afterReset = await limiter("user-1");
expect(afterReset.allowed).toBe(true);
expect(afterReset.remaining).toBe(1);
});
it("falls back to in-memory counters when Redis is unavailable", async () => {
vi.doMock("ioredis", () => ({
Redis: vi.fn().mockImplementation(() => ({
on: vi.fn(),
eval: vi.fn(async () => {
throw new Error("redis down");
}),
scan: vi.fn(async () => ["0", []]),
del: vi.fn(async () => 0),
})),
}));
const { createRateLimiter } = await import("../middleware/rate-limit.js");
const limiter = createRateLimiter(60_000, 2, {
backend: "redis",
redisUrl: "redis://test",
name: "redis-fallback-test",
});
const first = await limiter("user-1");
const second = await limiter("user-1");
const third = await limiter("user-1");
expect(first.allowed).toBe(true);
expect(second.allowed).toBe(true);
expect(third.allowed).toBe(false);
expect(third.remaining).toBe(0);
});
});
+222 -46
View File
@@ -1,34 +1,79 @@
/**
* Simple in-memory rate limiter (Map-based).
* Good enough for single-instance deployments.
* For multi-instance, swap to Redis-backed implementation.
*/
import { Redis } from "ioredis";
import { logger } from "../lib/logger.js";
interface RateLimitEntry {
count: number;
resetAt: number;
}
interface RateLimitResult {
export interface RateLimitResult {
allowed: boolean;
remaining: number;
resetAt: Date;
}
type RateLimitEntry = {
count: number;
resetAt: number;
};
type RateLimitBackendMode = "auto" | "memory" | "redis";
type CreateRateLimiterOptions = {
name?: string;
backend?: RateLimitBackendMode;
redisUrl?: string;
keyPrefix?: string;
};
export interface RateLimiter {
(key: string): RateLimitResult;
reset(): void;
(key: string): Promise<RateLimitResult>;
reset(): Promise<void>;
}
/**
* Creates a sliding-window rate limiter.
* @param windowMs - Time window in milliseconds
* @param maxRequests - Maximum requests allowed within the window
*/
export function createRateLimiter(windowMs: number, maxRequests: number): RateLimiter {
const store = new Map<string, RateLimitEntry>();
type RateLimiterBackend = {
check: (key: string) => Promise<RateLimitResult>;
reset: () => Promise<void>;
};
// Periodically clean up expired entries to prevent memory leaks
const DEFAULT_REDIS_KEY_PREFIX = "capakraken:ratelimit";
const DEFAULT_REDIS_BACKEND = process.env["RATE_LIMIT_BACKEND"] as RateLimitBackendMode | undefined;
const DEFAULT_REDIS_URL = process.env["REDIS_URL"]?.trim();
const warnedRedisFailures = new Set<string>();
let sharedRedisClient: Redis | null = null;
function getBackendMode(
requestedBackend: RateLimitBackendMode | undefined,
): RateLimitBackendMode {
if (requestedBackend === "memory" || requestedBackend === "redis" || requestedBackend === "auto") {
return requestedBackend;
}
if (DEFAULT_REDIS_BACKEND === "memory" || DEFAULT_REDIS_BACKEND === "redis" || DEFAULT_REDIS_BACKEND === "auto") {
return DEFAULT_REDIS_BACKEND;
}
return "auto";
}
function sanitizeKeySegment(value: string): string {
return value.replace(/[^a-zA-Z0-9:_-]/g, "_");
}
function getRedisClient(redisUrl: string): Redis {
if (!sharedRedisClient) {
sharedRedisClient = new Redis(redisUrl, {
lazyConnect: false,
enableReadyCheck: false,
enableOfflineQueue: false,
maxRetriesPerRequest: 1,
commandTimeout: 1000,
});
sharedRedisClient.on("error", (error: unknown) => {
logger.warn({ err: error, redisUrl }, "Rate limiter Redis connection emitted an error");
});
}
return sharedRedisClient;
}
function createMemoryBackend(
windowMs: number,
maxRequests: number,
): RateLimiterBackend {
const store = new Map<string, RateLimitEntry>();
const cleanupInterval = setInterval(() => {
const now = Date.now();
for (const [key, entry] of store) {
@@ -38,45 +83,176 @@ export function createRateLimiter(windowMs: number, maxRequests: number): RateLi
}
}, windowMs);
// Allow garbage collection if the process holds no other references
if (cleanupInterval.unref) {
cleanupInterval.unref();
}
const check = function check(key: string): RateLimitResult {
const now = Date.now();
const existing = store.get(key);
return {
async check(key: string) {
const now = Date.now();
const existing = store.get(key);
// Window expired or first request — start fresh
if (!existing || existing.resetAt <= now) {
const resetAt = now + windowMs;
store.set(key, { count: 1, resetAt });
if (!existing || existing.resetAt <= now) {
const resetAt = now + windowMs;
store.set(key, { count: 1, resetAt });
return {
allowed: true,
remaining: maxRequests - 1,
resetAt: new Date(resetAt),
};
}
existing.count += 1;
return {
allowed: existing.count <= maxRequests,
remaining: Math.max(0, maxRequests - existing.count),
resetAt: new Date(existing.resetAt),
};
},
async reset() {
store.clear();
},
};
}
function createRedisBackend(
windowMs: number,
maxRequests: number,
options: Required<Pick<CreateRateLimiterOptions, "name" | "redisUrl" | "keyPrefix">>,
): RateLimiterBackend {
const redisKeyPrefix = `${options.keyPrefix}:${sanitizeKeySegment(options.name)}`;
const warningKey = `${options.name}:${options.redisUrl}`;
async function runRedisCheck(key: string): Promise<RateLimitResult> {
const client = getRedisClient(options.redisUrl);
const redisKey = `${redisKeyPrefix}:${key}`;
const result = await client.eval(
`
local current = redis.call("INCR", KEYS[1])
local ttl = redis.call("PTTL", KEYS[1])
if ttl < 0 then
redis.call("PEXPIRE", KEYS[1], ARGV[1])
ttl = tonumber(ARGV[1])
end
return {current, ttl}
`,
1,
redisKey,
String(windowMs),
) as [number | string, number | string];
const count = Number(result[0]);
const ttlMs = Math.max(0, Number(result[1]));
return {
allowed: count <= maxRequests,
remaining: Math.max(0, maxRequests - count),
resetAt: new Date(Date.now() + ttlMs),
};
}
async function resetRedisKeys(): Promise<void> {
const client = getRedisClient(options.redisUrl);
const matchPattern = `${redisKeyPrefix}:*`;
let cursor = "0";
do {
const [nextCursor, keys] = await client.scan(
cursor,
"MATCH",
matchPattern,
"COUNT",
100,
);
cursor = nextCursor;
if (keys.length > 0) {
await client.del(...keys);
}
} while (cursor !== "0");
}
return {
async check(key: string) {
try {
return await runRedisCheck(key);
} catch (error) {
if (!warnedRedisFailures.has(warningKey)) {
warnedRedisFailures.add(warningKey);
logger.warn(
{ err: error, redisUrl: options.redisUrl, limiter: options.name },
"Rate limiter Redis backend unavailable, falling back to in-memory counters",
);
}
throw error;
}
},
async reset() {
await resetRedisKeys();
},
};
}
/**
* Creates a rate limiter.
* Uses a Redis-backed shared counter when `REDIS_URL` is configured (or `backend: "redis"` is selected),
* and falls back to in-memory counters when Redis is unavailable or intentionally disabled.
*/
export function createRateLimiter(
windowMs: number,
maxRequests: number,
options: CreateRateLimiterOptions = {},
): RateLimiter {
const name = options.name ?? `window-${windowMs}-max-${maxRequests}`;
const memoryBackend = createMemoryBackend(windowMs, maxRequests);
const backendMode = getBackendMode(options.backend);
const redisUrl = options.redisUrl?.trim() || DEFAULT_REDIS_URL;
const keyPrefix = options.keyPrefix ?? DEFAULT_REDIS_KEY_PREFIX;
const shouldUseRedis = backendMode === "redis" || (backendMode === "auto" && Boolean(redisUrl));
const redisBackend = shouldUseRedis && redisUrl
? createRedisBackend(windowMs, maxRequests, { name, redisUrl, keyPrefix })
: null;
const check = (async (key: string) => {
const normalizedKey = key.trim().toLowerCase();
if (!normalizedKey) {
return {
allowed: true,
remaining: maxRequests - 1,
resetAt: new Date(resetAt),
remaining: maxRequests,
resetAt: new Date(Date.now() + windowMs),
};
}
// Within the current window
existing.count += 1;
const allowed = existing.count <= maxRequests;
return {
allowed,
remaining: Math.max(0, maxRequests - existing.count),
resetAt: new Date(existing.resetAt),
};
} as RateLimiter;
if (!redisBackend) {
return memoryBackend.check(normalizedKey);
}
check.reset = () => {
store.clear();
try {
return await redisBackend.check(normalizedKey);
} catch {
return memoryBackend.check(normalizedKey);
}
}) as RateLimiter;
check.reset = async () => {
await memoryBackend.reset();
if (redisBackend) {
try {
await redisBackend.reset();
} catch {
// Ignore Redis reset errors; tests and local fallback must remain usable.
}
}
};
return check;
}
/** General API rate limiter: 100 requests per 15 minutes per key */
export const apiRateLimiter = createRateLimiter(15 * 60 * 1000, 100);
/** General API rate limiter: 100 requests per 15 minutes per user. */
export const apiRateLimiter = createRateLimiter(15 * 60 * 1000, 100, {
name: "api",
});
/** Auth rate limiter: 5 attempts per 15 minutes per key */
export const authRateLimiter = createRateLimiter(15 * 60 * 1000, 5);
/** Auth rate limiter: 5 attempts per 15 minutes per login identifier. */
export const authRateLimiter = createRateLimiter(15 * 60 * 1000, 5, {
name: "auth",
});
+2 -2
View File
@@ -95,7 +95,7 @@ const isE2eTestMode = process.env["E2E_TEST_MODE"] === "true";
* Protected procedure — requires authenticated session AND a valid DB user record.
* This prevents stale sessions from accessing data after the DB user is deleted.
*/
export const protectedProcedure = t.procedure.use(withLogging).use(({ ctx, next }) => {
export const protectedProcedure = t.procedure.use(withLogging).use(async ({ ctx, next }) => {
if (!ctx.session?.user) {
throw new TRPCError({ code: "UNAUTHORIZED", message: "Authentication required" });
}
@@ -105,7 +105,7 @@ export const protectedProcedure = t.procedure.use(withLogging).use(({ ctx, next
// Rate limit by user ID
if (!isE2eTestMode) {
const rateLimitResult = apiRateLimiter(ctx.dbUser.id);
const rateLimitResult = await apiRateLimiter(ctx.dbUser.id);
if (!rateLimitResult.allowed) {
throw new TRPCError({
code: "TOO_MANY_REQUESTS",