Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount
from starlette.routing import Mount, Route

from channel_manager import ChannelManager
from claude_code_client import ClaudeCodeClient
Expand Down Expand Up @@ -80,6 +80,62 @@ async def dispatch(self, request: Request, call_next):
return await call_next(request)


# ---------------------------------------------------------------------------
# REST API 핸들러
# ---------------------------------------------------------------------------


async def api_list_sessions(request: Request) -> JSONResponse:
"""GET /api/sessions — 활성 세션 목록 조회"""
sessions = session_manager.list_sessions()
data = {
"active_count": len(sessions),
"sessions": [s.to_dict() for s in sessions.values()],
}
return JSONResponse(data)


async def api_get_session(request: Request) -> JSONResponse:
"""GET /api/sessions/{user_id} — 특정 유저 세션 상세 조회"""
user_id = request.path_params["user_id"]
message_limit = int(request.query_params.get("message_limit", "20"))
session = session_manager.get_session(user_id)
if session is None:
return JSONResponse(
{"error": f"유저 '{user_id}'의 세션을 찾을 수 없습니다"},
status_code=404,
)
return JSONResponse(session.to_dict(message_limit=message_limit))


async def api_delete_session(request: Request) -> JSONResponse:
"""DELETE /api/sessions/{user_id} — 특정 유저 세션 삭제"""
user_id = request.path_params["user_id"]
session = session_manager.get_session(user_id)
if session is None:
return JSONResponse(
{"error": f"유저 '{user_id}'의 세션을 찾을 수 없습니다"},
status_code=404,
)
session_manager.delete_session(user_id)
return JSONResponse({"message": "세션이 삭제되었습니다", "user_id": user_id})


async def api_cleanup_sessions(request: Request) -> JSONResponse:
"""POST /api/sessions/cleanup — 오래된 세션 일괄 정리"""
hours = 24
try:
body = await request.json()
hours = body.get("hours", 24)
except Exception:
pass
deleted = session_manager.cleanup_old_sessions(hours=hours)
return JSONResponse({
"deleted_count": deleted,
"remaining_active": session_manager.active_count,
})


@asynccontextmanager
async def lifespan(app):
async with session_mgr.run():
Expand All @@ -89,7 +145,13 @@ async def lifespan(app):


