Skip to content

Commit 6689081

Browse files
committed
Boost benchmark/reranker/embedder coverage past codecov target
Patch coverage on PR #1380 was 85.02% (target 87.30%). Added: - run_benchmark management command: full happy-path test (real adapter, mocked runner) plus retrieval-only/corpus-wide flag passthrough and user-not-found error path. Brings ~50 lines of run_benchmark.py from 0% to covered. - CohereReranker fault-tolerance: request-exception, non-200, non-JSON body, missing 'results' key, malformed item skipping, and empty results all fall back to identity ordering. - CrossEncoderReranker (no prior tests): success path with score extraction, max_length forwarding, padding when the model returns fewer scores than passages, scalar→list normalization, and the per-key model cache fast path. Uses an injected stub model so CI runs without sentence-transformers/torch installed. - OpenAI embedder truncation: oversize single-text and oversize batch inputs are clipped to OPENAI_EMBEDDER_MAX_INPUT_CHARS before the wire call; empty/whitespace batch inputs become None slots.
1 parent 3ff12bb commit 6689081

3 files changed

Lines changed: 338 additions & 0 deletions

File tree

opencontractserver/tests/test_benchmarks.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,3 +842,114 @@ def test_sampling_drops_malformed_tests(self):
842842
self.assertEqual(len(out), 194)
843843
for t in out:
844844
self.assertTrue(t["snippets"][0].get("file_path"))
845+
846+
847+
# --------------------------------------------------------------------------- #
848+
# Management command tests (``python manage.py run_benchmark``)
849+
# --------------------------------------------------------------------------- #
850+
851+
852+
class RunBenchmarkCommandTest(TransactionTestCase):
853+
"""Cover the CLI entry point so user lookup, adapter wiring and
854+
aggregate printing are exercised end-to-end with the runner mocked.
855+
"""
856+
857+
def test_user_not_found_raises(self):
858+
from django.core.management import call_command
859+
from django.core.management.base import CommandError
860+
861+
with self.assertRaisesRegex(CommandError, "not found"):
862+
call_command(
863+
"run_benchmark",
864+
"--path=/tmp/does-not-matter",
865+
"--user=nope-no-such-user",
866+
)
867+
868+
@patch(
869+
"opencontractserver.benchmarks.management.commands.run_benchmark."
870+
"run_benchmark"
871+
)
872+
def test_happy_path_invokes_runner_and_prints_aggregates(self, mock_run):
873+
from io import StringIO
874+
875+
from django.core.management import call_command
876+
877+
from opencontractserver.benchmarks.report import BenchmarkReport
878+
879+
user = User.objects.create_user(username="bench-cli-user")
880+
881+
# Use the real BenchmarkReport so __post_init__ populates aggregates
882+
# — covers the float/dict/value formatting branches in handle().
883+
report = BenchmarkReport(
884+
adapter={"name": "test"},
885+
config={"model": "test:m"},
886+
corpus_id=42,
887+
extract_id=7,
888+
task_results=[],
889+
run_dir=Path("/tmp/run-x"),
890+
)
891+
mock_run.return_value = report
892+
893+
out = StringIO()
894+
call_command(
895+
"run_benchmark",
896+
f"--path={MICRO_FIXTURE}",
897+
"--user=bench-cli-user",
898+
"--top-k=3",
899+
"--limit=5",
900+
stdout=out,
901+
)
902+
903+
# Runner called with the parsed CLI options.
904+
run_kwargs = mock_run.call_args.kwargs
905+
self.assertEqual(run_kwargs["top_k"], 3)
906+
self.assertEqual(run_kwargs["user"], user)
907+
self.assertFalse(run_kwargs["retrieval_only"])
908+
self.assertFalse(run_kwargs["corpus_wide"])
909+
# The adapter passed to run_benchmark is a real LegalBenchRAGAdapter.
910+
from opencontractserver.benchmarks.adapters.legalbench_rag import (
911+
LegalBenchRAGAdapter,
912+
)
913+
914+
self.assertIsInstance(run_kwargs["adapter"], LegalBenchRAGAdapter)
915+
916+
text = out.getvalue()
917+
self.assertIn("Benchmark run complete", text)
918+
self.assertIn("Corpus ID: 42", text)
919+
self.assertIn("Extract ID: 7", text)
920+
self.assertIn("Report dir: /tmp/run-x", text)
921+
# Aggregate lines emitted (BenchmarkReport.__post_init__ populates
922+
# task_count and float metrics, exercising the int-else and float
923+
# branches of the handle() output loop).
924+
self.assertIn("task_count", text)
925+
self.assertIn("answer_token_f1", text)
926+
927+
@patch(
928+
"opencontractserver.benchmarks.management.commands.run_benchmark."
929+
"run_benchmark"
930+
)
931+
def test_retrieval_only_and_corpus_wide_flags_pass_through(self, mock_run):
932+
from django.core.management import call_command
933+
934+
from opencontractserver.benchmarks.report import BenchmarkReport
935+
936+
User.objects.create_user(username="bench-flag-user")
937+
mock_run.return_value = BenchmarkReport(
938+
adapter={},
939+
config={},
940+
corpus_id=1,
941+
extract_id=1,
942+
task_results=[],
943+
run_dir=None, # exercise the no-run-dir branch in handle()
944+
)
945+
call_command(
946+
"run_benchmark",
947+
f"--path={MICRO_FIXTURE}",
948+
"--user=bench-flag-user",
949+
"--retrieval-only",
950+
"--corpus-wide",
951+
"--no-paper-sampling",
952+
)
953+
run_kwargs = mock_run.call_args.kwargs
954+
self.assertTrue(run_kwargs["retrieval_only"])
955+
self.assertTrue(run_kwargs["corpus_wide"])

