From 2328a54ab3c89cc9e28ed2e219fbc10fc54a40fb Mon Sep 17 00:00:00 2001 From: WenjinXie Date: Sat, 13 Jun 2026 00:53:34 +0800 Subject: [PATCH 1/2] [fix] Keep numpy off the async pool for cross-language RAG queries A Python vector store's query path runs numpy (e.g. chroma's embedding normalization). numpy releases and re-acquires the GIL during the conversion, which deadlocks on an async pemja worker thread since pemja keeps a single PyThreadState. Split the async RAG query so only the numpy normalization runs on the operator thread: embed and query stay async, normalize is sync. - PythonVectorStore: resolve the embedding model in open(); add embedQuery / normalizeEmbedding / queryNormalized hooks; forward pre-computed vectors. - BaseVectorStore: expose getEmbeddingModel to subclasses. - ContextRetrievalAction: async embed -> sync normalize -> async query. - ChromaVectorStore: add _normalize_embeddings; query accepts pre-normalized. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../api/vectorstores/BaseVectorStore.java | 2 +- .../python/PythonVectorStore.java | 139 +++++++++++------- ...onCollectionManageableVectorStoreTest.java | 26 ++-- .../plan/actions/ContextRetrievalAction.java | 106 ++++++++++--- .../api/vector_stores/vector_store.py | 10 ++ .../chroma/chroma_vector_store.py | 18 +++ 6 files changed, 216 insertions(+), 85 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java b/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java index 442e705ed..d0fc4ce72 100644 --- a/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java +++ b/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java @@ -339,7 +339,7 @@ protected void ensureEmbeddings(List documents) { } } - private BaseEmbeddingModelSetup getEmbeddingModel() { + protected BaseEmbeddingModelSetup getEmbeddingModel() { if (embeddingModel == null) { throw new IllegalStateException( "No embedding model configured on this vector store. " diff --git a/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java b/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java index 6b2d5d71a..69025cc11 100644 --- a/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java +++ b/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java @@ -24,13 +24,12 @@ import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import org.apache.flink.agents.api.vectorstores.BaseVectorStore; import org.apache.flink.agents.api.vectorstores.Document; -import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; -import org.apache.flink.agents.api.vectorstores.VectorStoreQueryResult; import pemja.core.object.PyObject; import javax.annotation.Nullable; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -45,10 +44,11 @@ *

This class serves as a connection layer between Java and Python vector store environments, * enabling seamless integration of Python-based vector stores within Java applications. * - *

The {@code *Embedding} hooks ({@link #queryEmbedding}, {@link #addEmbedding}, {@link - * #updateEmbedding}) are no-ops here: this bridge forwards each public method directly to its - * Python counterpart, which already handles auto-embedding internally, so the Java auto-embed path - * in {@link BaseVectorStore} is not used. + *

Embedding is generated on the Java side via {@link BaseVectorStore}'s public add/update/query; + * the {@code *Embedding} hooks then forward the pre-computed vectors to the Python {@code + * _add_embedding}/{@code _update_embedding}/{@code _query_embedding}. This keeps each store + * operation a single Java->Python crossing, avoiding a Python->Java re-entry that deadlocks when + * run on the async pool thread. */ public class PythonVectorStore extends BaseVectorStore implements PythonResourceWrapper { protected final PyObject vectorStore; @@ -74,52 +74,14 @@ public PythonVectorStore( } @Override - public void open() { + public void open() throws Exception { + // Resolve the Java-side embedding model so embeddings are generated on the mailbox thread + // (single Java->Python crossing per op). Without this, add/query would re-embed inside + // Python and re-enter Java, which deadlocks when run on the async pool thread. + super.open(); adapter.callMethod(vectorStore, "open", Collections.emptyMap()); } - @Override - @SuppressWarnings("unchecked") - public List add( - List documents, @Nullable String collection, Map extraArgs) - throws IOException { - Object pythonDocuments = adapter.toPythonDocuments(documents); - - Map kwargs = new HashMap<>(extraArgs); - kwargs.put("documents", pythonDocuments); - - if (collection != null) { - kwargs.put("collection_name", collection); - } - - return (List) adapter.callMethod(vectorStore, "add", kwargs); - } - - @Override - public void update( - List documents, @Nullable String collection, Map extraArgs) - throws IOException { - Object pythonDocuments = adapter.toPythonDocuments(documents); - - Map kwargs = new HashMap<>(extraArgs); - kwargs.put("documents", pythonDocuments); - - if (collection != null) { - kwargs.put("collection_name", collection); - } - - adapter.callMethod(vectorStore, "update", kwargs); - } - - @Override - public VectorStoreQueryResult query(VectorStoreQuery query) { - Object pythonQuery = adapter.toPythonVectorStoreQuery(query); - - PyObject pythonResult = (PyObject) vectorStore.invokeMethod("query", pythonQuery); - - return adapter.fromPythonVectorStoreQueryResult(pythonResult); - } - @Override @SuppressWarnings("unchecked") public List get( @@ -170,30 +132,101 @@ public void delete( @Override public Map getStoreKwargs() { - return Map.of(); + return new HashMap<>(); } @Override + @SuppressWarnings("unchecked") public List queryEmbedding( float[] embedding, int limit, @Nullable String collection, @Nullable Map filters, Map args) { - return List.of(); + Map kwargs = new HashMap<>(args); + // pemja maps float[] to a Python tuple, which Chroma rejects; pass a list instead. + List embeddingList = new ArrayList<>(embedding.length); + for (float v : embedding) { + embeddingList.add(v); + } + kwargs.put("embedding", embeddingList); + kwargs.put("limit", limit); + if (collection != null) { + kwargs.put("collection_name", collection); + } + if (filters != null) { + kwargs.put("filters", filters); + } + Object pythonDocuments = adapter.callMethod(vectorStore, "_query_embedding", kwargs); + return adapter.fromPythonDocuments((List) pythonDocuments); + } + + /** Embed query text via the configured model (no numpy, so it stays on the async pool). */ + public float[] embedQuery(String text) { + return getEmbeddingModel().embed(text); + } + + /** + * Convert the raw embedding to the Python store's native vector form. This runs the numpy + * conversion on the mailbox thread: numpy releases/re-acquires the GIL during the copy, and + * pemja keeps a single PyThreadState, so doing it on an async worker thread can stall the + * interpreter (observed as a hang in CI; benign with spare cores locally). The returned Python + * object is forwarded back into {@link #queryNormalized}. See + * https://github.com/apache/flink-agents/issues/844. + */ + public Object normalizeEmbedding(float[] embedding) { + List embeddingList = new ArrayList<>(embedding.length); + for (float v : embedding) { + embeddingList.add(v); + } + Map kwargs = new HashMap<>(); + kwargs.put("embeddings", embeddingList); + return adapter.callMethod(vectorStore, "_normalize_embeddings", kwargs); + } + + /** Query with a pre-normalized embedding; numpy-free, so it stays on the async pool. */ + @SuppressWarnings("unchecked") + public List queryNormalized( + Object normalizedEmbedding, + int limit, + @Nullable String collection, + @Nullable Map filters, + Map args) { + Map kwargs = new HashMap<>(args); + kwargs.put("embedding", normalizedEmbedding); + kwargs.put("limit", limit); + if (collection != null) { + kwargs.put("collection_name", collection); + } + if (filters != null) { + kwargs.put("filters", filters); + } + Object pythonDocuments = adapter.callMethod(vectorStore, "_query_embedding", kwargs); + return adapter.fromPythonDocuments((List) pythonDocuments); } @Override + @SuppressWarnings("unchecked") public List addEmbedding( List documents, @Nullable String collection, Map extraArgs) throws IOException { - return List.of(); + Map kwargs = new HashMap<>(extraArgs); + kwargs.put("documents", adapter.toPythonDocuments(documents)); + if (collection != null) { + kwargs.put("collection_name", collection); + } + return (List) adapter.callMethod(vectorStore, "_add_embedding", kwargs); } @Override public void updateEmbedding( List documents, @Nullable String collection, Map extraArgs) { - // no-op; Python forwards public update() directly + Map kwargs = new HashMap<>(extraArgs); + kwargs.put("documents", adapter.toPythonDocuments(documents)); + if (collection != null) { + kwargs.put("collection_name", collection); + } + adapter.callMethod(vectorStore, "_update_embedding", kwargs); } @Override diff --git a/api/src/test/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStoreTest.java b/api/src/test/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStoreTest.java index 0a5e04cb1..37b27f856 100644 --- a/api/src/test/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStoreTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStoreTest.java @@ -161,17 +161,19 @@ void testDeleteCollection() throws Exception { @Test void testAddDocuments() throws Exception { - List documents = - Arrays.asList( - new Document("content1", Map.of("key", "value1"), "doc1"), - new Document("content2", Map.of("key", "value2"), "doc2")); + Document d1 = new Document("content1", Map.of("key", "value1"), "doc1"); + Document d2 = new Document("content2", Map.of("key", "value2"), "doc2"); + // Pre-computed embeddings so the Java auto-embed path is skipped (no model configured). + d1.setEmbedding(new float[] {0.1f, 0.2f}); + d2.setEmbedding(new float[] {0.3f, 0.4f}); + List documents = Arrays.asList(d1, d2); String collection = "test_collection"; Map extraArgs = Map.of("batch_size", 10); List expectedIds = Arrays.asList("doc1", "doc2"); when(mockAdapter.toPythonDocuments(documents)).thenReturn(new Object()); - when(mockAdapter.callMethod(eq(mockVectorStore), eq("add"), any(Map.class))) + when(mockAdapter.callMethod(eq(mockVectorStore), eq("_add_embedding"), any(Map.class))) .thenReturn(expectedIds); List result = vectorStore.add(documents, collection, extraArgs); @@ -184,7 +186,7 @@ void testAddDocuments() throws Exception { verify(mockAdapter) .callMethod( eq(mockVectorStore), - eq("add"), + eq("_add_embedding"), argThat( kwargs -> { assertThat(kwargs).containsKey("documents"); @@ -196,10 +198,12 @@ void testAddDocuments() throws Exception { @Test void testUpdateDocuments() throws Exception { - List documents = - Arrays.asList( - new Document("c1", Map.of("k", "v1"), "doc1"), - new Document("c2", Map.of("k", "v2"), "doc2")); + Document d1 = new Document("c1", Map.of("k", "v1"), "doc1"); + Document d2 = new Document("c2", Map.of("k", "v2"), "doc2"); + // Pre-computed embeddings so the Java auto-embed path is skipped (no model configured). + d1.setEmbedding(new float[] {0.1f, 0.2f}); + d2.setEmbedding(new float[] {0.3f, 0.4f}); + List documents = Arrays.asList(d1, d2); String collection = "test_collection"; Map extraArgs = Map.of("batch_size", 5); @@ -211,7 +215,7 @@ void testUpdateDocuments() throws Exception { verify(mockAdapter) .callMethod( eq(mockVectorStore), - eq("update"), + eq("_update_embedding"), argThat( kwargs -> { assertThat(kwargs).containsKey("documents"); diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java index 28a0853b6..6b91c05cd 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java @@ -26,6 +26,7 @@ import org.apache.flink.agents.api.event.ContextRetrievalResponseEvent; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.Document; import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; import org.apache.flink.agents.api.vectorstores.VectorStoreQueryResult; import org.apache.flink.agents.api.vectorstores.python.PythonVectorStore; @@ -71,26 +72,34 @@ public static void processContextRetrievalRequest(Event event, RunnerContext ctx contextRetrievalRequestEvent.getQuery(), contextRetrievalRequestEvent.getMaxResults()); - DurableCallable callable = - new DurableCallable() { - @Override - public String getId() { - return "rag-async"; - } - - @Override - public Class getResultClass() { - return VectorStoreQueryResult.class; - } - - @Override - public VectorStoreQueryResult call() throws Exception { - return vectorStore.query(vectorStoreQuery); - } - }; - - VectorStoreQueryResult result = - ragAsync ? ctx.durableExecuteAsync(callable) : ctx.durableExecute(callable); + final VectorStoreQueryResult result; + if (ragAsync && vectorStore instanceof PythonVectorStore) { + // A Python store's query path runs numpy, which can stall on the async pool + // (pemja keeps a single PyThreadState; numpy releasing/re-acquiring the GIL on a + // worker thread hangs intermittently — seen in CI, fine locally with spare cores). + // Keep only that numpy step on the mailbox thread; embed and query stay async. + result = queryPythonAsync((PythonVectorStore) vectorStore, vectorStoreQuery, ctx); + } else { + DurableCallable callable = + new DurableCallable() { + @Override + public String getId() { + return "rag-async"; + } + + @Override + public Class getResultClass() { + return VectorStoreQueryResult.class; + } + + @Override + public VectorStoreQueryResult call() throws Exception { + return vectorStore.query(vectorStoreQuery); + } + }; + result = + ragAsync ? ctx.durableExecuteAsync(callable) : ctx.durableExecute(callable); + } ctx.sendEvent( new ContextRetrievalResponseEvent( @@ -99,4 +108,61 @@ public VectorStoreQueryResult call() throws Exception { result.getDocuments())); } } + + /** + * Run a Python vector-store RAG query while keeping numpy off the async pool: embed async, + * normalize the embedding synchronously on the mailbox thread (numpy on a worker thread can + * stall under pemja's single PyThreadState), then query async with the pre-normalized vector. + * See https://github.com/apache/flink-agents/issues/844. + */ + private static VectorStoreQueryResult queryPythonAsync( + PythonVectorStore store, VectorStoreQuery query, RunnerContext ctx) throws Exception { + final float[] embedding = + ctx.durableExecuteAsync( + new DurableCallable() { + @Override + public String getId() { + return "rag-embed"; + } + + @Override + public Class getResultClass() { + return float[].class; + } + + @Override + public float[] call() { + return store.embedQuery(query.getQueryText()); + } + }); + + final Object normalized = store.normalizeEmbedding(embedding); + + final List documents = + ctx.durableExecuteAsync( + new DurableCallable>() { + @Override + public String getId() { + return "rag-query"; + } + + @SuppressWarnings("unchecked") + @Override + public Class> getResultClass() { + return (Class>) (Class) List.class; + } + + @Override + public List call() { + return store.queryNormalized( + normalized, + query.getLimit(), + query.getCollection(), + query.getFilters(), + store.getStoreKwargs()); + } + }); + + return new VectorStoreQueryResult(documents); + } } diff --git a/python/flink_agents/api/vector_stores/vector_store.py b/python/flink_agents/api/vector_stores/vector_store.py index d16d00c76..2cf04c49a 100644 --- a/python/flink_agents/api/vector_stores/vector_store.py +++ b/python/flink_agents/api/vector_stores/vector_store.py @@ -397,6 +397,16 @@ def delete( **kwargs: Vector store specific parameters. """ + @staticmethod + def _normalize_embeddings(embeddings: list[float]) -> Any: + """Pre-process a query embedding before the search call. + + Hook for backends whose query path performs CPU/numpy work that must run + on the mailbox thread rather than an async cross-language worker (see the + ChromaDB override). Default is identity. + """ + return embeddings + @abstractmethod def _query_embedding( self, diff --git a/python/flink_agents/integrations/vector_stores/chroma/chroma_vector_store.py b/python/flink_agents/integrations/vector_stores/chroma/chroma_vector_store.py index 7028ccd0e..29dacf901 100644 --- a/python/flink_agents/integrations/vector_stores/chroma/chroma_vector_store.py +++ b/python/flink_agents/integrations/vector_stores/chroma/chroma_vector_store.py @@ -21,6 +21,7 @@ import chromadb from chromadb import ClientAPI as ChromaClient from chromadb import CloudClient +from chromadb.api.types import normalize_embeddings from chromadb.config import Settings from pydantic import Field from typing_extensions import override @@ -306,6 +307,20 @@ def _update_embedding( metadatas=[doc.metadata for doc in documents], ) + @staticmethod + @override + def _normalize_embeddings(embeddings: List[float]) -> Any: + """Convert the raw embedding to chroma's numpy form on the caller's thread. + + chroma's query normally runs this np.array conversion internally. numpy + releases and re-acquires the GIL during the copy, which can stall on an + async pemja worker thread (pemja keeps a single PyThreadState) — seen as a + hang in CI, benign locally with spare cores. Running it here lets the + mailbox thread do the numpy step, so the async query sees a ready ndarray. + See https://github.com/apache/flink-agents/issues/844. + """ + return normalize_embeddings([embeddings])[0] + @override def _query_embedding( self, @@ -315,6 +330,9 @@ def _query_embedding( filters: Dict[str, Any] | None = None, **kwargs: Any, ) -> List[Document]: + # ``embedding`` may be a pre-normalized ndarray (async path) or a raw list + # (sync path); chroma takes the ndarray branch for the former, avoiding any + # numpy work on this thread. collection = self._resolve_collection(collection_name, kwargs) results = collection.query( query_embeddings=[embedding], From 5dc4fdedfe88ea1cfa70d069915521c4b380dac4 Mon Sep 17 00:00:00 2001 From: WenjinXie Date: Fri, 12 Jun 2026 14:13:00 +0800 Subject: [PATCH 2/2] [dependency] Bump flink to 1.20.5 and 2.1.3 --- dist/pom.xml | 4 ++-- .../flink-agents-end-to-end-tests-integration/pom.xml | 8 ++++---- pom.xml | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dist/pom.xml b/dist/pom.xml index 9e0c0c6c5..0a72542a4 100644 --- a/dist/pom.xml +++ b/dist/pom.xml @@ -30,9 +30,9 @@ under the License. pom - 1.20.4 + 1.20.5 2.0.2 - 2.1.2 + 2.1.3 2.2.1 diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml b/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml index 119edaf84..bfbdf1e6a 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml +++ b/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml @@ -29,10 +29,10 @@ under the License. Flink Agents : E2E Tests: Integration - 1.20.3 - 2.0.1 - 2.1.1 - 2.2.0 + 1.20.5 + 2.0.2 + 2.1.3 + 2.2.1 ${flink.2.2.version} flink-agents-dist-flink-2.2 diff --git a/pom.xml b/pom.xml index e77b0a51c..4728fa589 100644 --- a/pom.xml +++ b/pom.xml @@ -41,12 +41,12 @@ under the License. ${target.java.version} 2.27.1 false - 2.2.0 + 2.2.1 4.0.0 0.9.0-incubating 5.10.1 2.18.2 - 0.5.5 + 0.5.7 2.23.1 1.7.36 3.27.7