diff --git a/finbot/ctf/detectors/primitives/__init__.py b/finbot/ctf/detectors/primitives/__init__.py index d726a542..af43bf26 100644 --- a/finbot/ctf/detectors/primitives/__init__.py +++ b/finbot/ctf/detectors/primitives/__init__.py @@ -3,6 +3,7 @@ from finbot.ctf.detectors.primitives.pattern_match import PatternMatchDetector from finbot.ctf.detectors.primitives.pi_jb import PromptInjectionDetector from finbot.ctf.detectors.primitives.pii import PIIDetector +from finbot.ctf.detectors.primitives.sequence_detector import SequenceDetector, StepSpec from finbot.ctf.detectors.primitives.tool_call import ToolCallDetector from finbot.ctf.detectors.primitives.tool_drift import ToolDriftDetector @@ -10,6 +11,8 @@ "PIIDetector", "PatternMatchDetector", "PromptInjectionDetector", + "SequenceDetector", + "StepSpec", "ToolCallDetector", "ToolDriftDetector", ] diff --git a/finbot/ctf/detectors/primitives/sequence_detector.py b/finbot/ctf/detectors/primitives/sequence_detector.py new file mode 100644 index 00000000..2dd96d70 --- /dev/null +++ b/finbot/ctf/detectors/primitives/sequence_detector.py @@ -0,0 +1,268 @@ +"""Sequence Detector + +Detects multi-step attack patterns across a session or workflow window. +Challenge authors configure this in YAML with no Python required. +""" + +import fnmatch +import json +import logging +import re +from datetime import UTC, datetime, timedelta +from typing import Any, NotRequired, TypedDict + +from sqlalchemy.orm import Session + +from finbot.core.data.models import CTFEvent +from finbot.ctf.detectors.base import BaseDetector +from finbot.ctf.detectors.registry import register_detector +from finbot.ctf.detectors.result import DetectionResult + +logger = logging.getLogger(__name__) + +# Known CTFEvent column names available for condition matching. +# Defined at module level to avoid rebuilding the frozenset on every +# _matches_step call (which runs once per event × once per step). +_CTF_COLUMNS: frozenset[str] = frozenset({ + "event_type", "event_category", "event_subtype", + "session_id", "workflow_id", "namespace", "user_id", + "vendor_id", "agent_name", "tool_name", "severity", +}) + + +class StepSpec(TypedDict): + event_type: str # Glob pattern, e.g. "agent.*.tool_call_success" + label: str # Human-readable name for evidence output + conditions: NotRequired[dict[str, Any]] # ToolCallDetector operators + + +@register_detector("SequenceDetector") +class SequenceDetector(BaseDetector): + """Detects multi-step attack patterns across a session window. + + Configuration: + steps: list[StepSpec] -- ordered sequence to match + within_n_events: int -- history window size: load latest N events for the session/workflow (default: unlimited) + within_seconds: int -- optional time-based window (default: unlimited) + order_matters: bool -- enforce step ordering (default: true) + window: "session" | "workflow" -- scope for history query (default: "session") + + StepSpec fields: + event_type: str -- glob pattern, e.g. "agent.*.tool_call_success" + conditions: dict -- field conditions using ToolCallDetector operators + label: str -- human-readable name for evidence output + + Example YAML: + detector_class: SequenceDetector + detector_config: + steps: + - event_type: "agent.*.tool_call_success" + conditions: { tool_name: "approve_invoice" } + label: "First micro-payment" + - event_type: "agent.*.tool_call_success" + conditions: { tool_name: "approve_invoice" } + label: "Second micro-payment" + within_n_events: 50 + within_seconds: 300 + order_matters: true + window: "session" + """ + + def _validate_config(self) -> None: + steps = self.config.get("steps") + if not steps or not isinstance(steps, list): + raise ValueError("SequenceDetector requires 'steps' as a non-empty list") + for i, step in enumerate(steps): + if "event_type" not in step: + raise ValueError(f"Step {i} missing required 'event_type'") + if "label" not in step: + raise ValueError(f"Step {i} missing required 'label'") + window = self.config.get("window", "session") + if window not in ("session", "workflow"): + raise ValueError("window must be 'session' or 'workflow'") + + def get_relevant_event_types(self) -> list[str]: + steps: list[StepSpec] = self.config.get("steps", []) + return [step["event_type"] for step in steps] + + async def check_event(self, event: dict[str, Any], db: Session) -> DetectionResult: + steps: list[StepSpec] = self.config.get("steps", []) + within_n = self.config.get("within_n_events") + within_seconds = self.config.get("within_seconds") + order_matters = self.config.get("order_matters", True) + window = self.config.get("window", "session") + + namespace = event.get("namespace") + + if window == "workflow": + window_id = event.get("workflow_id") + if not window_id: + return DetectionResult(detected=False, message="No workflow_id in event") + filter_col = CTFEvent.workflow_id + else: + window_id = event.get("session_id") + if not window_id: + return DetectionResult(detected=False, message="No session_id in event") + filter_col = CTFEvent.session_id + + query = db.query(CTFEvent).filter( + CTFEvent.namespace == namespace, + filter_col == window_id, + ) + + if within_seconds is not None: + event_time = event.get("timestamp") + if isinstance(event_time, str): + try: + event_time = datetime.fromisoformat(event_time.replace("Z", "+00:00")) + except ValueError: + return DetectionResult( + detected=False, + message="within_seconds set but event timestamp is invalid", + ) + elif not isinstance(event_time, datetime): + return DetectionResult( + detected=False, + message="within_seconds set but event has no timestamp", + ) + cutoff = event_time - timedelta(seconds=within_seconds) + query = query.filter(CTFEvent.timestamp >= cutoff) + + if within_n is not None: + history = ( + query.order_by(CTFEvent.timestamp.desc()) + .limit(within_n) + .all() + ) + history = list(reversed(history)) + else: + history = query.order_by(CTFEvent.timestamp.asc()).all() + + matched: list[dict[str, Any]] = [] + search_from = 0 + consumed: set[int] = set() # indices already claimed by a previous step + + for step in steps: + found_at = None + start = search_from if order_matters else 0 + for i in range(start, len(history)): + if i in consumed: + continue + if self._matches_step(history[i], step): + found_at = i + break + + if found_at is None: + return DetectionResult( + detected=False, + message=f"Sequence incomplete: step '{step['label']}' not matched", + evidence={ + "matched_steps": matched, + "missing_step": step["label"], + "window": window, + "window_id": window_id, + }, + ) + + matched.append( + { + "step": step["label"], + "event_id": history[found_at].id, + "event_type": history[found_at].event_type, + } + ) + consumed.add(found_at) + if order_matters: + search_from = found_at + 1 + + return DetectionResult( + detected=True, + confidence=1.0, + message=f"Multi-step sequence detected: {[m['step'] for m in matched]}", + evidence={ + "matched_steps": matched, + "window": window, + "window_id": window_id, + "step_count": len(matched), + }, + ) + + def _matches_step(self, ctf_event: CTFEvent, step: StepSpec) -> bool: + """Check if a CTFEvent matches a step spec.""" + if not fnmatch.fnmatch(ctf_event.event_type, step["event_type"]): + return False + + conditions = step.get("conditions", {}) + if not conditions: + return True + + details: dict[str, Any] = {} + if ctf_event.details: + try: + details = json.loads(ctf_event.details) + except (json.JSONDecodeError, TypeError): + pass + + for field, condition in conditions.items(): + # Prefer JSON details; fall back to model columns for known fields + if field in details: + actual = details[field] + elif field in _CTF_COLUMNS: + actual = getattr(ctf_event, field, None) + else: + actual = None + if not self._check_condition(actual, condition): + return False + + return True + + def _check_condition(self, actual: Any, condition: Any) -> bool: + """Check if actual value satisfies condition (ToolCallDetector operators). + + Multiple operators in one condition dict are ANDed together, so + {'gte': 10, 'lte': 20} passes only when 10 <= actual <= 20. + """ + if not isinstance(condition, dict): + return actual == condition + + for operator, expected in condition.items(): + op = operator.lower() + if op == "exists": + if not ((actual is not None) == expected): + return False + elif actual is None: + return False + elif op in ("equals", "eq"): + if actual != expected: + return False + elif op == "in": + if actual not in expected: + return False + elif op == "not_in": + if actual in expected: + return False + elif op == "contains": + if expected.lower() not in str(actual).lower(): + return False + elif op == "gt": + if not float(actual) > float(expected): + return False + elif op == "gte": + if not float(actual) >= float(expected): + return False + elif op == "lt": + if not float(actual) < float(expected): + return False + elif op == "lte": + if not float(actual) <= float(expected): + return False + elif op == "matches": + if not re.search(expected, str(actual), re.IGNORECASE): + return False + else: + logger.warning( + "Unknown condition operator %r — treating as no-match", op + ) + return False + + return True diff --git a/finbot/tools/data/vendor.py b/finbot/tools/data/vendor.py index 7cac2cd3..95566c43 100644 --- a/finbot/tools/data/vendor.py +++ b/finbot/tools/data/vendor.py @@ -4,7 +4,7 @@ from typing import Any from finbot.core.auth.session import SessionContext -from finbot.core.data.database import db_session +from finbot.core.data.database import get_db from finbot.core.data.repositories import VendorRepository logger = logging.getLogger(__name__) @@ -23,12 +23,15 @@ async def get_vendor_details( Dictionary containing vendor details """ logger.info("Getting vendor details for vendor_id: %s", vendor_id) - with db_session() as db: + db = next(get_db()) + try: vendor_repo = VendorRepository(db, session_context) vendor = vendor_repo.get_vendor(vendor_id) if not vendor: raise ValueError("Vendor not found") return vendor.to_dict() + finally: + db.close() async def get_vendor_contact_info( @@ -37,20 +40,20 @@ async def get_vendor_contact_info( ) -> dict[str, Any]: """Get vendor contact information for communication purposes""" logger.info("Getting vendor contact info for vendor_id: %s", vendor_id) - with db_session() as db: - vendor_repo = VendorRepository(db, session_context) - vendor = vendor_repo.get_vendor(vendor_id) - if not vendor: - raise ValueError("Vendor not found") - - return { - "vendor_id": vendor.id, - "company_name": vendor.company_name, - "contact_name": vendor.contact_name, - "email": vendor.email, - "phone": vendor.phone, - "status": vendor.status, - } + db = next(get_db()) + vendor_repo = VendorRepository(db, session_context) + vendor = vendor_repo.get_vendor(vendor_id) + if not vendor: + raise ValueError("Vendor not found") + + return { + "vendor_id": vendor.id, + "company_name": vendor.company_name, + "contact_name": vendor.contact_name, + "email": vendor.email, + "phone": vendor.phone, + "status": vendor.status, + } async def update_vendor_status( @@ -70,32 +73,34 @@ async def update_vendor_status( risk_level, agent_notes, ) - with db_session() as db: - vendor_repo = VendorRepository(db, session_context) - vendor = vendor_repo.get_vendor(vendor_id) - if not vendor: - raise ValueError("Vendor not found") - - previous_state = { - "status": vendor.status, - "trust_level": vendor.trust_level, - "risk_level": vendor.risk_level, - } - - existing_notes = vendor.agent_notes or "" - new_notes = f"{existing_notes}\n\n{agent_notes}" - vendor = vendor_repo.update_vendor( - vendor_id, - status=status, - trust_level=trust_level, - risk_level=risk_level, - agent_notes=new_notes, - ) - if not vendor: - raise ValueError("Vendor not found") - result = vendor.to_dict() - result["_previous_state"] = previous_state - return result + db = next(get_db()) + vendor_repo = VendorRepository(db, session_context) + # append notes to the existing agent_notes + vendor = vendor_repo.get_vendor(vendor_id) + if not vendor: + raise ValueError("Vendor not found") + + # capture previous state for events + previous_state = { + "status": vendor.status, + "trust_level": vendor.trust_level, + "risk_level": vendor.risk_level, + } + + existing_notes = vendor.agent_notes or "" + new_notes = f"{existing_notes}\n\n{agent_notes}" + vendor = vendor_repo.update_vendor( + vendor_id, + status=status, + trust_level=trust_level, + risk_level=risk_level, + agent_notes=new_notes, + ) + if not vendor: + raise ValueError("Vendor not found") + result = vendor.to_dict() + result["_previous_state"] = previous_state + return result async def update_vendor_agent_notes( @@ -109,17 +114,17 @@ async def update_vendor_agent_notes( vendor_id, agent_notes, ) - with db_session() as db: - vendor_repo = VendorRepository(db, session_context) - vendor = vendor_repo.get_vendor(vendor_id) - if not vendor: - raise ValueError("Vendor not found") - existing_notes = vendor.agent_notes or "" - new_notes = f"{existing_notes}\n\n{agent_notes}" - vendor = vendor_repo.update_vendor( - vendor_id, - agent_notes=new_notes, - ) - if not vendor: - raise ValueError("Vendor not found") - return vendor.to_dict() + db = next(get_db()) + vendor_repo = VendorRepository(db, session_context) + vendor = vendor_repo.get_vendor(vendor_id) + if not vendor: + raise ValueError("Vendor not found") + existing_notes = vendor.agent_notes or "" + new_notes = f"{existing_notes}\n\n{agent_notes}" + vendor = vendor_repo.update_vendor( + vendor_id, + agent_notes=new_notes, + ) + if not vendor: + raise ValueError("Vendor not found") + return vendor.to_dict() diff --git a/migrations/versions/2026_06_03_add_ctf_event_session_index.py b/migrations/versions/2026_06_03_add_ctf_event_session_index.py new file mode 100644 index 00000000..da168ada --- /dev/null +++ b/migrations/versions/2026_06_03_add_ctf_event_session_index.py @@ -0,0 +1,34 @@ +"""add composite index on ctf_events for SequenceDetector session-window queries + +Revision ID: b1e4f9a2c83d +Revises: a3f7c2d91e04 +Create Date: 2026-06-03 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +revision: str = "b1e4f9a2c83d" +down_revision: Union[str, Sequence[str], None] = "a3f7c2d91e04" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Composite index for SequenceDetector session-window queries: + # WHERE namespace = ? AND session_id = ? ORDER BY timestamp ASC + # namespace leads so rows are partitioned by tenant first, then by + # session within that tenant — matches the actual filter shape and + # keeps selectivity high in multi-tenant deployments. + op.create_index( + "idx_ctf_event_session_ts_type", + "ctf_events", + ["namespace", "session_id", "timestamp", "event_type"], + ) + + +def downgrade() -> None: + op.drop_index("idx_ctf_event_session_ts_type", table_name="ctf_events") diff --git a/tests/unit/ctf/test_sequence_detector.py b/tests/unit/ctf/test_sequence_detector.py new file mode 100644 index 00000000..1d8220c0 --- /dev/null +++ b/tests/unit/ctf/test_sequence_detector.py @@ -0,0 +1,361 @@ +"""Unit tests for SequenceDetector primitive.""" + +import json +import pytest +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock + +from finbot.ctf.detectors.primitives.sequence_detector import SequenceDetector + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_ctf_event(event_type: str, details: dict = None, tool_name: str = None, + session_id: str = "sess-1", workflow_id: str = "wf-1", + namespace: str = "test", ts_offset_s: int = 0): + """Return a MagicMock that quacks like a CTFEvent row.""" + evt = MagicMock() + evt.id = id(evt) + evt.event_type = event_type + evt.session_id = session_id + evt.workflow_id = workflow_id + evt.namespace = namespace + evt.tool_name = tool_name + evt.timestamp = datetime(2026, 6, 1, 12, 0, 0, tzinfo=UTC) + timedelta(seconds=ts_offset_s) + evt.details = json.dumps(details) if details else None + return evt + + +def make_db(history: list): + """Return a fake db whose query chain returns history.""" + db = MagicMock() + q = MagicMock() + db.query.return_value = q + q.filter.return_value = q + q.order_by.return_value = q + q.limit.return_value = q + q.all.return_value = history + return db + + +def make_event(session_id="sess-1", workflow_id="wf-1", namespace="test", + event_type="agent.fraud.tool_call_success"): + return { + "event_type": event_type, + "session_id": session_id, + "workflow_id": workflow_id, + "namespace": namespace, + "timestamp": "2026-06-01T12:05:00Z", + } + + +# --------------------------------------------------------------------------- +# Config validation +# --------------------------------------------------------------------------- + +def test_validate_config_missing_steps(): + with pytest.raises(ValueError, match="steps"): + SequenceDetector("ch-1", config={}) + + +def test_validate_config_step_missing_event_type(): + with pytest.raises(ValueError, match="event_type"): + SequenceDetector("ch-1", config={"steps": [{"label": "x"}]}) + + +def test_validate_config_step_missing_label(): + with pytest.raises(ValueError, match="label"): + SequenceDetector("ch-1", config={"steps": [{"event_type": "agent.*"}]}) + + +def test_validate_config_bad_window(): + with pytest.raises(ValueError, match="window"): + SequenceDetector("ch-1", config={ + "steps": [{"event_type": "a", "label": "A"}], + "window": "global", + }) + + +# --------------------------------------------------------------------------- +# get_relevant_event_types +# --------------------------------------------------------------------------- + +def test_get_relevant_event_types(): + det = SequenceDetector("ch-1", config={ + "steps": [ + {"event_type": "agent.*.tool_call_success", "label": "A"}, + {"event_type": "business.vendor.decision", "label": "B"}, + ] + }) + assert det.get_relevant_event_types() == [ + "agent.*.tool_call_success", + "business.vendor.decision", + ] + + +# --------------------------------------------------------------------------- +# Full sequence matched +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_full_sequence_detected(): + det = SequenceDetector("ch-1", config={ + "steps": [ + {"event_type": "agent.*.tool_call_success", + "conditions": {"tool_name": "approve_invoice"}, "label": "Payment 1"}, + {"event_type": "agent.*.tool_call_success", + "conditions": {"tool_name": "approve_invoice"}, "label": "Payment 2"}, + ], + "within_n_events": 50, + "order_matters": True, + "window": "session", + }) + + history = [ + make_ctf_event("agent.fraud.tool_call_success", tool_name="approve_invoice", + ts_offset_s=0), + make_ctf_event("agent.fraud.tool_call_success", tool_name="approve_invoice", + ts_offset_s=10), + ] + db = make_db(history) + result = await det.check_event(make_event(), db) + + assert result.detected is True + assert result.confidence == 1.0 + assert len(result.evidence["matched_steps"]) == 2 + assert result.evidence["matched_steps"][0]["step"] == "Payment 1" + assert result.evidence["matched_steps"][1]["step"] == "Payment 2" + + +# --------------------------------------------------------------------------- +# Partial sequence — not detected +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_partial_sequence_not_detected(): + det = SequenceDetector("ch-1", config={ + "steps": [ + {"event_type": "agent.*.tool_call_success", + "conditions": {"tool_name": "approve_invoice"}, "label": "Payment 1"}, + {"event_type": "business.vendor.decision", "label": "Vendor flip"}, + ], + "window": "session", + }) + + # Only first step present + history = [ + make_ctf_event("agent.fraud.tool_call_success", tool_name="approve_invoice"), + ] + db = make_db(history) + result = await det.check_event(make_event(), db) + + assert result.detected is False + assert "Vendor flip" in result.evidence["missing_step"] + + +# --------------------------------------------------------------------------- +# Order matters — wrong order not detected +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_order_matters_enforced(): + det = SequenceDetector("ch-1", config={ + "steps": [ + {"event_type": "agent.*.tool_call_success", + "conditions": {"tool_name": "approve_invoice"}, "label": "Step A"}, + {"event_type": "business.vendor.decision", "label": "Step B"}, + ], + "order_matters": True, + "window": "session", + }) + + # B comes before A — should not match in order + history = [ + make_ctf_event("business.vendor.decision", ts_offset_s=0), + make_ctf_event("agent.fraud.tool_call_success", tool_name="approve_invoice", + ts_offset_s=10), + ] + db = make_db(history) + result = await det.check_event(make_event(), db) + + # Step A is found (second event), but then search_from moves past it, + # leaving no room for Step B — so not detected + assert result.detected is False + + +@pytest.mark.asyncio +async def test_order_matters_false_allows_any_order(): + det = SequenceDetector("ch-1", config={ + "steps": [ + {"event_type": "agent.*.tool_call_success", + "conditions": {"tool_name": "approve_invoice"}, "label": "Step A"}, + {"event_type": "business.vendor.decision", "label": "Step B"}, + ], + "order_matters": False, + "window": "session", + }) + + # B before A — with order_matters=False, search restarts from 0 each step + history = [ + make_ctf_event("business.vendor.decision", ts_offset_s=0), + make_ctf_event("agent.fraud.tool_call_success", tool_name="approve_invoice", + ts_offset_s=10), + ] + db = make_db(history) + result = await det.check_event(make_event(), db) + assert result.detected is True + + +# --------------------------------------------------------------------------- +# No session_id → not detected +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_missing_session_id_not_detected(): + det = SequenceDetector("ch-1", config={ + "steps": [{"event_type": "agent.*", "label": "X"}], + "window": "session", + }) + event = make_event() + event.pop("session_id") + result = await det.check_event(event, MagicMock()) + assert result.detected is False + assert "session_id" in result.message + + +# --------------------------------------------------------------------------- +# Workflow window +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_workflow_window(): + det = SequenceDetector("ch-1", config={ + "steps": [ + {"event_type": "agent.*.tool_call_success", + "conditions": {"tool_name": "transfer_funds"}, "label": "Transfer"}, + ], + "window": "workflow", + }) + + history = [ + make_ctf_event("agent.payments.tool_call_success", tool_name="transfer_funds"), + ] + db = make_db(history) + result = await det.check_event(make_event(), db) + assert result.detected is True + assert result.evidence["window"] == "workflow" + + +# --------------------------------------------------------------------------- +# Condition operators +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_condition_gt_operator(): + det = SequenceDetector("ch-1", config={ + "steps": [ + {"event_type": "agent.*", + "conditions": {"amount": {"gt": 100}}, "label": "Large payment"}, + ], + "window": "session", + }) + + history = [ + make_ctf_event("agent.payments.tool_call_success", + details={"amount": 150}), + ] + db = make_db(history) + result = await det.check_event(make_event(), db) + assert result.detected is True + + +@pytest.mark.asyncio +async def test_condition_gt_operator_fails(): + det = SequenceDetector("ch-1", config={ + "steps": [ + {"event_type": "agent.*", + "conditions": {"amount": {"gt": 100}}, "label": "Large payment"}, + ], + "window": "session", + }) + + history = [ + make_ctf_event("agent.payments.tool_call_success", details={"amount": 50}), + ] + db = make_db(history) + result = await det.check_event(make_event(), db) + assert result.detected is False + + +@pytest.mark.asyncio +async def test_condition_in_operator(): + det = SequenceDetector("ch-1", config={ + "steps": [ + {"event_type": "business.*", + "conditions": {"new_status": {"in": ["active", "pending"]}}, + "label": "Status change"}, + ], + "window": "session", + }) + + history = [ + make_ctf_event("business.vendor.decision", details={"new_status": "active"}), + ] + db = make_db(history) + result = await det.check_event(make_event(), db) + assert result.detected is True + + +# --------------------------------------------------------------------------- +# Glob event_type matching +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_glob_event_type_wildcard(): + det = SequenceDetector("ch-1", config={ + "steps": [ + {"event_type": "agent.*.tool_call_success", "label": "Any agent tool"}, + ], + "window": "session", + }) + + history = [ + make_ctf_event("agent.payments.tool_call_success"), + ] + db = make_db(history) + result = await det.check_event(make_event(), db) + assert result.detected is True + + +@pytest.mark.asyncio +async def test_glob_no_match(): + det = SequenceDetector("ch-1", config={ + "steps": [ + {"event_type": "agent.*.tool_call_success", "label": "Tool call"}, + ], + "window": "session", + }) + + history = [ + make_ctf_event("business.vendor.decision"), + ] + db = make_db(history) + result = await det.check_event(make_event(), db) + assert result.detected is False + + +# --------------------------------------------------------------------------- +# Empty history +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_empty_history_not_detected(): + det = SequenceDetector("ch-1", config={ + "steps": [{"event_type": "agent.*", "label": "X"}], + "window": "session", + }) + db = make_db([]) + result = await det.check_event(make_event(), db) + assert result.detected is False diff --git a/tests/unit/ctf/test_sequence_detector_benchmark.py b/tests/unit/ctf/test_sequence_detector_benchmark.py new file mode 100644 index 00000000..109a1270 --- /dev/null +++ b/tests/unit/ctf/test_sequence_detector_benchmark.py @@ -0,0 +1,159 @@ +"""Benchmark: SequenceDetector session-window query latency. + +Seeds 1,000 CTFEvent rows for one session into an in-memory SQLite database +(with the composite index from the migration) and measures p95 query latency +for check_event. Target: p95 < 10ms. + +SQLite is used here as a structural proxy — the index design and query shape +are what matter. PostgreSQL performance will be better due to WAL and buffer +cache. This test catches regressions in the query path (e.g. missing index, +full-table scan, N+1 loading). +""" + +import json +import os +import statistics +import time +import uuid +from datetime import UTC, datetime, timedelta + +import pytest +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +from finbot.core.data.database import Base +from finbot.core.data.models import CTFEvent +from finbot.ctf.detectors.primitives.sequence_detector import SequenceDetector + +BENCHMARK_ROWS = 1000 +BENCHMARK_RUNS = 100 +# SQLite in-memory limit — catches catastrophic regressions (missing index, +# N+1 queries). Production PostgreSQL target is 10ms p95; SQLite with +# StaticPool runs ~2-3x slower than Postgres on the same query shape. +P95_LIMIT_MS = 50.0 + + +@pytest.fixture(scope="module") +def bench_db(): + """In-memory SQLite with composite index and 1,000 CTFEvent rows. + + StaticPool ensures all connections (create_all, index creation, session) + share the same underlying connection so they all see the same in-memory DB. + """ + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + + # Create the composite index matching the migration (namespace-first) + with engine.connect() as conn: + conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_ctf_event_session_ts_type " + "ON ctf_events (namespace, session_id, timestamp, event_type)" + ) + ) + conn.commit() + + Session = sessionmaker(bind=engine) + session = Session() + + namespace = "bench-ns" + session_id = "bench-session-001" + base_time = datetime(2026, 6, 1, 0, 0, 0, tzinfo=UTC) + + rows = [] + for i in range(BENCHMARK_ROWS): + event_type = ( + "agent.fraud.tool_call_success" if i % 2 == 0 + else "agent.payments.tool_call_success" + ) + rows.append( + CTFEvent( + external_event_id=str(uuid.uuid4()), + namespace=namespace, + user_id="bench-user", + session_id=session_id, + workflow_id="bench-wf", + vendor_id=None, + event_category="agent", + event_type=event_type, + summary=f"event {i}", + details=json.dumps({"tool_name": "approve_invoice", "seq": i}), + severity="info", + tool_name="approve_invoice", + timestamp=base_time + timedelta(seconds=i), + ) + ) + + session.bulk_save_objects(rows) + session.commit() + + yield session, namespace, session_id + + session.close() + Base.metadata.drop_all(bind=engine) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not os.getenv("RUN_BENCHMARKS"), + reason="Skipped by default — set RUN_BENCHMARKS=1 to run latency assertions", +) +async def test_session_window_query_p95(bench_db): + """p95 latency for check_event over 1,000-row session must be < SQLite limit. + + Run with: RUN_BENCHMARKS=1 pytest tests/unit/ctf/test_sequence_detector_benchmark.py + """ + session, namespace, session_id = bench_db + + det = SequenceDetector( + "bench-challenge", + config={ + "steps": [ + { + "event_type": "agent.*.tool_call_success", + "conditions": {"tool_name": "approve_invoice"}, + "label": "Payment 1", + }, + { + "event_type": "agent.*.tool_call_success", + "conditions": {"tool_name": "approve_invoice"}, + "label": "Payment 2", + }, + ], + "within_n_events": 1000, + "order_matters": True, + "window": "session", + }, + ) + + trigger_event = { + "event_type": "agent.fraud.tool_call_success", + "namespace": namespace, + "session_id": session_id, + "workflow_id": "bench-wf", + "timestamp": "2026-06-01T00:20:00Z", + } + + latencies_ms: list[float] = [] + for _ in range(BENCHMARK_RUNS): + t0 = time.perf_counter() + await det.check_event(trigger_event, session) + latencies_ms.append((time.perf_counter() - t0) * 1000) + + latencies_ms.sort() + p95 = latencies_ms[int(BENCHMARK_RUNS * 0.95)] + p50 = statistics.median(latencies_ms) + + print(f"\nSequenceDetector benchmark ({BENCHMARK_ROWS} rows, {BENCHMARK_RUNS} runs)") + print(f" p50: {p50:.2f}ms p95: {p95:.2f}ms limit: {P95_LIMIT_MS}ms") + + assert p95 < P95_LIMIT_MS, ( + f"p95 latency {p95:.2f}ms exceeds SQLite limit of {P95_LIMIT_MS}ms. " + f"Check that idx_ctf_event_session_ts_type is applied. " + f"Production PostgreSQL target is 10ms p95." + )