Skip to content

Commit 31fe44b

Browse files
committed
Add unit tests for retrieval-citation linking, none-result classifier, reranker
- _link_retrieval_citations: 4 tests covering happy path, defensive filtering of non-int / negative ids, missing-id graceful fallback, and noop-on-empty - _classify_none_result: 7 tests pinning each classification (empty, empty_history with messages-but-no-response, agent_committed_none for text and output_tool parts, no_final_response for single tool call and thinking, tool_loop_no_output for repeated tool calls, and the text-after-loop precedence rule) - CrossEncoderReranker._rerank_impl: 3 tests covering happy-path scoring, scalar-response normalization, and length-mismatch -inf padding — uses a mocked CrossEncoder loader so CI doesn't download weights
1 parent a2aa94d commit 31fe44b

1 file changed

Lines changed: 327 additions & 0 deletions

File tree

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
"""Targeted unit tests for the small helpers introduced alongside the
2+
benchmark harness: retrieval-citation linking and the
3+
``result is None`` failure-mode classifier.
4+
5+
These are pure-Python helpers that don't go through the agent runtime,
6+
so they can be exercised with mocked message logs and lightweight
7+
fixtures without spinning up a full extraction.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from types import SimpleNamespace
13+
from unittest.mock import MagicMock
14+
15+
from django.contrib.auth import get_user_model
16+
from django.test import TestCase
17+
18+
from opencontractserver.annotations.models import (
19+
SPAN_LABEL,
20+
Annotation,
21+
AnnotationLabel,
22+
)
23+
from opencontractserver.corpuses.models import Corpus
24+
from opencontractserver.documents.models import Document
25+
from opencontractserver.extracts.models import (
26+
Column,
27+
Datacell,
28+
Extract,
29+
Fieldset,
30+
)
31+
from opencontractserver.tasks.data_extract_tasks import _classify_none_result
32+
33+
User = get_user_model()
34+
35+
36+
def _build_datacell_with_annotations(test):
37+
"""Return a ``(datacell, annotation_ids)`` tuple wired up minimally."""
38+
user = User.objects.create_user(
39+
username="extract_helpers_user", password="testpass"
40+
)
41+
corpus = Corpus.objects.create(title="ExtractHelpers Corpus", creator=user)
42+
document = Document.objects.create(
43+
title="ExtractHelpers Doc",
44+
creator=user,
45+
file_type="text/plain",
46+
)
47+
corpus.add_document(document=document, user=user)
48+
49+
label = AnnotationLabel.objects.create(
50+
text="ExtractHelpersLabel", creator=user, label_type=SPAN_LABEL
51+
)
52+
53+
annotations = [
54+
Annotation.objects.create(
55+
document=document,
56+
corpus=corpus,
57+
annotation_label=label,
58+
annotation_type=SPAN_LABEL,
59+
raw_text=f"hit {i}",
60+
json={"start": i, "end": i + 4},
61+
creator=user,
62+
page=1,
63+
)
64+
for i in range(3)
65+
]
66+
67+
fieldset = Fieldset.objects.create(name="ExtractHelpers FS", creator=user)
68+
column = Column.objects.create(
69+
fieldset=fieldset,
70+
name="Hits",
71+
query="anything",
72+
output_type="str",
73+
creator=user,
74+
)
75+
extract = Extract.objects.create(
76+
name="ExtractHelpers Extract",
77+
corpus=corpus,
78+
fieldset=fieldset,
79+
creator=user,
80+
)
81+
datacell = Datacell.objects.create(
82+
extract=extract,
83+
column=column,
84+
document=document,
85+
creator=user,
86+
data={"data": "anything"},
87+
)
88+
return datacell, [a.id for a in annotations]
89+
90+
91+
class LinkRetrievalCitationsTests(TestCase):
92+
"""Cover ``_link_retrieval_citations``'s defensive filtering path."""
93+
94+
def test_real_ids_are_linked_to_sources(self) -> None:
95+
from asgiref.sync import async_to_sync
96+
97+
from opencontractserver.tasks.data_extract_tasks import (
98+
_link_retrieval_citations,
99+
)
100+
101+
datacell, annotation_ids = _build_datacell_with_annotations(self)
102+
103+
async_to_sync(_link_retrieval_citations)(datacell, annotation_ids)
104+
105+
datacell.refresh_from_db()
106+
self.assertEqual(
107+
set(datacell.sources.values_list("id", flat=True)),
108+
set(annotation_ids),
109+
)
110+
111+
def test_non_int_and_negative_ids_are_dropped(self) -> None:
112+
from asgiref.sync import async_to_sync
113+
114+
from opencontractserver.tasks.data_extract_tasks import (
115+
_link_retrieval_citations,
116+
)
117+
118+
datacell, annotation_ids = _build_datacell_with_annotations(self)
119+
120+
# Inject a hostile mix: floats, negative ints, strings, valid ids
121+
async_to_sync(_link_retrieval_citations)(
122+
datacell,
123+
[None, -1, 0, "5", 1.5, *annotation_ids],
124+
)
125+
126+
datacell.refresh_from_db()
127+
# Only the real positive ints survive the filter
128+
self.assertEqual(
129+
set(datacell.sources.values_list("id", flat=True)),
130+
set(annotation_ids),
131+
)
132+
133+
def test_missing_ids_silently_ignored(self) -> None:
134+
"""A row deleted between retrieval and link must not blow up."""
135+
from asgiref.sync import async_to_sync
136+
137+
from opencontractserver.tasks.data_extract_tasks import (
138+
_link_retrieval_citations,
139+
)
140+
141+
datacell, annotation_ids = _build_datacell_with_annotations(self)
142+
# Reference an annotation id that doesn't exist plus the real ones
143+
bogus_id = max(annotation_ids) + 9999
144+
145+
async_to_sync(_link_retrieval_citations)(datacell, [bogus_id, *annotation_ids])
146+
147+
datacell.refresh_from_db()
148+
# The bogus id is silently filtered; real ids still link
149+
self.assertEqual(
150+
set(datacell.sources.values_list("id", flat=True)),
151+
set(annotation_ids),
152+
)
153+
154+
def test_empty_input_is_a_noop(self) -> None:
155+
from asgiref.sync import async_to_sync
156+
157+
from opencontractserver.tasks.data_extract_tasks import (
158+
_link_retrieval_citations,
159+
)
160+
161+
datacell, _ = _build_datacell_with_annotations(self)
162+
163+
async_to_sync(_link_retrieval_citations)(datacell, [])
164+
async_to_sync(_link_retrieval_citations)(datacell, [None, "abc", -1])
165+
166+
datacell.refresh_from_db()
167+
self.assertEqual(datacell.sources.count(), 0)
168+
169+
170+
def _response_msg(part_kinds):
171+
"""Build a minimal duck-typed ``response``-kind message.
172+
173+
The classifier only reads ``msg.kind`` and ``msg.parts[i].part_kind``,
174+
so a ``SimpleNamespace`` is enough — no need to drag in pydantic-ai's
175+
real ``ModelResponse`` and its strict validation.
176+
"""
177+
parts = [SimpleNamespace(part_kind=kind) for kind in part_kinds]
178+
return SimpleNamespace(kind="response", parts=parts)
179+
180+
181+
class ClassifyNoneResultTests(TestCase):
182+
"""Cover the four failure-mode classifications the agent emits."""
183+
184+
def test_no_messages_is_empty_history(self) -> None:
185+
mode, detail = _classify_none_result(None)
186+
self.assertEqual(mode, "empty_history")
187+
self.assertIn("no messages", detail)
188+
189+
mode, detail = _classify_none_result([])
190+
self.assertEqual(mode, "empty_history")
191+
192+
def test_no_response_messages_is_empty_history(self) -> None:
193+
"""Messages exist, but none of them are ``response``-kind."""
194+
request_only = [SimpleNamespace(kind="request", parts=[])]
195+
mode, detail = _classify_none_result(request_only)
196+
self.assertEqual(mode, "empty_history")
197+
self.assertIn("no response messages", detail)
198+
199+
def test_text_only_response_is_committed_none(self) -> None:
200+
"""Last response carries a text part → model committed."""
201+
msg = _response_msg(["text"])
202+
mode, _ = _classify_none_result([msg])
203+
self.assertEqual(mode, "agent_committed_none")
204+
205+
def test_output_tool_part_is_committed_none(self) -> None:
206+
"""``output_tool`` parts (final structured response) → committed."""
207+
msg = _response_msg(["output_tool"])
208+
mode, _ = _classify_none_result([msg])
209+
self.assertEqual(mode, "agent_committed_none")
210+
211+
def test_single_tool_call_only_is_no_final(self) -> None:
212+
"""One response that ends on a tool call never reached final."""
213+
msg = _response_msg(["tool-call"])
214+
mode, _ = _classify_none_result([msg])
215+
self.assertEqual(mode, "no_final_response")
216+
217+
def test_repeated_tool_call_only_is_tool_loop(self) -> None:
218+
"""Multiple response messages, all tool-call parts, no final."""
219+
msgs = [
220+
_response_msg(["tool-call"]),
221+
_response_msg(["tool-call"]),
222+
_response_msg(["tool-call"]),
223+
]
224+
mode, _ = _classify_none_result(msgs)
225+
self.assertEqual(mode, "tool_loop_no_output")
226+
227+
def test_thinking_only_is_no_final_response(self) -> None:
228+
"""``thinking`` parts don't count as final output (they're internal)."""
229+
msg = _response_msg(["thinking"])
230+
mode, _ = _classify_none_result([msg])
231+
self.assertEqual(mode, "no_final_response")
232+
233+
def test_text_after_tool_loop_is_committed(self) -> None:
234+
"""If the *last* response has a text part, that's commitment."""
235+
msgs = [
236+
_response_msg(["tool-call"]),
237+
_response_msg(["tool-call"]),
238+
_response_msg(["text"]),
239+
]
240+
mode, _ = _classify_none_result(msgs)
241+
self.assertEqual(mode, "agent_committed_none")
242+
243+
244+
class CrossEncoderRerankerTests(TestCase):
245+
"""Light coverage for ``CrossEncoderReranker._rerank_impl``.
246+
247+
The cross-encoder weights are large and we don't want to download
248+
them in CI, so this exercises the scoring/ranking logic with a
249+
mocked ``CrossEncoder`` backend.
250+
"""
251+
252+
def test_scores_sort_passages_by_relevance(self) -> None:
253+
from opencontractserver.pipeline.rerankers import cross_encoder_reranker
254+
255+
# Mock the loader so the reranker doesn't try to download weights
256+
fake_model = MagicMock()
257+
# Simulate scores: passage 0 (hit), 1 (miss), 2 (best hit)
258+
fake_model.predict.return_value = [0.4, 0.05, 0.95]
259+
260+
original_loader = cross_encoder_reranker._load_cross_encoder
261+
cross_encoder_reranker._load_cross_encoder = (
262+
lambda model_name, device: fake_model
263+
) # noqa: E731
264+
try:
265+
reranker = cross_encoder_reranker.CrossEncoderReranker()
266+
results = reranker._rerank_impl(
267+
query="capital of france",
268+
passages=["paris is the capital", "lyon", "paris france capital"],
269+
)
270+
finally:
271+
cross_encoder_reranker._load_cross_encoder = original_loader
272+
273+
# Result indices preserve input ordering; caller sorts by score.
274+
scores = {r.index: r.score for r in results}
275+
self.assertEqual(scores[0], 0.4)
276+
self.assertEqual(scores[1], 0.05)
277+
self.assertEqual(scores[2], 0.95)
278+
279+
def test_scalar_score_response_is_normalized(self) -> None:
280+
"""Single-pair scoring may come back as a numpy scalar — handle it."""
281+
from opencontractserver.pipeline.rerankers import cross_encoder_reranker
282+
283+
fake_model = MagicMock()
284+
# Some backends return a 0-d scalar instead of a length-1 list
285+
fake_model.predict.return_value = 0.7
286+
287+
original_loader = cross_encoder_reranker._load_cross_encoder
288+
cross_encoder_reranker._load_cross_encoder = (
289+
lambda model_name, device: fake_model
290+
) # noqa: E731
291+
try:
292+
reranker = cross_encoder_reranker.CrossEncoderReranker()
293+
results = reranker._rerank_impl(
294+
query="anything", passages=["only one passage"]
295+
)
296+
finally:
297+
cross_encoder_reranker._load_cross_encoder = original_loader
298+
299+
self.assertEqual(len(results), 1)
300+
self.assertEqual(results[0].score, 0.7)
301+
302+
def test_score_count_mismatch_pads_with_neg_inf(self) -> None:
303+
"""If predict returns fewer scores than passages, pad defensively."""
304+
from opencontractserver.pipeline.rerankers import cross_encoder_reranker
305+
306+
fake_model = MagicMock()
307+
# Only one score for two passages
308+
fake_model.predict.return_value = [0.5]
309+
310+
original_loader = cross_encoder_reranker._load_cross_encoder
311+
cross_encoder_reranker._load_cross_encoder = (
312+
lambda model_name, device: fake_model
313+
) # noqa: E731
314+
try:
315+
reranker = cross_encoder_reranker.CrossEncoderReranker()
316+
results = reranker._rerank_impl(query="anything", passages=["one", "two"])
317+
finally:
318+
cross_encoder_reranker._load_cross_encoder = original_loader
319+
320+
self.assertEqual(len(results), 2)
321+
self.assertEqual(results[0].score, 0.5)
322+
# Padded entries land at -inf so they sort to the bottom
323+
self.assertEqual(results[1].score, float("-inf"))
324+
325+
326+
# Suppress unused-import warning for the SimpleNamespace shim used elsewhere
327+
_ = SimpleNamespace

0 commit comments

Comments
 (0)