diff --git a/src/config/env.test.ts b/src/config/env.test.ts index bd85745..0baa714 100644 --- a/src/config/env.test.ts +++ b/src/config/env.test.ts @@ -8,9 +8,7 @@ const baseEnv = { METRICS_API_KEY: "test-metrics-key", }; -describe("env schema — BCRYPT_COST_FACTOR", () => { - // ── Unit Tests (Task 1.4) ────────────────────────────────────────────────── - +describe("env schema - BCRYPT_COST_FACTOR", () => { describe("unit tests", () => { it("defaults to 12 when BCRYPT_COST_FACTOR is omitted", () => { const result = envSchema.safeParse({ ...baseEnv }); @@ -67,10 +65,6 @@ describe("env schema — BCRYPT_COST_FACTOR", () => { }); }); - // ── Property-Based Tests ─────────────────────────────────────────────────── - - // Feature: bcrypt-cost-config, Property 1: valid cost factor parses to the correct integer - // Validates: Requirements 1.1, 2.1, 4.1 it("Property 1: valid cost factor parses to the correct integer", () => { fc.assert( fc.property(fc.integer({ min: 10, max: 31 }), (n) => { @@ -84,8 +78,6 @@ describe("env schema — BCRYPT_COST_FACTOR", () => { ); }); - // Feature: bcrypt-cost-config, Property 2: out-of-range values are rejected - // Validates: Requirements 1.2, 1.3, 5.1, 5.2 it("Property 2: out-of-range values are rejected", () => { fc.assert( fc.property( @@ -102,8 +94,6 @@ describe("env schema — BCRYPT_COST_FACTOR", () => { ); }); - // Feature: bcrypt-cost-config, Property 3: non-numeric strings are rejected - // Validates: Requirements 1.4, 5.3 it("Property 3: non-numeric strings are rejected", () => { fc.assert( fc.property( diff --git a/src/config/index.ts b/src/config/index.ts index bff79a2..142c2aa 100644 --- a/src/config/index.ts +++ b/src/config/index.ts @@ -133,6 +133,13 @@ export const config = { maxRequests: env.REST_RATE_LIMIT_MAX_REQUESTS, }, + rateLimiter: { + maxRequests: env.RATE_LIMIT_MAX_REQUESTS, + windowMs: env.RATE_LIMIT_WINDOW_MS, + store: env.RATE_LIMIT_STORE, + postgresTable: env.RATE_LIMIT_PG_TABLE, + }, + sorobanRpc: env.SOROBAN_RPC_ENABLED && env.SOROBAN_RPC_URL ? { diff --git a/src/index.ts b/src/index.ts index 08309f1..edf4199 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,7 +2,7 @@ import './config/env.js' import express from 'express'; import helmet from 'helmet'; import { initializeDb, closeDb } from './db/index.js'; -import { closePgPool } from './db.js'; +import { closePgPool, pool } from './db.js'; import { closeDbPool } from './config/health.js'; import { disconnectPrisma } from './lib/prisma.js'; import { errorHandler } from './middleware/errorHandler.js'; diff --git a/src/routes/gatewayRoutes.test.ts b/src/routes/gatewayRoutes.test.ts index 31eb472..155332b 100644 --- a/src/routes/gatewayRoutes.test.ts +++ b/src/routes/gatewayRoutes.test.ts @@ -6,13 +6,15 @@ import { errorHandler } from "../middleware/errorHandler.js"; import { requestIdMiddleware } from "../middleware/requestId.js"; describe("gateway route - rate limiting", () => { + let now = 0; + beforeEach(() => { - jest.useFakeTimers("modern" as unknown as any); - jest.setSystemTime(new Date("2026-03-30T00:00:00.000Z").getTime()); + now = new Date("2026-03-30T00:00:00.000Z").getTime(); + jest.spyOn(Date, "now").mockImplementation(() => now); }); afterEach(() => { - jest.useRealTimers(); + jest.restoreAllMocks(); }); test("returns 429 with Retry-After when rate limited", async () => { diff --git a/src/routes/gatewayRoutes.ts b/src/routes/gatewayRoutes.ts index ff2fc31..663e217 100644 --- a/src/routes/gatewayRoutes.ts +++ b/src/routes/gatewayRoutes.ts @@ -58,7 +58,7 @@ export function createGatewayRouter(deps: GatewayDeps): Router { return; } - const rateResult = rateLimiter.check(apiKeyHeader); + const rateResult = await rateLimiter.check(apiKeyHeader); if (!rateResult.allowed) { const retryAfterSec = Math.ceil((rateResult.retryAfterMs ?? 1000) / 1000); res.set('Retry-After', String(retryAfterSec)); diff --git a/src/routes/index.ts b/src/routes/index.ts index 0be483b..5aaa44b 100644 --- a/src/routes/index.ts +++ b/src/routes/index.ts @@ -1,4 +1,4 @@ -import { Router } from 'express'; +import { Router, type RequestHandler, type Router as ExpressRouter } from 'express'; import healthRouter from './health.js'; import usageRouter from './usage.js'; import billingRouter from './billing.js'; diff --git a/src/routes/proxyRoutes.ts b/src/routes/proxyRoutes.ts index b1c780f..6351e28 100644 --- a/src/routes/proxyRoutes.ts +++ b/src/routes/proxyRoutes.ts @@ -117,7 +117,7 @@ export function createProxyRouter(deps: ProxyDeps): Router { } // 3. Rate-limit check - const rateResult = rateLimiter.check(apiKeyHeader); + const rateResult = await rateLimiter.check(apiKeyHeader); if (!rateResult.allowed) { const retryAfterSec = Math.ceil((rateResult.retryAfterMs ?? 1000) / 1000); res.set('Retry-After', String(retryAfterSec)); diff --git a/src/services/rateLimiter.test.ts b/src/services/rateLimiter.test.ts index a8d74b7..f95a447 100644 --- a/src/services/rateLimiter.test.ts +++ b/src/services/rateLimiter.test.ts @@ -1,42 +1,601 @@ -import { createRateLimiter } from "./rateLimiter.js"; +import assert from 'node:assert/strict'; + +import { + createConfiguredRateLimiter, + createRateLimiter, + InMemoryRateLimiter, + InMemoryRateLimiterStore, + isPersistentRateLimiterStore, + type PersistentRateLimiterPool, + PostgresRateLimiterStore, +} from './rateLimiter.js'; + +type StoredBucket = { + bucket_key: string; + last_refill_ms: number; + tokens: number; + updated_at: number; +}; + +class MockPersistentPool implements PersistentRateLimiterPool { + readonly createdTables = new Set(); + readonly rows = new Map(); + + private readonly lockOwners = new Map(); + private readonly lockQueues = new Map void>>(); + private nextClientId = 1; + + async connect() { + const clientId = this.nextClientId++; + let lockedKey: string | null = null; + + return { + query: async (text: string, params: unknown[] = []) => { + const normalized = text.replace(/\s+/g, ' ').trim(); + + if (normalized.startsWith('BEGIN')) { + return { rows: [] as T[] }; + } + + if (normalized.startsWith('COMMIT') || normalized.startsWith('ROLLBACK')) { + this.releaseLock(clientId, lockedKey); + lockedKey = null; + return { rows: [] as T[] }; + } + + if (normalized.startsWith('CREATE TABLE IF NOT EXISTS')) { + const match = normalized.match(/CREATE TABLE IF NOT EXISTS ([a-z0-9_]+)/i); + if (match?.[1]) { + this.createdTables.add(match[1]); + } + return { rows: [] as T[] }; + } + + if (normalized.startsWith('CREATE INDEX IF NOT EXISTS')) { + return { rows: [] as T[] }; + } + + if (normalized.startsWith('INSERT INTO')) { + const bucketKey = String(params[0]); + if (!this.rows.has(bucketKey)) { + this.rows.set(bucketKey, { + bucket_key: bucketKey, + last_refill_ms: Number(params[2]), + tokens: Number(params[1]), + updated_at: Date.now(), + }); + } + return { rows: [] as T[] }; + } + + if (normalized.includes('FOR UPDATE')) { + const bucketKey = String(params[0]); + await this.acquireLock(bucketKey, clientId); + lockedKey = bucketKey; + + const row = this.rows.get(bucketKey); + return { + rows: row ? [row as unknown as T] : [], + }; + } + + if (normalized.startsWith('UPDATE')) { + const bucketKey = String(params[0]); + const row = this.rows.get(bucketKey); + + if (!row) { + throw new Error(`Missing row for ${bucketKey}`); + } + + row.tokens = Number(params[1]); + row.last_refill_ms = Number(params[2]); + row.updated_at = Date.now(); + return { rows: [] as T[] }; + } + + if (normalized.startsWith('SELECT bucket_key, tokens FROM')) { + const bucketKey = String(params[0]); + const row = this.rows.get(bucketKey); + + return { + rows: row ? ([{ bucket_key: row.bucket_key, tokens: row.tokens }] as T[]) : [], + }; + } + + throw new Error(`Unhandled SQL in mock pool: ${normalized}`); + }, + release: () => { + this.releaseLock(clientId, lockedKey); + lockedKey = null; + }, + }; + } + + private async acquireLock(bucketKey: string, clientId: number): Promise { + while (true) { + const owner = this.lockOwners.get(bucketKey); + + if (owner === undefined || owner === clientId) { + this.lockOwners.set(bucketKey, clientId); + return; + } + + await new Promise((resolve) => { + const queue = this.lockQueues.get(bucketKey) ?? []; + queue.push(resolve); + this.lockQueues.set(bucketKey, queue); + }); + } + } + + private releaseLock(clientId: number, bucketKey: string | null): void { + if (!bucketKey) { + return; + } + + if (this.lockOwners.get(bucketKey) !== clientId) { + return; + } + + this.lockOwners.delete(bucketKey); + const queue = this.lockQueues.get(bucketKey); + const next = queue?.shift(); + + if (!queue || queue.length === 0) { + this.lockQueues.delete(bucketKey); + } + + next?.(); + } +} + +describe('InMemoryRateLimiter', () => { + let now = 0; + + beforeEach(() => { + now = new Date('2026-03-30T00:00:00.000Z').getTime(); + jest.spyOn(Date, 'now').mockImplementation(() => now); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + test('allows up to maxRequests then rejects until the window elapses', async () => { + const rl = createRateLimiter(2, 1000); + const apiKey = 'test-key'; + + assert.deepEqual(await rl.check(apiKey), { allowed: true }); + assert.deepEqual(await rl.check(apiKey), { allowed: true }); + + const rejected = await rl.check(apiKey); + assert.equal(rejected.allowed, false); + assert.equal(rejected.retryAfterMs, 1000); + + now += 500; + const retrying = await rl.check(apiKey); + assert.equal(retrying.allowed, false); + assert.equal(retrying.retryAfterMs, 500); + + now += 500; + assert.deepEqual(await rl.check(apiKey), { allowed: true }); + }); + + test('tracks buckets independently per API key', async () => { + const rl = createRateLimiter(1, 1_000); + + assert.deepEqual(await rl.check('first-key'), { allowed: true }); + assert.deepEqual(await rl.check('second-key'), { allowed: true }); + + const firstRejected = await rl.check('first-key'); + assert.equal(firstRejected.allowed, false); + assert.equal(firstRejected.retryAfterMs, 1000); + }); + + test('exhaust helper forces the next request to be rejected', async () => { + const rl = createRateLimiter(3, 5_000); + + rl.exhaust('force-blocked'); + + const result = await rl.check('force-blocked'); + assert.equal(result.allowed, false); + assert.equal(result.retryAfterMs, 5000); + }); + + test('reset helper clears all in-memory buckets', async () => { + const rl = createRateLimiter(1, 1_000); + + await rl.check('reset-me'); + rl.reset(); + + assert.deepEqual(await rl.check('reset-me'), { allowed: true }); + }); + + test('rejects invalid limiter dimensions', () => { + assert.throws(() => createRateLimiter(0, 1_000), /maxRequests must be a positive integer/); + assert.throws(() => createRateLimiter(1, 0), /windowMs must be a positive integer/); + }); +}); + +describe('InMemoryRateLimiterStore', () => { + test('persists the updated bucket state between checks', async () => { + const store = new InMemoryRateLimiterStore(); + const options = { maxRequests: 2, now: 10, windowMs: 1000 }; + + assert.deepEqual(await store.check('bucket', options), { allowed: true }); + assert.deepEqual(await store.check('bucket', options), { allowed: true }); + + const blocked = await store.check('bucket', options); + assert.equal(blocked.allowed, false); + assert.equal(blocked.retryAfterMs, 1000); + }); +}); + +describe('PostgresRateLimiterStore', () => { + let now = 0; -describe("InMemoryRateLimiter", () => { beforeEach(() => { - // Use modern fake timers so Date.now() and time advances are deterministic - // cast to any to satisfy TS defs in this project - jest.useFakeTimers("modern" as unknown as any); - jest.setSystemTime(new Date("2026-03-30T00:00:00.000Z").getTime()); + now = new Date('2026-03-30T00:00:00.000Z').getTime(); + jest.spyOn(Date, 'now').mockImplementation(() => now); }); afterEach(() => { - jest.useRealTimers(); + jest.restoreAllMocks(); }); - test("allows up to maxRequests then rejects until window elapses", () => { - const maxRequests = 2; - const windowMs = 1000; - const rl = createRateLimiter(maxRequests, windowMs); - const apiKey = "test-key"; + test('shares buckets across limiter instances backed by the same pool', async () => { + const pool = new MockPersistentPool(); + const limiterA = createConfiguredRateLimiter( + { maxRequests: 2, store: 'postgres', windowMs: 1000 }, + pool, + ); + const limiterB = createConfiguredRateLimiter( + { maxRequests: 2, store: 'postgres', windowMs: 1000 }, + pool, + ); + + assert.deepEqual(await limiterA.check('shared-key'), { allowed: true }); + assert.deepEqual(await limiterB.check('shared-key'), { allowed: true }); + + const blocked = await limiterA.check('shared-key'); + assert.equal(blocked.allowed, false); + assert.equal(blocked.retryAfterMs, 1000); + + assert.deepEqual(pool.rows.get('shared-key'), { + bucket_key: 'shared-key', + last_refill_ms: now, + tokens: 0, + updated_at: now, + }); + }); - const r1 = rl.check(apiKey); - expect(r1.allowed).toBe(true); + test('refills exhausted buckets after the window elapses', async () => { + const pool = new MockPersistentPool(); + const limiter = createConfiguredRateLimiter( + { maxRequests: 1, store: 'postgres', windowMs: 1_000 }, + pool, + ); - const r2 = rl.check(apiKey); - expect(r2.allowed).toBe(true); + assert.deepEqual(await limiter.check('refill-key'), { allowed: true }); + + const blocked = await limiter.check('refill-key'); + assert.equal(blocked.allowed, false); + assert.equal(blocked.retryAfterMs, 1_000); + + now += 1_000; + assert.deepEqual(await limiter.check('refill-key'), { allowed: true }); + }); + + test('serializes concurrent checks so only maxRequests calls succeed', async () => { + const pool = new MockPersistentPool(); + const limiter = createConfiguredRateLimiter( + { maxRequests: 2, store: 'postgres', windowMs: 60_000 }, + pool, + ); + + const results = await Promise.all( + Array.from({ length: 5 }, () => limiter.check('concurrent-key')), + ); + + assert.equal(results.filter((result) => result.allowed).length, 2); + assert.equal(results.filter((result) => !result.allowed).length, 3); + for (const rejected of results.filter((result) => !result.allowed)) { + assert.equal(rejected.retryAfterMs, 60_000); + } + }); - // third request should be rejected - const r3 = rl.check(apiKey); - expect(r3.allowed).toBe(false); - expect(r3.retryAfterMs).toBe(windowMs); + test('creates the backing table lazily on first use', async () => { + const pool = new MockPersistentPool(); + const store = new PostgresRateLimiterStore(pool, { + tableName: 'custom_rate_limit_table', + }); + + assert.deepEqual( + await store.check('lazy-table-key', { + maxRequests: 1, + now: Date.now(), + windowMs: 1_000, + }), + { allowed: true }, + ); + + assert.ok(pool.createdTables.has('custom_rate_limit_table')); + }); + + test('rejects unsafe table names before querying the database', () => { + const pool = new MockPersistentPool(); + + assert.throws( + () => new PostgresRateLimiterStore(pool, { tableName: 'rate-limits;DROP TABLE users' }), + /letters, numbers, and underscores/, + ); + }); + + test('rejects malformed persisted string values', async () => { + const pool: PersistentRateLimiterPool = { + async connect() { + return { + async query(text: string) { + if (text.includes('CREATE TABLE') || text.includes('CREATE INDEX')) { + return { rows: [] as T[] }; + } + if (text === 'BEGIN' || text === 'COMMIT' || text === 'ROLLBACK') { + return { rows: [] as T[] }; + } + if (text.includes('INSERT INTO')) { + return { rows: [] as T[] }; + } + if (text.includes('FOR UPDATE')) { + return { + rows: [ + { + bucket_key: 'bad-string', + tokens: '1.5', + last_refill_ms: '123', + }, + ] as T[], + }; + } + return { rows: [] as T[] }; + }, + release() {}, + }; + }, + }; + + const store = new PostgresRateLimiterStore(pool); + + await assert.rejects( + store.check('bad-string', { maxRequests: 2, now, windowMs: 1_000 }), + /tokens must be stored as a non-negative integer/, + ); + }); + + test('rejects malformed persisted numeric values', async () => { + const pool: PersistentRateLimiterPool = { + async connect() { + return { + async query(text: string) { + if (text.includes('CREATE TABLE') || text.includes('CREATE INDEX')) { + return { rows: [] as T[] }; + } + if (text === 'BEGIN' || text === 'COMMIT' || text === 'ROLLBACK') { + return { rows: [] as T[] }; + } + if (text.includes('INSERT INTO')) { + return { rows: [] as T[] }; + } + if (text.includes('FOR UPDATE')) { + return { + rows: [ + { + bucket_key: 'bad-number', + tokens: -1, + last_refill_ms: 123, + }, + ] as T[], + }; + } + return { rows: [] as T[] }; + }, + release() {}, + }; + }, + }; + + const store = new PostgresRateLimiterStore(pool); + + await assert.rejects( + store.check('bad-number', { maxRequests: 2, now, windowMs: 1_000 }), + /tokens must be stored as a non-negative integer/, + ); + }); - // advance time by less than window -> still rejected - jest.advanceTimersByTime(500); - const r4 = rl.check(apiKey); - expect(r4.allowed).toBe(false); + test('accepts persisted integer strings from the backing store', async () => { + const pool: PersistentRateLimiterPool = { + async connect() { + return { + async query(text: string) { + if (text.includes('CREATE TABLE') || text.includes('CREATE INDEX')) { + return { rows: [] as T[] }; + } + if (text === 'BEGIN' || text === 'COMMIT' || text === 'ROLLBACK') { + return { rows: [] as T[] }; + } + if (text.includes('INSERT INTO') || text.trim().startsWith('UPDATE')) { + return { rows: [] as T[] }; + } + if (text.includes('FOR UPDATE')) { + return { + rows: [ + { + bucket_key: 'string-values', + tokens: '2', + last_refill_ms: String(now), + }, + ] as T[], + }; + } + return { rows: [] as T[] }; + }, + release() {}, + }; + }, + }; + + const store = new PostgresRateLimiterStore(pool); + + assert.deepEqual( + await store.check('string-values', { maxRequests: 2, now, windowMs: 1_000 }), + { allowed: true }, + ); + }); + + test('rolls back the transaction when a database write fails', async () => { + let rollbackCount = 0; + const pool: PersistentRateLimiterPool = { + async connect() { + return { + async query(text: string) { + if (text.includes('CREATE TABLE') || text.includes('CREATE INDEX')) { + return { rows: [] as T[] }; + } + if (text === 'BEGIN') { + return { rows: [] as T[] }; + } + if (text === 'ROLLBACK') { + rollbackCount += 1; + return { rows: [] as T[] }; + } + if (text.includes('INSERT INTO')) { + return { rows: [] as T[] }; + } + if (text.includes('FOR UPDATE')) { + return { + rows: [ + { + bucket_key: 'rollback-key', + tokens: 2, + last_refill_ms: now, + }, + ] as T[], + }; + } + if (text.includes('UPDATE')) { + throw new Error('update failed'); + } + return { rows: [] as T[] }; + }, + release() {}, + }; + }, + }; + + const store = new PostgresRateLimiterStore(pool); + + await assert.rejects( + store.check('rollback-key', { maxRequests: 2, now, windowMs: 1_000 }), + /update failed/, + ); + assert.equal(rollbackCount, 1); + }); + + test('throws if a bucket still cannot be read after initialization', async () => { + const pool: PersistentRateLimiterPool = { + async connect() { + return { + async query(text: string) { + if (text.includes('CREATE TABLE') || text.includes('CREATE INDEX')) { + return { rows: [] as T[] }; + } + return { rows: [] as T[] }; + }, + release() {}, + }; + }, + }; + + const store = new PostgresRateLimiterStore(pool); + + await assert.rejects( + store.check('missing-row', { maxRequests: 2, now, windowMs: 1_000 }), + /was not found after initialization/, + ); + }); + + test('retries table creation after an initialization failure', async () => { + let failCreate = true; + const pool = new MockPersistentPool(); + const originalConnect = pool.connect.bind(pool); + + pool.connect = async () => { + const client = await originalConnect(); + return { + async query(text: string, params?: unknown[]) { + if (text.includes('CREATE TABLE') && failCreate) { + failCreate = false; + throw new Error('create failed'); + } + return client.query(text, params); + }, + release: client.release, + }; + }; + + const store = new PostgresRateLimiterStore(pool); + + await assert.rejects( + store.check('retry-table', { maxRequests: 1, now, windowMs: 1_000 }), + /create failed/, + ); + assert.deepEqual( + await store.check('retry-table', { maxRequests: 1, now, windowMs: 1_000 }), + { allowed: true }, + ); + }); +}); + +describe('createConfiguredRateLimiter', () => { + test('uses default settings when no explicit limiter options are provided', async () => { + const limiter = createConfiguredRateLimiter({}); + + assert.ok(limiter instanceof InMemoryRateLimiter); + assert.deepEqual(await limiter.check('default-key'), { allowed: true }); + }); + + test('falls back to the in-memory limiter when persistence is not configured', () => { + const limiter = createConfiguredRateLimiter({ + maxRequests: 4, + store: 'memory', + windowMs: 2_000, + }); + + assert.ok(limiter instanceof InMemoryRateLimiter); + }); + + test('requires a pool when the postgres store is selected', () => { + assert.throws( + () => + createConfiguredRateLimiter({ + maxRequests: 4, + store: 'postgres', + windowMs: 2_000, + }), + /PostgreSQL pool is required/, + ); + }); + + test('creates an in-memory limiter with default constructor values', async () => { + const limiter = createRateLimiter(); + + assert.deepEqual(await limiter.check('constructor-defaults'), { allowed: true }); + }); - // advance to end of window -> tokens should be refilled - jest.advanceTimersByTime(500); - const r5 = rl.check(apiKey); - expect(r5.allowed).toBe(true); + test('identifies persistent stores for callers that need special handling', () => { + assert.equal(isPersistentRateLimiterStore(new InMemoryRateLimiterStore()), false); + assert.equal( + isPersistentRateLimiterStore(new PostgresRateLimiterStore(new MockPersistentPool())), + true, + ); }); }); diff --git a/src/services/rateLimiter.ts b/src/services/rateLimiter.ts index 868348f..66d6ff6 100644 --- a/src/services/rateLimiter.ts +++ b/src/services/rateLimiter.ts @@ -1,63 +1,387 @@ -import { RateLimiter, RateLimitResult } from '../types/gateway.js'; +import type { PoolClient } from 'pg'; +import type { RateLimiter, RateLimitResult } from '../types/gateway.js'; interface TokenBucket { tokens: number; lastRefill: number; } -/** - * Simple token-bucket rate limiter. - * Each API key gets `maxRequests` tokens per `windowMs` window. - */ -export class InMemoryRateLimiter implements RateLimiter { - private buckets = new Map(); - private maxRequests: number; - private windowMs: number; +export interface RateLimiterStoreCheckOptions { + maxRequests: number; + now: number; + windowMs: number; +} - constructor(maxRequests: number, windowMs: number) { - this.maxRequests = maxRequests; - this.windowMs = windowMs; +export interface RateLimiterStore { + check(bucketKey: string, options: RateLimiterStoreCheckOptions): Promise; +} + +export interface PersistentRateLimiterClient { + query(text: string, params?: unknown[]): Promise<{ rows: T[] }>; + release(): void; +} + +export interface PersistentRateLimiterPool { + connect(): Promise; +} + +export interface PersistentRateLimiterStoreOptions { + tableName?: string; +} + +export interface ConfiguredRateLimiterOptions { + maxRequests?: number; + windowMs?: number; +} + +export interface PersistentRateLimiterConfig extends ConfiguredRateLimiterOptions { + store: 'postgres'; + tableName?: string; +} + +export interface InMemoryRateLimiterConfig extends ConfiguredRateLimiterOptions { + store?: 'memory'; +} + +export type RateLimiterConfig = + | InMemoryRateLimiterConfig + | PersistentRateLimiterConfig; + +const DEFAULT_MAX_REQUESTS = 100; +const DEFAULT_WINDOW_MS = 60_000; +const DEFAULT_PERSISTENT_TABLE = 'gateway_rate_limit_buckets'; +const TABLE_NAME_PATTERN = /^[a-z_][a-z0-9_]*$/i; + +type TokenBucketRow = { + bucket_key: string; + tokens: number | string; + last_refill_ms: number | string; +}; + +function normalizePositiveInteger(value: number, label: string): number { + if (!Number.isInteger(value) || value <= 0) { + throw new Error(`${label} must be a positive integer.`); } - check(apiKey: string): RateLimitResult { - const now = Date.now(); - let bucket = this.buckets.get(apiKey); + return value; +} - if (!bucket) { - bucket = { tokens: this.maxRequests, lastRefill: now }; - this.buckets.set(apiKey, bucket); +function normalizeTokenBucketValue( + value: number | string, + label: string, +): number { + if (typeof value === 'string') { + if (!/^\d+$/.test(value.trim())) { + throw new Error(`${label} must be stored as a non-negative integer.`); } + } + + const parsed = typeof value === 'number' ? value : Number.parseInt(value, 10); + + if (!Number.isFinite(parsed) || !Number.isInteger(parsed) || parsed < 0) { + throw new Error(`${label} must be stored as a non-negative integer.`); + } + + return parsed; +} + +function computeRateLimitResult( + bucket: TokenBucket | undefined, + maxRequests: number, + windowMs: number, + now: number, +): { bucket: TokenBucket; result: RateLimitResult } { + const currentBucket = bucket + ? { ...bucket } + : { tokens: maxRequests, lastRefill: now }; + + const elapsed = now - currentBucket.lastRefill; + + if (elapsed >= windowMs) { + currentBucket.tokens = maxRequests; + currentBucket.lastRefill = now; + } + + if (currentBucket.tokens <= 0) { + const retryAfterMs = Math.max( + windowMs - (now - currentBucket.lastRefill), + 0, + ); + + return { + bucket: currentBucket, + result: { allowed: false, retryAfterMs }, + }; + } + + currentBucket.tokens -= 1; + + return { + bucket: currentBucket, + result: { allowed: true }, + }; +} + +function buildLimiterOptions( + maxRequests = DEFAULT_MAX_REQUESTS, + windowMs = DEFAULT_WINDOW_MS, +): RateLimiterStoreCheckOptions { + return { + maxRequests: normalizePositiveInteger(maxRequests, 'maxRequests'), + now: 0, + windowMs: normalizePositiveInteger(windowMs, 'windowMs'), + }; +} + +function assertSafeTableName(tableName: string): string { + if (!TABLE_NAME_PATTERN.test(tableName)) { + throw new Error( + 'Rate limiter tableName must contain only letters, numbers, and underscores.', + ); + } + + return tableName; +} + +async function rollbackQuietly(client: PersistentRateLimiterClient): Promise { + try { + await client.query('ROLLBACK'); + } catch { + // Ignore rollback errors so we surface the original failure. + } +} + +export class InMemoryRateLimiterStore implements RateLimiterStore { + private readonly buckets = new Map(); + + async check( + bucketKey: string, + options: RateLimiterStoreCheckOptions, + ): Promise { + const existingBucket = this.buckets.get(bucketKey); + const { bucket, result } = computeRateLimitResult( + existingBucket, + options.maxRequests, + options.windowMs, + options.now, + ); + + this.buckets.set(bucketKey, bucket); + return result; + } - // Refill tokens if the window has elapsed - const elapsed = now - bucket.lastRefill; - if (elapsed >= this.windowMs) { - bucket.tokens = this.maxRequests; - bucket.lastRefill = now; + exhaust(bucketKey: string): void { + this.buckets.set(bucketKey, { tokens: 0, lastRefill: Date.now() }); + } + + reset(): void { + this.buckets.clear(); + } +} + +export class PostgresRateLimiterStore implements RateLimiterStore { + private readonly pool: PersistentRateLimiterPool; + private readonly tableName: string; + private tableReadyPromise: Promise | null = null; + + constructor( + pool: PersistentRateLimiterPool, + options: PersistentRateLimiterStoreOptions = {}, + ) { + this.pool = pool; + this.tableName = assertSafeTableName( + options.tableName ?? DEFAULT_PERSISTENT_TABLE, + ); + } + + async check( + bucketKey: string, + options: RateLimiterStoreCheckOptions, + ): Promise { + await this.ensureTable(); + + const client = await this.pool.connect(); + + try { + await client.query('BEGIN'); + + await client.query( + `INSERT INTO ${this.tableName} ( + bucket_key, + tokens, + last_refill_ms + ) VALUES ($1, $2, $3) + ON CONFLICT (bucket_key) DO NOTHING`, + [bucketKey, options.maxRequests, options.now], + ); + + const existingRow = await client.query( + `SELECT bucket_key, tokens, last_refill_ms + FROM ${this.tableName} + WHERE bucket_key = $1 + FOR UPDATE`, + [bucketKey], + ); + + if (!existingRow.rows[0]) { + throw new Error(`Rate limiter bucket "${bucketKey}" was not found after initialization.`); + } + + const bucket = { + tokens: normalizeTokenBucketValue(existingRow.rows[0].tokens, 'tokens'), + lastRefill: normalizeTokenBucketValue( + existingRow.rows[0].last_refill_ms, + 'last_refill_ms', + ), + }; + + const { bucket: nextBucket, result } = computeRateLimitResult( + bucket, + options.maxRequests, + options.windowMs, + options.now, + ); + + await client.query( + `UPDATE ${this.tableName} + SET tokens = $2, + last_refill_ms = $3, + updated_at = NOW() + WHERE bucket_key = $1`, + [bucketKey, nextBucket.tokens, nextBucket.lastRefill], + ); + + await client.query('COMMIT'); + return result; + } catch (error) { + await rollbackQuietly(client); + throw error; + } finally { + client.release(); } + } + + private async ensureTable(): Promise { + if (!this.tableReadyPromise) { + this.tableReadyPromise = this.createTableIfNeeded().catch((error) => { + this.tableReadyPromise = null; + throw error; + }); + } + + await this.tableReadyPromise; + } + + private async createTableIfNeeded(): Promise { + const client = await this.pool.connect(); + + try { + await client.query(` + CREATE TABLE IF NOT EXISTS ${this.tableName} ( + bucket_key TEXT PRIMARY KEY, + tokens INTEGER NOT NULL CHECK (tokens >= 0), + last_refill_ms BIGINT NOT NULL CHECK (last_refill_ms >= 0), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + `); - if (bucket.tokens <= 0) { - const retryAfterMs = this.windowMs - (now - bucket.lastRefill); - return { allowed: false, retryAfterMs: Math.max(retryAfterMs, 0) }; + await client.query(` + CREATE INDEX IF NOT EXISTS ${this.tableName}_updated_at_idx + ON ${this.tableName} (updated_at) + `); + } finally { + client.release(); } + } +} + +export class StoreBackedRateLimiter implements RateLimiter { + protected readonly maxRequests: number; + protected readonly store: RateLimiterStore; + protected readonly windowMs: number; + + constructor( + maxRequests: number, + windowMs: number, + store: RateLimiterStore, + ) { + const baseOptions = buildLimiterOptions(maxRequests, windowMs); + + this.maxRequests = baseOptions.maxRequests; + this.windowMs = baseOptions.windowMs; + this.store = store; + } + + check(apiKey: string): Promise { + return this.store.check(apiKey, { + maxRequests: this.maxRequests, + now: Date.now(), + windowMs: this.windowMs, + }); + } +} + +/** + * Simple token-bucket rate limiter. + * Each API key gets `maxRequests` tokens per `windowMs` window. + */ +export class InMemoryRateLimiter extends StoreBackedRateLimiter { + private readonly inMemoryStore: InMemoryRateLimiterStore; - bucket.tokens -= 1; - return { allowed: true }; + constructor(maxRequests: number, windowMs: number) { + const inMemoryStore = new InMemoryRateLimiterStore(); + super(maxRequests, windowMs, inMemoryStore); + this.inMemoryStore = inMemoryStore; } - /** Helper for tests — exhaust all tokens for a given key. */ + /** Helper for tests - exhaust all tokens for a given key. */ exhaust(apiKey: string): void { - this.buckets.set(apiKey, { tokens: 0, lastRefill: Date.now() }); + this.inMemoryStore.exhaust(apiKey); } - /** Helper for tests — reset all buckets. */ + /** Helper for tests - reset all buckets. */ reset(): void { - this.buckets.clear(); + this.inMemoryStore.reset(); } } export function createRateLimiter( - maxRequests = 100, - windowMs = 60_000, + maxRequests = DEFAULT_MAX_REQUESTS, + windowMs = DEFAULT_WINDOW_MS, ): InMemoryRateLimiter { return new InMemoryRateLimiter(maxRequests, windowMs); } + +export function createConfiguredRateLimiter( + config: RateLimiterConfig, + persistentPool?: PersistentRateLimiterPool, +): RateLimiter { + const maxRequests = config.maxRequests ?? DEFAULT_MAX_REQUESTS; + const windowMs = config.windowMs ?? DEFAULT_WINDOW_MS; + + if (config.store === 'postgres') { + if (!persistentPool) { + throw new Error( + 'A PostgreSQL pool is required when RATE_LIMIT_STORE is set to "postgres".', + ); + } + + return new StoreBackedRateLimiter( + maxRequests, + windowMs, + new PostgresRateLimiterStore(persistentPool, { + tableName: config.tableName, + }), + ); + } + + return createRateLimiter(maxRequests, windowMs); +} + +export function isPersistentRateLimiterStore( + store: RateLimiterStore, +): store is PostgresRateLimiterStore { + return store instanceof PostgresRateLimiterStore; +} + +export type RateLimiterPgClient = PoolClient; diff --git a/src/types/gateway.ts b/src/types/gateway.ts index 0b7ac8f..0df5b5c 100644 --- a/src/types/gateway.ts +++ b/src/types/gateway.ts @@ -70,7 +70,7 @@ export interface BillingService { /** Interface for rate limiting. */ export interface RateLimiter { - check(apiKey: string): RateLimitResult; + check(apiKey: string): Promise; } /** Interface for recording and querying usage events. */