opencontractserver/tests/test_openai_embedder.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,74 @@ def test_embed_image_not_supported(self):
299299
result = embedder.embed_image("base64data", "jpeg")
300300
self.assertIsNone(result)
301301

302+
@patch("opencontractserver.pipeline.embedders.openai_embedder.openai.OpenAI")
303+
def test_embed_text_truncates_oversize_input(self, mock_openai_cls):
304+
"""Inputs longer than OPENAI_EMBEDDER_MAX_INPUT_CHARS are truncated."""
305+
from opencontractserver.constants.document_processing import (
306+
OPENAI_EMBEDDER_MAX_INPUT_CHARS,
307+
)
308+
309+
fake_embedding = [0.1] * DEFAULT_OPENAI_EMBEDDING_DIMENSIONS
310+
mock_client = MagicMock()
311+
mock_client.embeddings.create.return_value = self._make_mock_response(
312+
fake_embedding
313+
)
314+
mock_openai_cls.return_value = mock_client
315+
316+
oversize = "a" * (OPENAI_EMBEDDER_MAX_INPUT_CHARS + 5_000)
317+
embedder = OpenAIEmbedder()
318+
result = embedder.embed_text(oversize, openai_api_key="test-key")
319+
320+
self.assertIsNotNone(result)
321+
sent = mock_client.embeddings.create.call_args.kwargs["input"]
322+
self.assertEqual(len(sent), OPENAI_EMBEDDER_MAX_INPUT_CHARS)
323+
324+
@patch("opencontractserver.pipeline.embedders.openai_embedder.openai.OpenAI")
325+
def test_embed_texts_batch_truncates_and_skips_blanks(self, mock_openai_cls):
326+
"""Batch path: blanks become None slots, oversize inputs are clipped."""
327+
from opencontractserver.constants.document_processing import (
328+
OPENAI_EMBEDDER_MAX_INPUT_CHARS,
329+
)
330+
331+
oversize = "z" * (OPENAI_EMBEDDER_MAX_INPUT_CHARS + 1_000)
332+
texts = ["hi", "", oversize, " "]
333+
# Two non-empty inputs survive — return one fake embedding for each.
334+
fake_a = [0.1] * DEFAULT_OPENAI_EMBEDDING_DIMENSIONS
335+
fake_b = [0.2] * DEFAULT_OPENAI_EMBEDDING_DIMENSIONS
336+
mock_response = MagicMock()
337+
d0, d1 = MagicMock(), MagicMock()
338+
d0.embedding = fake_a
339+
d1.embedding = fake_b
340+
mock_response.data = [d0, d1]
341+
mock_client = MagicMock()
342+
mock_client.embeddings.create.return_value = mock_response
343+
mock_openai_cls.return_value = mock_client
344+
345+
embedder = OpenAIEmbedder()
346+
result = embedder.embed_texts_batch(texts, openai_api_key="test-key")
347+
348+
self.assertEqual(len(result), 4)
349+
self.assertEqual(result[0], fake_a) # "hi"
350+
self.assertIsNone(result[1]) # "" filtered out
351+
self.assertEqual(result[2], fake_b) # oversize → truncated then sent
352+
self.assertIsNone(result[3]) # whitespace-only filtered out
353+
354+
# Confirm the wire payload carried only the two surviving inputs and
355+
# that the oversize one was truncated to the cap.
356+
sent = mock_client.embeddings.create.call_args.kwargs["input"]
357+
self.assertEqual(len(sent), 2)
358+
self.assertEqual(sent[0], "hi")
359+
self.assertEqual(len(sent[1]), OPENAI_EMBEDDER_MAX_INPUT_CHARS)
360+
361+
def test_embed_texts_batch_empty_returns_empty_list(self):
362+
embedder = OpenAIEmbedder()
363+
self.assertEqual(embedder.embed_texts_batch([]), [])
364+
365+
def test_embed_texts_batch_all_blank_returns_all_none(self):
366+
embedder = OpenAIEmbedder()
367+
result = embedder.embed_texts_batch(["", " ", None]) # type: ignore[list-item]
368+
self.assertEqual(result, [None, None, None])
369+
302370

