Skip to content

Commit 4d549ae

Browse files
committed
feat: CORS support + API rate limiting middleware
CORS: opt-in via CONTENTPIPE_CORS_ORIGINS env var (comma-separated). Rate limit: 60/min per IP on POST /api/* endpoints by default. Configurable via CONTENTPIPE_RATE_LIMIT (e.g. '100/hour', '0' to disable).
1 parent 991e6d8 commit 4d549ae

2 files changed

Lines changed: 93 additions & 0 deletions

File tree

scripts/web/app.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99

1010
from __future__ import annotations
1111

12+
import os
1213
import sys
1314
from pathlib import Path
1415

1516
from fastapi import FastAPI
17+
from fastapi.middleware.cors import CORSMiddleware
1618
from fastapi.staticfiles import StaticFiles
1719

1820
from web.auth import AuthMiddleware
21+
from web.ratelimit import RateLimitMiddleware
1922

2023
# 确保 scripts/ 在 sys.path
2124
SCRIPTS_DIR = Path(__file__).parent.parent
@@ -34,6 +37,18 @@
3437
version="0.8.1",
3538
)
3639
app.add_middleware(AuthMiddleware)
40+
app.add_middleware(RateLimitMiddleware)
41+
42+
# CORS — 允许前后端分离部署
43+
_cors_origins = os.environ.get("CONTENTPIPE_CORS_ORIGINS", "").strip()
44+
if _cors_origins:
45+
app.add_middleware(
46+
CORSMiddleware,
47+
allow_origins=[o.strip() for o in _cors_origins.split(",")],
48+
allow_credentials=True,
49+
allow_methods=["*"],
50+
allow_headers=["*"],
51+
)
3752

3853
# 静态文件
3954
STATIC_DIR = Path(__file__).parent / "static"

scripts/web/ratelimit.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
ContentPipe — 简易 API 速率限制
3+
4+
基于 IP 的滑动窗口计数器,无外部依赖。
5+
通过 CONTENTPIPE_RATE_LIMIT 环境变量控制(默认:60/min)。
6+
设为 0 或空值表示不限制。
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import os
12+
import time
13+
from collections import defaultdict
14+
from typing import Callable
15+
16+
from fastapi import Request
17+
from fastapi.responses import JSONResponse
18+
from starlette.middleware.base import BaseHTTPMiddleware
19+
20+
21+
def _parse_rate_limit() -> tuple[int, int]:
22+
"""解析速率限制配置,格式: '{count}/{period}' 例如 '60/min' '100/hour'"""
23+
raw = os.environ.get("CONTENTPIPE_RATE_LIMIT", "60/min").strip()
24+
if not raw or raw == "0":
25+
return 0, 0
26+
27+
try:
28+
count_str, period_str = raw.split("/", 1)
29+
count = int(count_str)
30+
periods = {"sec": 1, "min": 60, "hour": 3600, "day": 86400}
31+
period = periods.get(period_str, 60)
32+
return count, period
33+
except Exception:
34+
return 60, 60 # 默认 60/min
35+
36+
37+
class RateLimitMiddleware(BaseHTTPMiddleware):
38+
"""基于 IP 的简易速率限制"""
39+
40+
def __init__(self, app, max_requests: int = 0, window_seconds: int = 60):
41+
super().__init__(app)
42+
if max_requests == 0:
43+
limit, window = _parse_rate_limit()
44+
self.max_requests = limit
45+
self.window = window
46+
else:
47+
self.max_requests = max_requests
48+
self.window = window_seconds
49+
self._hits: dict[str, list[float]] = defaultdict(list)
50+
51+
async def dispatch(self, request: Request, call_next):
52+
if self.max_requests <= 0:
53+
return await call_next(request)
54+
55+
# 只限制 API 写入端点
56+
path = request.url.path
57+
if not path.startswith("/api/") or request.method in ("GET", "HEAD", "OPTIONS"):
58+
return await call_next(request)
59+
60+
client_ip = request.client.host if request.client else "unknown"
61+
now = time.time()
62+
cutoff = now - self.window
63+
64+
# 清理过期记录
65+
hits = self._hits[client_ip]
66+
self._hits[client_ip] = [t for t in hits if t > cutoff]
67+
hits = self._hits[client_ip]
68+
69+
if len(hits) >= self.max_requests:
70+
retry_after = int(hits[0] - cutoff) + 1
71+
return JSONResponse(
72+
{"detail": "Rate limit exceeded", "retry_after": retry_after},
73+
status_code=429,
74+
headers={"Retry-After": str(retry_after)},
75+
)
76+
77+
hits.append(now)
78+
return await call_next(request)

0 commit comments

Comments
 (0)