import { Redis } from "ioredis"; import { logger } from "../lib/logger.js"; export interface RateLimitResult { allowed: boolean; remaining: number; resetAt: Date; } type RateLimitEntry = { count: number; resetAt: number; }; type RateLimitBackendMode = "auto" | "memory" | "redis"; type CreateRateLimiterOptions = { name?: string; backend?: RateLimitBackendMode; redisUrl?: string; keyPrefix?: string; }; export interface RateLimiter { (key: string): Promise; reset(): Promise; } type RateLimiterBackend = { check: (key: string) => Promise; reset: () => Promise; }; const DEFAULT_REDIS_KEY_PREFIX = "capakraken:ratelimit"; const DEFAULT_REDIS_BACKEND = process.env["RATE_LIMIT_BACKEND"] as RateLimitBackendMode | undefined; const DEFAULT_REDIS_URL = process.env["REDIS_URL"]?.trim(); const warnedRedisFailures = new Set(); let sharedRedisClient: Redis | null = null; function getBackendMode( requestedBackend: RateLimitBackendMode | undefined, ): RateLimitBackendMode { if (requestedBackend === "memory" || requestedBackend === "redis" || requestedBackend === "auto") { return requestedBackend; } if (DEFAULT_REDIS_BACKEND === "memory" || DEFAULT_REDIS_BACKEND === "redis" || DEFAULT_REDIS_BACKEND === "auto") { return DEFAULT_REDIS_BACKEND; } return "auto"; } function sanitizeKeySegment(value: string): string { return value.replace(/[^a-zA-Z0-9:_-]/g, "_"); } function getRedisClient(redisUrl: string): Redis { if (!sharedRedisClient) { sharedRedisClient = new Redis(redisUrl, { lazyConnect: false, enableReadyCheck: false, enableOfflineQueue: false, maxRetriesPerRequest: 1, commandTimeout: 1000, }); sharedRedisClient.on("error", (error: unknown) => { logger.warn({ err: error, redisUrl }, "Rate limiter Redis connection emitted an error"); }); } return sharedRedisClient; } function createMemoryBackend( windowMs: number, maxRequests: number, ): RateLimiterBackend { const store = new Map(); const cleanupInterval = setInterval(() => { const now = Date.now(); for (const [key, entry] of store) { if (entry.resetAt <= now) { store.delete(key); } } }, windowMs); if (cleanupInterval.unref) { cleanupInterval.unref(); } return { async check(key: string) { const now = Date.now(); const existing = store.get(key); if (!existing || existing.resetAt <= now) { const resetAt = now + windowMs; store.set(key, { count: 1, resetAt }); return { allowed: true, remaining: maxRequests - 1, resetAt: new Date(resetAt), }; } existing.count += 1; return { allowed: existing.count <= maxRequests, remaining: Math.max(0, maxRequests - existing.count), resetAt: new Date(existing.resetAt), }; }, async reset() { store.clear(); }, }; } function createRedisBackend( windowMs: number, maxRequests: number, options: Required>, ): RateLimiterBackend { const redisKeyPrefix = `${options.keyPrefix}:${sanitizeKeySegment(options.name)}`; const warningKey = `${options.name}:${options.redisUrl}`; async function runRedisCheck(key: string): Promise { const client = getRedisClient(options.redisUrl); const redisKey = `${redisKeyPrefix}:${key}`; const result = await client.eval( ` local current = redis.call("INCR", KEYS[1]) local ttl = redis.call("PTTL", KEYS[1]) if ttl < 0 then redis.call("PEXPIRE", KEYS[1], ARGV[1]) ttl = tonumber(ARGV[1]) end return {current, ttl} `, 1, redisKey, String(windowMs), ) as [number | string, number | string]; const count = Number(result[0]); const ttlMs = Math.max(0, Number(result[1])); return { allowed: count <= maxRequests, remaining: Math.max(0, maxRequests - count), resetAt: new Date(Date.now() + ttlMs), }; } async function resetRedisKeys(): Promise { const client = getRedisClient(options.redisUrl); const matchPattern = `${redisKeyPrefix}:*`; let cursor = "0"; do { const [nextCursor, keys] = await client.scan( cursor, "MATCH", matchPattern, "COUNT", 100, ); cursor = nextCursor; if (keys.length > 0) { await client.del(...keys); } } while (cursor !== "0"); } return { async check(key: string) { try { return await runRedisCheck(key); } catch (error) { if (!warnedRedisFailures.has(warningKey)) { warnedRedisFailures.add(warningKey); logger.warn( { err: error, redisUrl: options.redisUrl, limiter: options.name }, "Rate limiter Redis backend unavailable, falling back to in-memory counters", ); } throw error; } }, async reset() { await resetRedisKeys(); }, }; } /** * Creates a rate limiter. * Uses a Redis-backed shared counter when `REDIS_URL` is configured (or `backend: "redis"` is selected), * and falls back to in-memory counters when Redis is unavailable or intentionally disabled. */ export function createRateLimiter( windowMs: number, maxRequests: number, options: CreateRateLimiterOptions = {}, ): RateLimiter { const name = options.name ?? `window-${windowMs}-max-${maxRequests}`; const memoryBackend = createMemoryBackend(windowMs, maxRequests); const backendMode = getBackendMode(options.backend); const redisUrl = options.redisUrl?.trim() || DEFAULT_REDIS_URL; const keyPrefix = options.keyPrefix ?? DEFAULT_REDIS_KEY_PREFIX; const shouldUseRedis = backendMode === "redis" || (backendMode === "auto" && Boolean(redisUrl)); const redisBackend = shouldUseRedis && redisUrl ? createRedisBackend(windowMs, maxRequests, { name, redisUrl, keyPrefix }) : null; const check = (async (key: string) => { const normalizedKey = key.trim().toLowerCase(); if (!normalizedKey) { return { allowed: true, remaining: maxRequests, resetAt: new Date(Date.now() + windowMs), }; } if (!redisBackend) { return memoryBackend.check(normalizedKey); } try { return await redisBackend.check(normalizedKey); } catch { return memoryBackend.check(normalizedKey); } }) as RateLimiter; check.reset = async () => { await memoryBackend.reset(); if (redisBackend) { try { await redisBackend.reset(); } catch { // Ignore Redis reset errors; tests and local fallback must remain usable. } } }; return check; } /** General API rate limiter: 100 requests per 15 minutes per user. */ export const apiRateLimiter = createRateLimiter(15 * 60 * 1000, 100, { name: "api", }); /** Auth rate limiter: 5 attempts per 15 minutes per login identifier. */ export const authRateLimiter = createRateLimiter(15 * 60 * 1000, 5, { name: "auth", });