|
105 | 105 | T = TypeVar("T") |
106 | 106 |
|
107 | 107 |
|
| 108 | +def _make_similarity_search_tool(vector_store: Any) -> Callable: |
| 109 | + """Build the citation-capturing similarity_search tool for a vector store. |
| 110 | +
|
| 111 | + Both the document and corpus agent factories used to define this closure |
| 112 | + inline; the only thing that differed was which vector store was bound. |
| 113 | + Centralising it here keeps the citation-accumulation contract — push |
| 114 | + every real annotation PK into ``ctx.deps.retrieved_annotation_ids`` — |
| 115 | + in a single place. The tool name remains ``similarity_search`` so |
| 116 | + downstream event handlers and source-linking logic are unaffected. |
| 117 | + """ |
| 118 | + |
| 119 | + async def similarity_search( |
| 120 | + ctx: RunContext[PydanticAIDependencies], |
| 121 | + query: str, |
| 122 | + k: int = 8, |
| 123 | + modalities: Optional[list[str]] = None, |
| 124 | + ) -> list[dict[str, Any]]: |
| 125 | + """Semantic vector search over the corpus annotations. |
| 126 | +
|
| 127 | + Returns the top-k nearest annotations for ``query`` as a list of |
| 128 | + dicts with keys ``annotation_id``, ``content``, ``document_id``, |
| 129 | + ``corpus_id``, ``page``, ``similarity_score``, ``label``, and |
| 130 | + ``json``. Each real annotation's ID is captured into |
| 131 | + ``ctx.deps.retrieved_annotation_ids`` so the caller can later link |
| 132 | + citations to the owning object (e.g. ``Datacell.sources``). |
| 133 | + """ |
| 134 | + results = await vector_store.similarity_search( |
| 135 | + query, k=k, modalities=modalities |
| 136 | + ) |
| 137 | + for r in results: |
| 138 | + if not isinstance(r, dict): |
| 139 | + continue |
| 140 | + aid = r.get("annotation_id") |
| 141 | + # Real annotation PKs are positive ints; synthetic / ad-hoc |
| 142 | + # match IDs are negative and must not be persisted. |
| 143 | + if isinstance(aid, int) and aid > 0: |
| 144 | + ctx.deps.retrieved_annotation_ids.append(aid) |
| 145 | + return results |
| 146 | + |
| 147 | + return similarity_search |
| 148 | + |
| 149 | + |
108 | 150 | def _get_function_tools(agent: PydanticAIAgent) -> dict: |
109 | 151 | """Return the function-tools dict from a pydantic-ai Agent. |
110 | 152 |
|
@@ -2059,42 +2101,10 @@ async def create( |
2059 | 2101 | **_vs_kwargs |
2060 | 2102 | ) |
2061 | 2103 |
|
2062 | | - # Default vector search tool: wraps the store's bound method so we can |
2063 | | - # append real annotation IDs returned by the retrieval to the per-run |
2064 | | - # citation accumulator on ``ctx.deps``. Pydantic-AI inspects the |
2065 | | - # signature and injects ``ctx`` because its first parameter is typed |
2066 | | - # as ``RunContext[PydanticAIDependencies]``. The tool name is |
2067 | | - # preserved as ``similarity_search`` so existing event handlers that |
2068 | | - # match on the tool name continue to work. |
2069 | | - async def similarity_search( |
2070 | | - ctx: RunContext[PydanticAIDependencies], |
2071 | | - query: str, |
2072 | | - k: int = 8, |
2073 | | - modalities: Optional[list[str]] = None, |
2074 | | - ) -> list[dict[str, Any]]: |
2075 | | - """Semantic vector search over the corpus annotations. |
2076 | | -
|
2077 | | - Returns the top-k nearest annotations for ``query`` as a list of |
2078 | | - dicts with keys ``annotation_id``, ``content``, ``document_id``, |
2079 | | - ``corpus_id``, ``page``, ``similarity_score``, ``label``, and |
2080 | | - ``json``. Each real annotation's ID is captured into |
2081 | | - ``ctx.deps.retrieved_annotation_ids`` so the caller can later link |
2082 | | - citations to the owning object (e.g. ``Datacell.sources``). |
2083 | | - """ |
2084 | | - results = await vector_store.similarity_search( |
2085 | | - query, k=k, modalities=modalities |
2086 | | - ) |
2087 | | - for r in results: |
2088 | | - if not isinstance(r, dict): |
2089 | | - continue |
2090 | | - aid = r.get("annotation_id") |
2091 | | - # Real annotation PKs are positive ints; synthetic / ad-hoc |
2092 | | - # match IDs are negative and must not be persisted. |
2093 | | - if isinstance(aid, int) and aid > 0: |
2094 | | - ctx.deps.retrieved_annotation_ids.append(aid) |
2095 | | - return results |
2096 | | - |
2097 | | - default_vs_tool: Callable = similarity_search |
| 2104 | + # See ``_make_similarity_search_tool`` for the citation-accumulation |
| 2105 | + # contract; the tool name remains ``similarity_search`` so existing |
| 2106 | + # event handlers that match on the tool name continue to work. |
| 2107 | + default_vs_tool: Callable = _make_similarity_search_tool(vector_store) |
2098 | 2108 |
|
2099 | 2109 | # ----------------------------- |
2100 | 2110 | # Auto-build pure passthrough tools from registry |
@@ -2598,36 +2608,9 @@ async def create( |
2598 | 2608 | **_vs_kwargs |
2599 | 2609 | ) |
2600 | 2610 |
|
2601 | | - # Default vector search tool: wraps the store's bound method to |
2602 | | - # capture real annotation IDs returned during retrieval. See the |
2603 | | - # equivalent wrapper in ``PydanticAIDocumentAgent.create`` for the |
2604 | | - # rationale — we preserve the tool name ``similarity_search`` so |
2605 | | - # downstream event / source handling is unaffected. |
2606 | | - async def similarity_search( |
2607 | | - ctx: RunContext[PydanticAIDependencies], |
2608 | | - query: str, |
2609 | | - k: int = 8, |
2610 | | - modalities: Optional[list[str]] = None, |
2611 | | - ) -> list[dict[str, Any]]: |
2612 | | - """Semantic vector search over the corpus annotations. |
2613 | | -
|
2614 | | - Returns the top-k nearest annotations for ``query`` as dicts. |
2615 | | - Appends every real annotation PK returned to |
2616 | | - ``ctx.deps.retrieved_annotation_ids`` so the caller can link |
2617 | | - citations to the owning object after the run completes. |
2618 | | - """ |
2619 | | - results = await vector_store.similarity_search( |
2620 | | - query, k=k, modalities=modalities |
2621 | | - ) |
2622 | | - for r in results: |
2623 | | - if not isinstance(r, dict): |
2624 | | - continue |
2625 | | - aid = r.get("annotation_id") |
2626 | | - if isinstance(aid, int) and aid > 0: |
2627 | | - ctx.deps.retrieved_annotation_ids.append(aid) |
2628 | | - return results |
2629 | | - |
2630 | | - default_vs_tool: Callable = similarity_search |
| 2611 | + # See ``_make_similarity_search_tool`` for the shared citation-capturing |
| 2612 | + # closure used by both the document and corpus agent factories. |
| 2613 | + default_vs_tool: Callable = _make_similarity_search_tool(vector_store) |
2631 | 2614 |
|
2632 | 2615 | # ----------------------------- |
2633 | 2616 | # Auto-build passthrough tools from registry |
|
0 commit comments