Skip to content

Commit c50ba12

Browse files
committed
Fix test expectations for intentional behavior changes; address review feedback
Test fixes (8 tests broken by PR's production changes): - test_openai_embedder: assert max_retries=OPENAI_CLIENT_MAX_RETRIES kwarg on OpenAI() construction; rename rate-limit test to assert re-raise (issue #1380 routes RateLimitError to celery for retry instead of returning None). - test_embeddings_task.TestArrayFormatHandling: patch _get_session for the sent_transformer microservice (PR routes through shared session instead of bare requests.post). - test_corpus_forking: compare annotation JSON via compact_annotation_json on both sides; forked annotations are saved through Annotation.save() which now lazily compacts v1 PAWLs to v2. - test_structural_annotation_sets: rename test to reflect new orphan-GC behavior — SAS is preserved only while another document references it (orphan path covered separately in test_orphan_structural_set_gc.py). Review feedback (PR #1380): - openai_embedder: extract magic 30000 to OPENAI_EMBEDDER_MAX_INPUT_CHARS in constants/document_processing.py; reuse at both call sites. - pydantic_ai_agents: extract duplicated similarity_search closure into _make_similarity_search_tool factory; both document and corpus agent factories now share one citation-capturing implementation. - data_extract_tasks: hoist _link_retrieval_citations to module level so the @sync_to_async wrapper isn't rebuilt on every Celery invocation; drop unused sync_add_sources helper. Tighten _classify_none_result parameter type from object to Sequence[Any] | None.
1 parent 94c5480 commit c50ba12

8 files changed

Lines changed: 171 additions & 144 deletions

File tree

opencontractserver/constants/document_processing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@
5353
# Larger than the single timeout because batches process multiple texts.
5454
EMBEDDER_BATCH_REQUEST_TIMEOUT_SECONDS = 60
5555

56+
# Character-count guard for OpenAI embedding input. The hosted /embeddings
57+
# endpoint caps input at 8192 tokens per text; truncating on the char side
58+
# at ~4x the token budget (English averages ~4 chars/token) keeps us well
59+
# under the cap for any realistic input. Mirrors the silent-tokenizer
60+
# truncation that ``sentence-transformers`` applies locally so OpenAI users
61+
# get the same robustness instead of a fatal 400 "maximum context length"
62+
# from a long whole-document chunk. See ``OpenAIEmbedder._embed_text_impl``
63+
# and ``OpenAIEmbedder.embed_texts_batch``.
64+
OPENAI_EMBEDDER_MAX_INPUT_CHARS = 30_000
65+
5666
# HTTP request timeout (seconds) for reranker microservice calls.
5767
# Reranking typically runs over tens of candidates (top_k * oversample), so
5868
# a modest timeout is sufficient. Retrieval degrades gracefully to the

opencontractserver/llms/agents/pydantic_ai_agents.py

Lines changed: 49 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,48 @@
105105
T = TypeVar("T")
106106

107107

108+
def _make_similarity_search_tool(vector_store: Any) -> Callable:
109+
"""Build the citation-capturing similarity_search tool for a vector store.
110+
111+
Both the document and corpus agent factories used to define this closure
112+
inline; the only thing that differed was which vector store was bound.
113+
Centralising it here keeps the citation-accumulation contract — push
114+
every real annotation PK into ``ctx.deps.retrieved_annotation_ids`` —
115+
in a single place. The tool name remains ``similarity_search`` so
116+
downstream event handlers and source-linking logic are unaffected.
117+
"""
118+
119+
async def similarity_search(
120+
ctx: RunContext[PydanticAIDependencies],
121+
query: str,
122+
k: int = 8,
123+
modalities: Optional[list[str]] = None,
124+
) -> list[dict[str, Any]]:
125+
"""Semantic vector search over the corpus annotations.
126+
127+
Returns the top-k nearest annotations for ``query`` as a list of
128+
dicts with keys ``annotation_id``, ``content``, ``document_id``,
129+
``corpus_id``, ``page``, ``similarity_score``, ``label``, and
130+
``json``. Each real annotation's ID is captured into
131+
``ctx.deps.retrieved_annotation_ids`` so the caller can later link
132+
citations to the owning object (e.g. ``Datacell.sources``).
133+
"""
134+
results = await vector_store.similarity_search(
135+
query, k=k, modalities=modalities
136+
)
137+
for r in results:
138+
if not isinstance(r, dict):
139+
continue
140+
aid = r.get("annotation_id")
141+
# Real annotation PKs are positive ints; synthetic / ad-hoc
142+
# match IDs are negative and must not be persisted.
143+
if isinstance(aid, int) and aid > 0:
144+
ctx.deps.retrieved_annotation_ids.append(aid)
145+
return results
146+
147+
return similarity_search
148+
149+
108150
def _get_function_tools(agent: PydanticAIAgent) -> dict:
109151
"""Return the function-tools dict from a pydantic-ai Agent.
110152
@@ -2059,42 +2101,10 @@ async def create(
20592101
**_vs_kwargs
20602102
)
20612103

2062-
# Default vector search tool: wraps the store's bound method so we can
2063-
# append real annotation IDs returned by the retrieval to the per-run
2064-
# citation accumulator on ``ctx.deps``. Pydantic-AI inspects the
2065-
# signature and injects ``ctx`` because its first parameter is typed
2066-
# as ``RunContext[PydanticAIDependencies]``. The tool name is
2067-
# preserved as ``similarity_search`` so existing event handlers that
2068-
# match on the tool name continue to work.
2069-
async def similarity_search(
2070-
ctx: RunContext[PydanticAIDependencies],
2071-
query: str,
2072-
k: int = 8,
2073-
modalities: Optional[list[str]] = None,
2074-
) -> list[dict[str, Any]]:
2075-
"""Semantic vector search over the corpus annotations.
2076-
2077-
Returns the top-k nearest annotations for ``query`` as a list of
2078-
dicts with keys ``annotation_id``, ``content``, ``document_id``,
2079-
``corpus_id``, ``page``, ``similarity_score``, ``label``, and
2080-
``json``. Each real annotation's ID is captured into
2081-
``ctx.deps.retrieved_annotation_ids`` so the caller can later link
2082-
citations to the owning object (e.g. ``Datacell.sources``).
2083-
"""
2084-
results = await vector_store.similarity_search(
2085-
query, k=k, modalities=modalities
2086-
)
2087-
for r in results:
2088-
if not isinstance(r, dict):
2089-
continue
2090-
aid = r.get("annotation_id")
2091-
# Real annotation PKs are positive ints; synthetic / ad-hoc
2092-
# match IDs are negative and must not be persisted.
2093-
if isinstance(aid, int) and aid > 0:
2094-
ctx.deps.retrieved_annotation_ids.append(aid)
2095-
return results
2096-
2097-
default_vs_tool: Callable = similarity_search
2104+
# See ``_make_similarity_search_tool`` for the citation-accumulation
2105+
# contract; the tool name remains ``similarity_search`` so existing
2106+
# event handlers that match on the tool name continue to work.
2107+
default_vs_tool: Callable = _make_similarity_search_tool(vector_store)
20982108

20992109
# -----------------------------
21002110
# Auto-build pure passthrough tools from registry
@@ -2598,36 +2608,9 @@ async def create(
25982608
**_vs_kwargs
25992609
)
26002610

2601-
# Default vector search tool: wraps the store's bound method to
2602-
# capture real annotation IDs returned during retrieval. See the
2603-
# equivalent wrapper in ``PydanticAIDocumentAgent.create`` for the
2604-
# rationale — we preserve the tool name ``similarity_search`` so
2605-
# downstream event / source handling is unaffected.
2606-
async def similarity_search(
2607-
ctx: RunContext[PydanticAIDependencies],
2608-
query: str,
2609-
k: int = 8,
2610-
modalities: Optional[list[str]] = None,
2611-
) -> list[dict[str, Any]]:
2612-
"""Semantic vector search over the corpus annotations.
2613-
2614-
Returns the top-k nearest annotations for ``query`` as dicts.
2615-
Appends every real annotation PK returned to
2616-
``ctx.deps.retrieved_annotation_ids`` so the caller can link
2617-
citations to the owning object after the run completes.
2618-
"""
2619-
results = await vector_store.similarity_search(
2620-
query, k=k, modalities=modalities
2621-
)
2622-
for r in results:
2623-
if not isinstance(r, dict):
2624-
continue
2625-
aid = r.get("annotation_id")
2626-
if isinstance(aid, int) and aid > 0:
2627-
ctx.deps.retrieved_annotation_ids.append(aid)
2628-
return results
2629-
2630-
default_vs_tool: Callable = similarity_search
2611+
# See ``_make_similarity_search_tool`` for the shared citation-capturing
2612+
# closure used by both the document and corpus agent factories.
2613+
default_vs_tool: Callable = _make_similarity_search_tool(vector_store)
26312614

26322615
# -----------------------------
26332616
# Auto-build passthrough tools from registry

opencontractserver/pipeline/embedders/openai_embedder.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
import openai
66

7+
from opencontractserver.constants.document_processing import (
8+
OPENAI_EMBEDDER_MAX_INPUT_CHARS,
9+
)
710
from opencontractserver.constants.embeddings import (
811
DEFAULT_OPENAI_EMBEDDING_DIMENSIONS,
912
DEFAULT_OPENAI_EMBEDDING_MODEL,
@@ -184,23 +187,17 @@ def _embed_text_impl(self, text: str, **all_kwargs) -> Optional[list[float]]:
184187
)
185188
)
186189

187-
# OpenAI embeddings API caps input at 8192 tokens; a 400 "maximum
188-
# context length" is fatal to ingestion pipelines that produce
189-
# long chunks (e.g. whole-document summaries, un-capped paragraph
190-
# chunks of legalese). Local embedders like
191-
# ``sentence-transformers`` silently truncate via the tokenizer,
192-
# so users expect the same robustness here. Truncate on the char
193-
# side at ~4x the token budget (English averages ~4 chars/token)
194-
# to stay well under 8192 tokens for any realistic input.
195-
max_chars = 30000
196-
if len(text) > max_chars:
190+
# See OPENAI_EMBEDDER_MAX_INPUT_CHARS for the rationale behind the
191+
# truncation cap (mirrors the silent tokenizer truncation that
192+
# ``sentence-transformers`` applies locally).
193+
if len(text) > OPENAI_EMBEDDER_MAX_INPUT_CHARS:
197194
logger.warning(
198195
"OpenAIEmbedder truncating input from %d to %d chars to fit "
199196
"the 8192-token context window",
200197
len(text),
201-
max_chars,
198+
OPENAI_EMBEDDER_MAX_INPUT_CHARS,
202199
)
203-
text = text[:max_chars]
200+
text = text[:OPENAI_EMBEDDER_MAX_INPUT_CHARS]
204201

205202
client = self._build_client(**all_kwargs)
206203

opencontractserver/tasks/data_extract_tasks.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import json
33
import logging
44
import os
5+
from collections.abc import Sequence
6+
from typing import Any
57

68
from asgiref.sync import sync_to_async
79

@@ -162,36 +164,6 @@ def sync_get_corpus_id(document):
162164
return doc_path.corpus_id
163165
return None
164166

165-
@sync_to_async
166-
def sync_add_sources(datacell, sources):
167-
"""Add source annotations to datacell."""
168-
if sources:
169-
# Extract annotation IDs from SourceNode objects
170-
annotation_ids = [s.annotation_id for s in sources if s.annotation_id > 0]
171-
if annotation_ids:
172-
datacell.sources.add(*annotation_ids)
173-
174-
@sync_to_async
175-
def _link_retrieval_citations(datacell, annotation_ids):
176-
"""Link raw Annotation PKs retrieved by the agent to ``datacell.sources``.
177-
178-
Filters defensively: only positive ints that correspond to real
179-
Annotation rows are persisted. Duplicates are deduped by the M2M
180-
unique constraint, so ``add(*ids)`` with repeats is safe.
181-
"""
182-
from opencontractserver.annotations.models import Annotation
183-
184-
valid_ids = [a for a in annotation_ids if isinstance(a, int) and a > 0]
185-
if not valid_ids:
186-
return
187-
# Guard against IDs that don't exist (e.g. race with deletion).
188-
existing = set(
189-
Annotation.objects.filter(id__in=valid_ids).values_list("id", flat=True)
190-
)
191-
existing_ids = [aid for aid in valid_ids if aid in existing]
192-
if existing_ids:
193-
datacell.sources.add(*existing_ids)
194-
195167
# Initialize datacell to None to avoid UnboundLocalError
196168
datacell = None
197169

@@ -459,7 +431,29 @@ def _link_retrieval_citations(datacell, annotation_ids):
459431
raise
460432

461433

462-
def _classify_none_result(messages: object) -> tuple[str, str]:
434+
@sync_to_async
435+
def _link_retrieval_citations(datacell, annotation_ids):
436+
"""Link raw Annotation PKs retrieved by the agent to ``datacell.sources``.
437+
438+
Filters defensively: only positive ints that correspond to real
439+
Annotation rows are persisted. Duplicates are deduped by the M2M
440+
unique constraint, so ``add(*ids)`` with repeats is safe.
441+
"""
442+
from opencontractserver.annotations.models import Annotation
443+
444+
valid_ids = [a for a in annotation_ids if isinstance(a, int) and a > 0]
445+
if not valid_ids:
446+
return
447+
# Guard against IDs that don't exist (e.g. race with deletion).
448+
existing = set(
449+
Annotation.objects.filter(id__in=valid_ids).values_list("id", flat=True)
450+
)
451+
existing_ids = [aid for aid in valid_ids if aid in existing]
452+
if existing_ids:
453+
datacell.sources.add(*existing_ids)
454+
455+
456+
def _classify_none_result(messages: Sequence[Any] | None) -> tuple[str, str]:
463457
"""Categorise a ``result is None`` outcome from ``agent.run()``.
464458
465459
Reads the captured pydantic-ai message history (a list of

opencontractserver/tests/test_corpus_forking.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,18 @@ def test_forked_annotation_field_integrity(self):
336336
# Core data fields must match
337337
self.assertEqual(forked.page, orig.page)
338338
self.assertEqual(forked.raw_text, orig.raw_text)
339-
self.assertEqual(forked.json, orig.json)
339+
# Forked annotations are saved through Annotation.save(), which
340+
# auto-compacts v1 PAWLs JSON to v2 (issue: lazy migration, see
341+
# compact_annotation_json). Compare both sides in compact form so
342+
# the test is format-agnostic.
343+
from opencontractserver.annotations.compact_json import (
344+
compact_annotation_json,
345+
)
346+
347+
self.assertEqual(
348+
compact_annotation_json(forked.json),
349+
compact_annotation_json(orig.json),
350+
)
340351
self.assertEqual(forked.annotation_type, orig.annotation_type)
341352

342353
# Creator should propagate

opencontractserver/tests/test_embeddings_task.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,16 +1294,22 @@ def test_microservice_embedder_handles_1d_array(self):
12941294
embedder = MicroserviceEmbedder()
12951295

12961296
# Simulate 1D response: [0.1, 0.2, 0.3]
1297-
with patch("requests.post") as mock_post:
1298-
mock_response = MagicMock()
1299-
mock_response.status_code = 200
1300-
mock_response.json.return_value = {"embeddings": [0.1, 0.2, 0.3]}
1301-
mock_post.return_value = mock_response
1302-
1297+
# PR #1380 routes embedder requests through a shared session, so
1298+
# patch the session getter instead of the global requests.post.
1299+
mock_session = MagicMock()
1300+
mock_response = MagicMock()
1301+
mock_response.status_code = 200
1302+
mock_response.json.return_value = {"embeddings": [0.1, 0.2, 0.3]}
1303+
mock_session.post.return_value = mock_response
1304+
1305+
with patch(
1306+
"opencontractserver.pipeline.embedders.sent_transformer_microservice._get_session",
1307+
return_value=mock_session,
1308+
):
13031309
result = embedder.embed_text("test")
13041310

1305-
self.assertEqual(result, [0.1, 0.2, 0.3])
1306-
self.assertIsInstance(result, list)
1311+
self.assertEqual(result, [0.1, 0.2, 0.3])
1312+
self.assertIsInstance(result, list)
13071313

13081314
def test_microservice_embedder_handles_2d_array(self):
13091315
"""Test that MicroserviceEmbedder correctly handles 2D array responses."""
@@ -1314,16 +1320,20 @@ def test_microservice_embedder_handles_2d_array(self):
13141320
embedder = MicroserviceEmbedder()
13151321

13161322
# Simulate 2D response: [[0.1, 0.2, 0.3]]
1317-
with patch("requests.post") as mock_post:
1318-
mock_response = MagicMock()
1319-
mock_response.status_code = 200
1320-
mock_response.json.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
1321-
mock_post.return_value = mock_response
1322-
1323+
mock_session = MagicMock()
1324+
mock_response = MagicMock()
1325+
mock_response.status_code = 200
1326+
mock_response.json.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
1327+
mock_session.post.return_value = mock_response
1328+
1329+
with patch(
1330+
"opencontractserver.pipeline.embedders.sent_transformer_microservice._get_session",
1331+
return_value=mock_session,
1332+
):
13231333
result = embedder.embed_text("test")
13241334

1325-
self.assertEqual(result, [0.1, 0.2, 0.3])
1326-
self.assertIsInstance(result, list)
1335+
self.assertEqual(result, [0.1, 0.2, 0.3])
1336+
self.assertIsInstance(result, list)
13271337

13281338
def test_multimodal_embedder_handles_1d_array(self):
13291339
"""Test that CLIPMicroserviceEmbedder correctly handles 1D array responses."""
@@ -1377,15 +1387,19 @@ def test_microservice_embedder_settings_none_fallback(self):
13771387
# Force settings to None to exercise the fallback path
13781388
embedder._settings = None
13791389

1380-
with patch("requests.post") as mock_post:
1381-
mock_response = MagicMock()
1382-
mock_response.status_code = 200
1383-
mock_response.json.return_value = {"embeddings": [0.1, 0.2, 0.3]}
1384-
mock_post.return_value = mock_response
1390+
mock_session = MagicMock()
1391+
mock_response = MagicMock()
1392+
mock_response.status_code = 200
1393+
mock_response.json.return_value = {"embeddings": [0.1, 0.2, 0.3]}
1394+
mock_session.post.return_value = mock_response
13851395

1396+
with patch(
1397+
"opencontractserver.pipeline.embedders.sent_transformer_microservice._get_session",
1398+
return_value=mock_session,
1399+
):
13861400
result = embedder.embed_text("test")
13871401

1388-
self.assertEqual(result, [0.1, 0.2, 0.3])
1402+
self.assertEqual(result, [0.1, 0.2, 0.3])
13891403

13901404
def test_clip_embedder_settings_none_fallback(self):
13911405
"""CLIPMicroserviceEmbedder._get_service_config falls back to Settings()."""

0 commit comments

Comments
 (0)