Skip to content

Commit a38c079

Browse files
committed
Address PR #1380 review: model_override allowlist + doc clarifications
- Add a runtime allowlist guard for ``model_override`` in ``doc_extract_query_task``. When the optional Django setting ``BENCHMARK_ALLOWED_MODEL_OVERRIDES`` is unset (default), no enforcement runs — preserves operator-only workflows while giving operators a no-code-change path to lock down this surface if the task is ever exposed to untrusted input. Rejected overrides mark the Datacell as failed with a clear stacktrace and re-raise so celery workers log the violation. - Update the merge-from-main fix in ``test_tool_approval_gate``: the earlier refactor of ``_get_function_tools`` to use the public ``agent.toolsets`` API broke the test mock that exposed the old private ``_function_tools`` attribute. Mock now exposes a real ``FunctionToolset`` via ``inst.toolsets`` so all four approval-flow tests pass on this branch. - Document the failure-mode convention on ``Datacell.stacktrace`` (the field name implies unstructured exception text but we also persist the structured ``failure_mode=`` lines that ``_classify_none_result`` produces — operators ``grep failure_mode=`` to triage). - Document reranker cache invalidation semantics in ``get_default_reranker_instance`` (DB-write busts the cache; in-memory test patches don't — set ``STRICT_RERANKER`` or call ``invalidate_reranker_cache`` explicitly). - Document TLS verification posture on ``MicroserviceReranker`` (relies on system trust store, no per-instance opt-out). - Document one-shot semantics on the 0037 migration so operators don't expect re-running ``migrate`` to re-seed an already-set value. - Add ``ModelOverrideAllowlistTests`` covering the unknown-model rejection path end-to-end.
1 parent 6e4ab5f commit a38c079

6 files changed

Lines changed: 120 additions & 10 deletions

File tree

opencontractserver/documents/migrations/0037_add_default_reranker_to_pipeline_settings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ def seed_default_reranker(apps, schema_editor):
77
88
Intentionally a no-op when ``DEFAULT_RERANKER`` is not defined, so
99
existing deployments keep reranking disabled until an operator opts in.
10+
11+
One-shot semantics: re-running ``migrate`` after a value has already been
12+
persisted will NOT re-seed it (the existing value is preserved by the
13+
``not instance.default_reranker`` guard). Operators changing rerankers
14+
should update via the admin / pipeline settings UI, not by re-running
15+
this migration.
1016
"""
1117
PipelineSettings = apps.get_model("documents", "PipelineSettings")
1218
initial = getattr(django_settings, "DEFAULT_RERANKER", "")

opencontractserver/pipeline/rerankers/microservice_reranker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@
5656

5757

5858
class MicroserviceReranker(BaseReranker):
59-
"""Reranker that delegates to an external HTTP microservice."""
59+
"""Reranker that delegates to an external HTTP microservice.
60+
61+
TLS: HTTPS endpoints are verified via the system trust store
62+
(``requests`` defaults to ``verify=True``). Self-signed/internal CAs
63+
must be trusted at the OS level — there is no per-instance opt-out.
64+
"""
6065

6166
title = "Microservice Reranker"
6267
description = (

opencontractserver/pipeline/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,12 @@ def get_default_reranker_instance(
613613
in one worker must not pin that worker to degraded behaviour while
614614
siblings continue reranking successfully. See the module-level comment
615615
for the rationale.
616+
617+
Cache invalidation: the cache key includes ``PipelineSettings.modified``,
618+
so DB writes bust it process-wide. Tests that patch settings purely
619+
in-memory will hit stale instances — set ``STRICT_RERANKER`` (which
620+
bypasses the cache fast-path) or call :func:`invalidate_reranker_cache`
621+
explicitly if you need a fresh instance from a fixture.
616622
"""
617623
from django.conf import settings as django_settings
618624

opencontractserver/tasks/data_extract_tasks.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,22 @@ async def doc_extract_query_task(
106106
default extraction model. Used by the benchmark runner to sweep
107107
models without touching production defaults.
108108
109-
Trust assumption: this string is passed straight to the agent
110-
factory and ultimately to the model registry. Current call
109+
Trust boundary: this string is passed straight to the agent
110+
factory and ultimately to the model registry. Current call
111111
sites (CLI ``run_benchmark`` command, internal benchmark
112-
runner) are operator-controlled. If this task is ever
113-
exposed to user-controlled input (webhook, public API), gate
114-
it behind an allowlist of approved model identifiers — an
115-
arbitrary string here can redirect extraction traffic to an
112+
runner) are operator-controlled. The optional Django setting
113+
``BENCHMARK_ALLOWED_MODEL_OVERRIDES`` (iterable of allowed
114+
identifiers) gates this parameter at runtime — by default it is
115+
unset, meaning no enforcement (operator-only path). Operators
116+
exposing this task to user-controlled input (webhook, public
117+
API) must set the allowlist to lock down the surface so an
118+
arbitrary string cannot redirect extraction traffic to an
116119
unintended model endpoint.
117120
"""
118121
import traceback
119122
from typing import get_origin
120123

124+
from django.conf import settings
121125
from django.utils import timezone
122126
from pydantic import BaseModel
123127
from pydantic_ai import capture_run_messages
@@ -156,7 +160,15 @@ def sync_mark_completed(dc, data_dict, llm_log=None):
156160

157161
@sync_to_async
158162
def sync_mark_failed(dc, exc, tb, llm_log=None):
159-
"""Mark datacell as failed with error and optional LLM log."""
163+
"""Mark datacell as failed with error and optional LLM log.
164+
165+
Convention: ``Datacell.stacktrace`` is the only persisted text field
166+
for failure context, so we use it for both real exception
167+
tracebacks AND the structured ``failure_mode=`` lines that
168+
``_classify_none_result`` produces for None outcomes. Operators
169+
``grep failure_mode=`` to separate legitimate "data not present"
170+
outcomes from pipeline bugs.
171+
"""
160172
dc.stacktrace = f"Error: {exc}\n\nTraceback:\n{tb}"
161173
dc.failed = timezone.now()
162174
if llm_log:
@@ -188,6 +200,19 @@ def sync_get_corpus_id(document):
188200
await sync_mark_started(datacell)
189201
logger.info(f"Marked datacell {cell_id} as started")
190202

203+
# Optional allowlist guard for ``model_override``. When
204+
# ``BENCHMARK_ALLOWED_MODEL_OVERRIDES`` is unset (default), no
205+
# enforcement runs — preserves operator-only workflows while
206+
# giving operators a no-code-change path to lock down this
207+
# surface if the task is ever exposed to untrusted input.
208+
if model_override is not None:
209+
allowed = getattr(settings, "BENCHMARK_ALLOWED_MODEL_OVERRIDES", None)
210+
if allowed is not None and model_override not in allowed:
211+
raise ValueError(
212+
f"model_override {model_override!r} is not in "
213+
f"BENCHMARK_ALLOWED_MODEL_OVERRIDES"
214+
)
215+
191216
document = datacell.document
192217
column = datacell.column
193218
logger.info(f"Document: {document.id}, Column: {column.name}")

opencontractserver/tests/test_data_extract_helpers.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,5 +323,65 @@ def test_score_count_mismatch_pads_with_neg_inf(self) -> None:
323323
self.assertEqual(results[1].score, float("-inf"))
324324

325325

326+
class ModelOverrideAllowlistTests(TestCase):
327+
"""``BENCHMARK_ALLOWED_MODEL_OVERRIDES`` guard fires before any Datacell
328+
work runs, so a rejected override marks the cell as failed without
329+
touching the agent runtime.
330+
"""
331+
332+
def setUp(self) -> None:
333+
self.user = User.objects.create_user(
334+
username="allowlist_user", password="testpass"
335+
)
336+
corpus = Corpus.objects.create(title="Allowlist Corpus", creator=self.user)
337+
document = Document.objects.create(
338+
title="Allowlist Doc", creator=self.user, file_type="text/plain"
339+
)
340+
corpus.add_document(document=document, user=self.user)
341+
fieldset = Fieldset.objects.create(name="fs", creator=self.user)
342+
column = Column.objects.create(
343+
fieldset=fieldset,
344+
name="col",
345+
query="anything",
346+
output_type="str",
347+
creator=self.user,
348+
)
349+
extract = Extract.objects.create(
350+
corpus=corpus, fieldset=fieldset, name="ex", creator=self.user
351+
)
352+
self.cell = Datacell.objects.create(
353+
extract=extract,
354+
column=column,
355+
document=document,
356+
data_definition="x",
357+
creator=self.user,
358+
)
359+
360+
def test_unknown_model_override_marks_cell_failed(self) -> None:
361+
from django.test import override_settings
362+
363+
from opencontractserver.tasks.data_extract_tasks import (
364+
doc_extract_query_task,
365+
)
366+
367+
with override_settings(
368+
BENCHMARK_ALLOWED_MODEL_OVERRIDES=["openai:gpt-4o-mini"]
369+
):
370+
# The task re-raises after marking the cell failed; the
371+
# operator-facing celery worker logs the error and the cell
372+
# carries the explanation in its stacktrace for ops review.
373+
with self.assertRaises(ValueError):
374+
doc_extract_query_task.si(
375+
self.cell.id, model_override="anthropic:not-allowed"
376+
).apply().get()
377+
378+
self.cell.refresh_from_db()
379+
self.assertIsNotNone(self.cell.failed)
380+
self.assertIn(
381+
"BENCHMARK_ALLOWED_MODEL_OVERRIDES",
382+
self.cell.stacktrace or "",
383+
)
384+
385+
326386
# Suppress unused-import warning for the SimpleNamespace shim used elsewhere
327387
_ = SimpleNamespace

opencontractserver/tests/test_tool_approval_gate.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,14 +199,22 @@ async def _run_side_effect(*_a, **_kw):
199199
inst.run = AsyncMock(side_effect=_run_side_effect)
200200
inst.iter = MagicMock(return_value=_IterCtx())
201201

202-
# Provide registry entry so resume_with_approval can execute tool
202+
# Provide registry entry so resume_with_approval can execute tool.
203+
# ``_get_function_tools`` now reads from the public ``agent.toolsets``
204+
# API and only consumes ``FunctionToolset`` instances, so the mock
205+
# exposes a real FunctionToolset whose ``tools`` dict carries the
206+
# stub functions keyed by name.
207+
from pydantic_ai.toolsets import FunctionToolset
208+
203209
async def _approved_tool(ctx, x: int): # noqa: D401 – minimal stub
204210
return x * 2
205211

206-
inst._function_tools = {
212+
toolset = FunctionToolset()
213+
toolset.tools = {
207214
"approved_tool": types.SimpleNamespace(function=_approved_tool),
208215
"second_gate_tool": types.SimpleNamespace(function=_approved_tool),
209216
}
217+
inst.toolsets = [toolset]
210218
mock_cls.return_value = inst
211219

212220
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)