starlette_app = Starlette(
routes=[Mount("/mcp", app=session_mgr.handle_request)],
routes=[
Mount("/mcp", app=session_mgr.handle_request),
Route("/api/sessions", api_list_sessions, methods=["GET"]),
Route("/api/sessions/cleanup", api_cleanup_sessions, methods=["POST"]),
Route("/api/sessions/{user_id}", api_get_session, methods=["GET"]),
Route("/api/sessions/{user_id}", api_delete_session, methods=["DELETE"]),
],
lifespan=lifespan,
middleware=[Middleware(APIKeyMiddleware)],
)
Expand Down
147 changes: 147 additions & 0 deletions tests/test_session_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""세션 관리 REST API 엔드포인트 테스트"""

import os

os.environ.setdefault("DISCORD_TOKEN", "test-token")
os.environ.setdefault("DISCORD_GUILD_ID", "123456789")

from datetime import datetime, timedelta
from unittest.mock import patch

from starlette.testclient import TestClient

from server import starlette_app
from session_manager import SessionManager

API_HEADERS = {"X-API-Key": "test-key"}


def _client():
return TestClient(starlette_app, raise_server_exceptions=False)


class TestListSessions:
def test_empty(self):
sm = SessionManager()
with patch("server.API_KEY", "test-key"), patch("server.session_manager", sm):
resp = _client().get("/api/sessions", headers=API_HEADERS)
assert resp.status_code == 200
data = resp.json()
assert data["active_count"] == 0
assert data["sessions"] == []

def test_multiple_sessions(self):
sm = SessionManager()
sm.get_or_create_session("user1").add_message("user", "hello")
sm.get_or_create_session("user2")
with patch("server.API_KEY", "test-key"), patch("server.session_manager", sm):
resp = _client().get("/api/sessions", headers=API_HEADERS)
data = resp.json()
assert data["active_count"] == 2
assert len(data["sessions"]) == 2
user_ids = {s["user_id"] for s in data["sessions"]}
assert user_ids == {"user1", "user2"}

def test_session_fields(self):
sm = SessionManager()
sm.get_or_create_session("user1").add_message("user", "hi")
with patch("server.API_KEY", "test-key"), patch("server.session_manager", sm):
resp = _client().get("/api/sessions", headers=API_HEADERS)
session = resp.json()["sessions"][0]
assert "user_id" in session
assert "message_count" in session
assert "created_at" in session
assert "last_activity" in session
assert session["message_count"] == 1


class TestGetSession:
def test_success(self):
sm = SessionManager()
sm.get_or_create_session("user1").add_message("user", "안녕")
with patch("server.API_KEY", "test-key"), patch("server.session_manager", sm):
resp = _client().get("/api/sessions/user1", headers=API_HEADERS)
assert resp.status_code == 200
data = resp.json()
assert data["user_id"] == "user1"
assert data["message_count"] == 1
assert len(data["recent_messages"]) == 1

def test_not_found(self):
sm = SessionManager()
with patch("server.API_KEY", "test-key"), patch("server.session_manager", sm):
resp = _client().get("/api/sessions/없는유저", headers=API_HEADERS)
assert resp.status_code == 404
assert "찾을 수 없습니다" in resp.json()["error"]

def test_message_limit(self):
sm = SessionManager()
s = sm.get_or_create_session("user1")
for i in range(30):
s.add_message("user", f"msg{i}")
with patch("server.API_KEY", "test-key"), patch("server.session_manager", sm):
resp = _client().get(
"/api/sessions/user1?message_limit=5", headers=API_HEADERS
)
data = resp.json()
assert len(data["recent_messages"]) == 5
assert data["recent_messages"][0]["content"] == "msg25"


class TestDeleteSession:
def test_success(self):
sm = SessionManager()
sm.get_or_create_session("user1")
with patch("server.API_KEY", "test-key"), patch("server.session_manager", sm):
resp = _client().delete("/api/sessions/user1", headers=API_HEADERS)
assert resp.status_code == 200
assert resp.json()["user_id"] == "user1"
assert sm.active_count == 0

def test_not_found(self):
sm = SessionManager()
with patch("server.API_KEY", "test-key"), patch("server.session_manager", sm):
resp = _client().delete("/api/sessions/없는유저", headers=API_HEADERS)
assert resp.status_code == 404


class TestCleanupSessions:
def test_cleanup_old(self):
sm = SessionManager()
old = sm.get_or_create_session("old_user")
old.last_activity = datetime.now() - timedelta(hours=25)
sm.get_or_create_session("new_user")
with patch("server.API_KEY", "test-key"), patch("server.session_manager", sm):
resp = _client().post(
"/api/sessions/cleanup",
headers=API_HEADERS,
json={"hours": 24},
)
assert resp.status_code == 200
data = resp.json()
assert data["deleted_count"] == 1
assert data["remaining_active"] == 1

def test_cleanup_default_hours(self):
sm = SessionManager()
sm.get_or_create_session("user1")
with patch("server.API_KEY", "test-key"), patch("server.session_manager", sm):
resp = _client().post(
"/api/sessions/cleanup", headers=API_HEADERS
)
assert resp.status_code == 200
data = resp.json()
assert data["deleted_count"] == 0
assert data["remaining_active"] == 1

def test_cleanup_custom_hours(self):
sm = SessionManager()
s = sm.get_or_create_session("user1")
s.last_activity = datetime.now() - timedelta(hours=2)
with patch("server.API_KEY", "test-key"), patch("server.session_manager", sm):
resp = _client().post(
"/api/sessions/cleanup",
headers=API_HEADERS,
json={"hours": 1},
)
assert resp.json()["deleted_count"] == 1