303371
class TestOpenAIEmbedderDiscovery(TestCase):
304372
"""Tests that OpenAIEmbedder is properly discovered by the registry."""

opencontractserver/tests/test_reranker.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,165 @@ def test_missing_api_key_returns_identity(self) -> None:
298298
results = reranker.rerank("query", ["a", "b"])
299299
self.assertEqual([r.index for r in results], [0, 1])
300300

301+
def test_request_exception_falls_back_to_identity(self) -> None:
302+
reranker = self._reranker()
303+
with patch(
304+
"requests.post",
305+
side_effect=requests.exceptions.ConnectionError("network down"),
306+
):
307+
results = reranker.rerank("query", ["p0", "p1"])
308+
self.assertEqual([r.index for r in results], [0, 1])
309+
310+
def test_non_200_status_falls_back_to_identity(self) -> None:
311+
reranker = self._reranker()
312+
with patch("requests.post", return_value=_FakeResponse(status_code=503)):
313+
results = reranker.rerank("query", ["p0", "p1"])
314+
self.assertEqual([r.index for r in results], [0, 1])
315+
316+
def test_non_json_body_falls_back_to_identity(self) -> None:
317+
class _BadJsonResponse:
318+
status_code = 200
319+
text = "not json"
320+
321+
def json(self):
322+
raise ValueError("bad json")
323+
324+
reranker = self._reranker()
325+
with patch("requests.post", return_value=_BadJsonResponse()):
326+
results = reranker.rerank("query", ["p0", "p1"])
327+
self.assertEqual([r.index for r in results], [0, 1])
328+
329+
def test_missing_results_key_falls_back_to_identity(self) -> None:
330+
reranker = self._reranker()
331+
with patch(
332+
"requests.post",
333+
return_value=_FakeResponse(status_code=200, payload={"nope": []}),
334+
):
335+
results = reranker.rerank("query", ["p0", "p1"])
336+
self.assertEqual([r.index for r in results], [0, 1])
337+
338+
def test_malformed_items_are_skipped(self) -> None:
339+
reranker = self._reranker()
340+
fake = _FakeResponse(
341+
status_code=200,
342+
payload={
343+
"results": [
344+
{"index": 0, "relevance_score": "not-a-float"},
345+
{"index": "bad", "relevance_score": 0.9},
346+
{"missing_keys": True},
347+
{"index": 1, "relevance_score": 0.42},
348+
]
349+
},
350+
)
351+
with patch("requests.post", return_value=fake):
352+
results = reranker.rerank("query", ["p0", "p1"])
353+
# Only the well-formed item survives.
354+
self.assertEqual([r.index for r in results], [1])
355+
356+
def test_empty_results_falls_back_to_identity(self) -> None:
357+
reranker = self._reranker()
358+
fake = _FakeResponse(status_code=200, payload={"results": []})
359+
with patch("requests.post", return_value=fake):
360+
results = reranker.rerank("query", ["p0", "p1"])
361+
self.assertEqual([r.index for r in results], [0, 1])
362+
363+
364+
# --------------------------------------------------------------------------- #
365+
# Cross-encoder reranker tests (uses an injected stub model so no
366+
# ``sentence-transformers`` / ``torch`` install is required in CI).
367+
# --------------------------------------------------------------------------- #
368+
369+
370+
class _StubCrossEncoder:
371+
"""Stand-in for sentence_transformers.CrossEncoder used in tests."""
372+
373+
def __init__(self, scores):
374+
self._scores = scores
375+
self.max_length: int | None = None
376+
self.last_pairs: list = []
377+
378+
def predict(self, pairs, batch_size=None, show_progress_bar=False):
379+
self.last_pairs = list(pairs)
380+
return self._scores
381+
382+
383+
class CrossEncoderRerankerTest(TestCase):
384+
def setUp(self) -> None:
385+
from opencontractserver.pipeline.rerankers import cross_encoder_reranker
386+
387+
# Reset the module-level cache between tests so each gets a fresh
388+
# stub model.
389+
cross_encoder_reranker._MODEL_CACHE.clear()
390+
391+
def _reranker(self):
392+
from opencontractserver.pipeline.rerankers.cross_encoder_reranker import (
393+
CrossEncoderReranker,
394+
)
395+
396+
rk = CrossEncoderReranker()
397+
rk._settings = CrossEncoderReranker.Settings(
398+
model_name="stub-model",
399+
device="cpu",
400+
batch_size=8,
401+
max_length=128,
402+
)
403+
return rk
404+
405+
def test_successful_rerank_returns_scores(self) -> None:
406+
stub = _StubCrossEncoder([0.1, 0.9, 0.5])
407+
with patch(
408+
"opencontractserver.pipeline.rerankers.cross_encoder_reranker."
409+
"_load_cross_encoder",
410+
return_value=stub,
411+
):
412+
results = self._reranker().rerank("q", ["a", "b", "c"])
413+
scores = {r.index: r.score for r in results}
414+
self.assertAlmostEqual(scores[0], 0.1)
415+
self.assertAlmostEqual(scores[1], 0.9)
416+
self.assertAlmostEqual(scores[2], 0.5)
417+
# max_length is forwarded onto the underlying model.
418+
self.assertEqual(stub.max_length, 128)
419+
# Empty/None passages are normalized to "" before pairing.
420+
self.assertEqual(stub.last_pairs[0], ("q", "a"))
421+
422+
def test_pads_when_model_returns_too_few_scores(self) -> None:
423+
stub = _StubCrossEncoder([0.7]) # only 1 score for 3 passages
424+
with patch(
425+
"opencontractserver.pipeline.rerankers.cross_encoder_reranker."
426+
"_load_cross_encoder",
427+
return_value=stub,
428+
):
429+
results = self._reranker().rerank("q", ["a", "b", "c"])
430+
# First passage keeps the real score; the rest get -inf padding.
431+
self.assertAlmostEqual(results[0].score, 0.7)
432+
self.assertEqual(results[1].score, float("-inf"))
433+
self.assertEqual(results[2].score, float("-inf"))
434+
435+
def test_scalar_score_is_normalized_to_list(self) -> None:
436+
# Some single-pair responses come back as a 0-D scalar instead of a
437+
# sequence; the reranker normalizes by wrapping it.
438+
stub = _StubCrossEncoder(0.42)
439+
with patch(
440+
"opencontractserver.pipeline.rerankers.cross_encoder_reranker."
441+
"_load_cross_encoder",
442+
return_value=stub,
443+
):
444+
results = self._reranker().rerank("q", ["only"])
445+
self.assertEqual(len(results), 1)
446+
self.assertAlmostEqual(results[0].score, 0.42)
447+
448+
def test_load_cross_encoder_caches_model_per_key(self) -> None:
449+
from opencontractserver.pipeline.rerankers import cross_encoder_reranker
450+
451+
sentinel = object()
452+
with patch(
453+
"opencontractserver.pipeline.rerankers.cross_encoder_reranker."
454+
"_MODEL_CACHE",
455+
{("model-a", "cpu"): sentinel},
456+
):
457+
got = cross_encoder_reranker._load_cross_encoder("model-a", "cpu")
458+
self.assertIs(got, sentinel)
459+
301460

302461
# --------------------------------------------------------------------------- #
303462
# Pipeline utility tests

0 commit comments

Comments
 (0)