Skip to content

Commit 7d5dfe6

Browse files
committed
Fix CI: pre-commit auto-fixes + mypy ignore on dynamic adapter kwargs
- trailing-whitespace + end-of-file-fixer applied to committed benchmark run artifacts (CSV/JSON in docs/benchmarks/runs/) - black + isort reformatted the perf-tuning files (no logic change) - mypy: narrow type: ignore[arg-type] on run_benchmark.py:198 'adapter_cls(**adapter_kwargs)' — adapter_kwargs is dict[str, object] sourced from argparse and the runtime types match each adapter's signature, but mypy can't statically narrow across the polymorphic adapter union
1 parent 66b1f5c commit 7d5dfe6

12 files changed

Lines changed: 64 additions & 135 deletions

File tree

opencontractserver/benchmarks/adapters/legalbench_rag.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,7 @@ def _ensure_loaded(self) -> None:
244244
tests = payload.get("tests", [])
245245
if self.paper_sampling:
246246
pre_count = len(tests)
247-
tests = _paper_sample_tests(
248-
tests, max_per_subset=self.max_per_subset
249-
)
247+
tests = _paper_sample_tests(tests, max_per_subset=self.max_per_subset)
250248
logger.info(
251249
"LegalBench-RAG paper-faithful sampling for subset %s: "
252250
"%d -> %d tasks (sorted by random(seed=file_path), cap=%d)",

opencontractserver/benchmarks/management/commands/run_benchmark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,10 @@ def handle(self, *args, **options) -> None:
195195
if benchmark_name == "legalbench-rag":
196196
adapter_kwargs["paper_sampling"] = options.get("paper_sampling", True)
197197
adapter_kwargs["max_per_subset"] = options.get("max_per_subset", 194)
198-
adapter = adapter_cls(**adapter_kwargs)
198+
# mypy can't statically narrow ``adapter_kwargs: dict[str, object]``
199+
# against each adapter subclass's specific parameter types. The dict
200+
# values are sourced from argparse, which already validated them.
201+
adapter = adapter_cls(**adapter_kwargs) # type: ignore[arg-type]
199202

200203
self.stdout.write(
201204
self.style.NOTICE(

opencontractserver/documents/signals.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,7 @@ def _gc_orphan_structural_set(sender, instance, **kwargs):
194194
set_id = getattr(instance, "_structural_set_id_at_delete", None)
195195
if set_id is None:
196196
return
197-
StructuralAnnotationSet = apps.get_model(
198-
"annotations", "StructuralAnnotationSet"
199-
)
197+
StructuralAnnotationSet = apps.get_model("annotations", "StructuralAnnotationSet")
200198
Document = apps.get_model("documents", "Document")
201199
if Document.objects.filter(structural_annotation_set_id=set_id).exists():
202200
return

opencontractserver/pipeline/embedders/openai_embedder.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,11 @@ def _embed_text_impl(self, text: str, **all_kwargs) -> Optional[list[float]]:
229229
except openai.BadRequestError as e:
230230
logger.error(f"OpenAI API bad request: {e}")
231231
return None
232-
except (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError):
232+
except (
233+
openai.RateLimitError,
234+
openai.APITimeoutError,
235+
openai.APIConnectionError,
236+
):
233237
# Transient: re-raise so callers can retry. See the matching
234238
# block in ``embed_texts_batch`` for the rationale.
235239
logger.warning(
@@ -304,9 +308,7 @@ def embed_texts_batch( # type: ignore[override]
304308
all_kwargs = {**self.get_component_settings(), **direct_kwargs}
305309
model = all_kwargs.get("openai_embedding_model", s.openai_embedding_model)
306310
dimensions = int(
307-
all_kwargs.get(
308-
"openai_embedding_dimensions", s.openai_embedding_dimensions
309-
)
311+
all_kwargs.get("openai_embedding_dimensions", s.openai_embedding_dimensions)
310312
)
311313

312314
try:
@@ -336,15 +338,21 @@ def embed_texts_batch( # type: ignore[override]
336338
return out
337339
except openai.AuthenticationError:
338340
# Permanent: a wrong API key won't fix itself with retry.
339-
logger.error("OpenAI API authentication failed (batch). Check your API key.")
341+
logger.error(
342+
"OpenAI API authentication failed (batch). Check your API key."
343+
)
340344
return None
341345
except openai.BadRequestError as e:
342346
# Permanent: malformed input (oversize, bad dimensions, etc.).
343347
# Returning None prevents celery from burning retries on
344348
# something that will fail every time.
345349
logger.error("OpenAI API bad request (batch): %s", e)
346350
return None
347-
except (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError):
351+
except (
352+
openai.RateLimitError,
353+
openai.APITimeoutError,
354+
openai.APIConnectionError,
355+
):
348356
# Transient: re-raise so the celery task's autoretry_for=(Exception,)
349357
# can take over with proper backoff. The OpenAI SDK already
350358
# absorbed up to OPENAI_CLIENT_MAX_RETRIES of these

opencontractserver/pipeline/parsers/text_chunkers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,7 @@ def chunk(self, text: str) -> Iterator[TextChunk]:
250250
# NaN and aborts the entire ingest pipeline. Observed in CUAD documents
251251
# (e.g. JuniperPharmaceuticalsInc_…) where copy-paste artifacts left
252252
# runs of ``​`` characters between real paragraphs.
253-
_INVISIBLE_CHARS_RE = re.compile(
254-
r"[   -‏
-  -⁠ ­]"
255-
)
253+
_INVISIBLE_CHARS_RE = re.compile(r"[   -‏
-  -⁠ ­]")
256254

257255

258256
@register_chunker

opencontractserver/tasks/embeddings_task.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,7 @@ def _batch_embed_text_annotations(
483483

484484
# Carve into sub-batches up front so we can fan them out concurrently.
485485
chunks: list[list[tuple[Annotation, str]]] = [
486-
items[i : i + api_batch_size]
487-
for i in range(0, len(items), api_batch_size)
486+
items[i : i + api_batch_size] for i in range(0, len(items), api_batch_size)
488487
]
489488

490489
# ------------------------------------------------------------------ #
@@ -523,8 +522,7 @@ def _embed_one(chunk):
523522
# Map future -> chunk index for logging/sub-batch numbering.
524523
with ThreadPoolExecutor(max_workers=max_workers) as executor:
525524
future_to_idx = {
526-
executor.submit(_embed_one, chunk): idx
527-
for idx, chunk in enumerate(chunks)
525+
executor.submit(_embed_one, chunk): idx for idx, chunk in enumerate(chunks)
528526
}
529527
for future in as_completed(future_to_idx):
530528
idx = future_to_idx[future]
@@ -616,9 +614,7 @@ def _embed_one(chunk):
616614
f"Failed to store embedding for annotation {annot.id}: {e}"
617615
)
618616
result["failed"] += 1
619-
result["errors"].append(
620-
f"Annotation {annot.id}: store failed: {e}"
621-
)
617+
result["errors"].append(f"Annotation {annot.id}: store failed: {e}")
622618

623619

624620
@shared_task(
@@ -744,9 +740,7 @@ def calculate_embeddings_for_annotation_batch(
744740
# Per-embedder ``api_batch_size`` falls back to the global default
745741
# for embedders that haven't overridden it (and for legacy paths
746742
# that pass an embedder instance without the attribute).
747-
api_batch_size = getattr(
748-
embedder, "api_batch_size", EMBEDDING_API_BATCH_SIZE
749-
)
743+
api_batch_size = getattr(embedder, "api_batch_size", EMBEDDING_API_BATCH_SIZE)
750744
if text_only_annots:
751745
logger.info(
752746
f"Batch-embedding {len(text_only_annots)} text-only annotations "

opencontractserver/tests/permissioning/test_permissioning.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -993,9 +993,7 @@ class SetPermissionsIsNewTests(TestCase):
993993

994994
@classmethod
995995
def setUpTestData(cls):
996-
cls.user = User.objects.create_user(
997-
username="is_new_test_user", password="pw"
998-
)
996+
cls.user = User.objects.create_user(username="is_new_test_user", password="pw")
999997

1000998
def test_is_new_skips_remove_perm_calls(self):
1001999
"""When called with ``is_new=True`` on a fresh object, the function
@@ -1005,9 +1003,7 @@ def test_is_new_skips_remove_perm_calls(self):
10051003

10061004
new_corpus = Corpus.objects.create(title="brand new", creator=self.user)
10071005

1008-
with patch(
1009-
"opencontractserver.utils.permissioning.remove_perm"
1010-
) as mock_remove:
1006+
with patch("opencontractserver.utils.permissioning.remove_perm") as mock_remove:
10111007
set_permissions_for_obj_to_user(
10121008
self.user, new_corpus, [PermissionTypes.ALL], is_new=True
10131009
)
@@ -1090,7 +1086,5 @@ def test_is_new_resolves_int_user_id(self):
10901086
self.user.id, new_corpus, [PermissionTypes.ALL], is_new=True
10911087
)
10921088
self.assertTrue(
1093-
user_has_permission_for_obj(
1094-
self.user, new_corpus, PermissionTypes.READ
1095-
)
1089+
user_has_permission_for_obj(self.user, new_corpus, PermissionTypes.READ)
10961090
)

opencontractserver/tests/test_batch_embedding.py

Lines changed: 17 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,7 @@ def _mock_response(self, status_code, embeddings=None):
464464
resp.json.return_value = {"embeddings": embeddings}
465465
return resp
466466

467-
@patch(
468-
"requests.Session.post"
469-
)
467+
@patch("requests.Session.post")
470468
def test_successful_batch(self, mock_post):
471469
"""Successful batch returns list of vectors."""
472470
embedder = self._make_embedder()
@@ -481,9 +479,7 @@ def test_successful_batch(self, mock_post):
481479
call_kwargs = mock_post.call_args
482480
self.assertIn("/embeddings/batch", call_kwargs[0][0])
483481

484-
@patch(
485-
"requests.Session.post"
486-
)
482+
@patch("requests.Session.post")
487483
def test_client_error_raises(self, mock_post):
488484
"""4xx response raises EmbeddingClientError.
489485
@@ -499,9 +495,7 @@ def test_client_error_raises(self, mock_post):
499495
with self.assertRaises(EmbeddingClientError):
500496
embedder.embed_texts_batch(["hello"])
501497

502-
@patch(
503-
"requests.Session.post"
504-
)
498+
@patch("requests.Session.post")
505499
def test_server_error_raises(self, mock_post):
506500
"""5xx response raises EmbeddingServerError for Celery retry."""
507501
embedder = self._make_embedder()
@@ -510,9 +504,7 @@ def test_server_error_raises(self, mock_post):
510504
with self.assertRaises(EmbeddingServerError):
511505
embedder.embed_texts_batch(["hello"])
512506

513-
@patch(
514-
"requests.Session.post"
515-
)
507+
@patch("requests.Session.post")
516508
def test_non_retriable_exception_returns_none(self, mock_post):
517509
"""Non-retriable exception (e.g., builtin ConnectionError) returns None."""
518510
embedder = self._make_embedder()
@@ -522,9 +514,7 @@ def test_non_retriable_exception_returns_none(self, mock_post):
522514

523515
self.assertIsNone(result)
524516

525-
@patch(
526-
"requests.Session.post"
527-
)
517+
@patch("requests.Session.post")
528518
def test_timeout_raises_for_retry(self, mock_post):
529519
"""requests.Timeout re-raises for Celery retry."""
530520
embedder = self._make_embedder()
@@ -533,9 +523,7 @@ def test_timeout_raises_for_retry(self, mock_post):
533523
with self.assertRaises(requests.exceptions.Timeout):
534524
embedder.embed_texts_batch(["hello"])
535525

536-
@patch(
537-
"requests.Session.post"
538-
)
526+
@patch("requests.Session.post")
539527
def test_connection_error_raises_for_retry(self, mock_post):
540528
"""requests.ConnectionError re-raises for Celery retry."""
541529
embedder = self._make_embedder()
@@ -568,9 +556,7 @@ def test_no_service_url_returns_none(self):
568556
result = embedder.embed_texts_batch(["hello"])
569557
self.assertIsNone(result)
570558

571-
@patch(
572-
"requests.Session.post"
573-
)
559+
@patch("requests.Session.post")
574560
def test_3d_response_squeezed(self, mock_post):
575561
"""3D response array is squeezed to 2D."""
576562
embedder = self._make_embedder()
@@ -584,9 +570,7 @@ def test_3d_response_squeezed(self, mock_post):
584570
self.assertEqual(len(result), 2)
585571
self.assertEqual(len(result[0]), 384)
586572

587-
@patch(
588-
"requests.Session.post"
589-
)
573+
@patch("requests.Session.post")
590574
def test_vector_count_mismatch_returns_none(self, mock_post):
591575
"""Mismatched vector count returns None."""
592576
embedder = self._make_embedder()
@@ -598,9 +582,7 @@ def test_vector_count_mismatch_returns_none(self, mock_post):
598582

599583
self.assertIsNone(result)
600584

601-
@patch(
602-
"requests.Session.post"
603-
)
585+
@patch("requests.Session.post")
604586
def test_malformed_200_missing_embeddings_key(self, mock_post):
605587
"""200 response missing 'embeddings' key returns None."""
606588
embedder = self._make_embedder()
@@ -613,9 +595,7 @@ def test_malformed_200_missing_embeddings_key(self, mock_post):
613595

614596
self.assertIsNone(result)
615597

616-
@patch(
617-
"requests.Session.post"
618-
)
598+
@patch("requests.Session.post")
619599
def test_nan_values_handled_per_item(self, mock_post):
620600
"""NaN values in individual embeddings return None for those items only."""
621601
embedder = self._make_embedder()
@@ -654,9 +634,7 @@ def _mock_response(self, status_code, embeddings=None, body=None):
654634
resp.json.return_value = {"embeddings": embeddings}
655635
return resp
656636

657-
@patch(
658-
"requests.Session.post"
659-
)
637+
@patch("requests.Session.post")
660638
def test_embed_text_success_1d(self, mock_post):
661639
"""Successful single-text embedding with 1D response."""
662640
embedder = self._make_embedder()
@@ -669,9 +647,7 @@ def test_embed_text_success_1d(self, mock_post):
669647
self.assertEqual(len(result), 384)
670648
mock_post.assert_called_once()
671649

672-
@patch(
673-
"requests.Session.post"
674-
)
650+
@patch("requests.Session.post")
675651
def test_embed_text_success_2d(self, mock_post):
676652
"""Successful single-text embedding with 2D response."""
677653
embedder = self._make_embedder()
@@ -683,9 +659,7 @@ def test_embed_text_success_2d(self, mock_post):
683659
self.assertIsNotNone(result)
684660
self.assertEqual(len(result), 384)
685661

686-
@patch(
687-
"requests.Session.post"
688-
)
662+
@patch("requests.Session.post")
689663
def test_embed_text_malformed_200(self, mock_post):
690664
"""200 response missing 'embeddings' key returns None."""
691665
embedder = self._make_embedder()
@@ -695,9 +669,7 @@ def test_embed_text_malformed_200(self, mock_post):
695669

696670
self.assertIsNone(result)
697671

698-
@patch(
699-
"requests.Session.post"
700-
)
672+
@patch("requests.Session.post")
701673
def test_embed_text_nan_returns_none(self, mock_post):
702674
"""NaN in single-text response returns None."""
703675
embedder = self._make_embedder()
@@ -707,9 +679,7 @@ def test_embed_text_nan_returns_none(self, mock_post):
707679

708680
self.assertIsNone(result)
709681

710-
@patch(
711-
"requests.Session.post"
712-
)
682+
@patch("requests.Session.post")
713683
def test_embed_text_client_error(self, mock_post):
714684
"""4xx response returns None."""
715685
embedder = self._make_embedder()
@@ -719,9 +689,7 @@ def test_embed_text_client_error(self, mock_post):
719689

720690
self.assertIsNone(result)
721691

722-
@patch(
723-
"requests.Session.post"
724-
)
692+
@patch("requests.Session.post")
725693
def test_embed_text_server_error(self, mock_post):
726694
"""5xx response returns None."""
727695
embedder = self._make_embedder()
@@ -731,9 +699,7 @@ def test_embed_text_server_error(self, mock_post):
731699

732700
self.assertIsNone(result)
733701

734-
@patch(
735-
"requests.Session.post"
736-
)
702+
@patch("requests.Session.post")
737703
def test_embed_text_exception(self, mock_post):
738704
"""Network exception returns None."""
739705
embedder = self._make_embedder()

opencontractserver/tests/test_benchmarks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -699,10 +699,10 @@ def _random_spans(
699699
return out
700700

701701
def test_recall_matches_upstream_on_random_inputs(self):
702-
from opencontractserver.benchmarks.metrics import char_recall_paper
703-
704702
import random as rng_module
705703

704+
from opencontractserver.benchmarks.metrics import char_recall_paper
705+
706706
rng = rng_module.Random(42)
707707
for trial in range(200):
708708
n_pred = rng.randint(0, 30)
@@ -728,10 +728,10 @@ def test_recall_matches_upstream_on_random_inputs(self):
728728
)
729729

730730
def test_precision_matches_upstream_on_random_inputs(self):
731-
from opencontractserver.benchmarks.metrics import char_precision_paper
732-
733731
import random as rng_module
734732

733+
from opencontractserver.benchmarks.metrics import char_precision_paper
734+
735735
rng = rng_module.Random(43)
736736
for trial in range(200):
737737
n_pred = rng.randint(0, 30)

0 commit comments

Comments
 (0)