Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ protected void ensureEmbeddings(List<Document> documents) {
}
}

private BaseEmbeddingModelSetup getEmbeddingModel() {
protected BaseEmbeddingModelSetup getEmbeddingModel() {
if (embeddingModel == null) {
throw new IllegalStateException(
"No embedding model configured on this vector store. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@
import org.apache.flink.agents.api.resource.python.PythonResourceWrapper;
import org.apache.flink.agents.api.vectorstores.BaseVectorStore;
import org.apache.flink.agents.api.vectorstores.Document;
import org.apache.flink.agents.api.vectorstores.VectorStoreQuery;
import org.apache.flink.agents.api.vectorstores.VectorStoreQueryResult;
import pemja.core.object.PyObject;

import javax.annotation.Nullable;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand All @@ -45,10 +44,11 @@
* <p>This class serves as a connection layer between Java and Python vector store environments,
* enabling seamless integration of Python-based vector stores within Java applications.
*
* <p>The {@code *Embedding} hooks ({@link #queryEmbedding}, {@link #addEmbedding}, {@link
* #updateEmbedding}) are no-ops here: this bridge forwards each public method directly to its
* Python counterpart, which already handles auto-embedding internally, so the Java auto-embed path
* in {@link BaseVectorStore} is not used.
* <p>Embedding is generated on the Java side via {@link BaseVectorStore}'s public add/update/query;
* the {@code *Embedding} hooks then forward the pre-computed vectors to the Python {@code
* _add_embedding}/{@code _update_embedding}/{@code _query_embedding}. This keeps each store
* operation a single Java->Python crossing, avoiding a Python->Java re-entry that deadlocks when
* run on the async pool thread.
*/
public class PythonVectorStore extends BaseVectorStore implements PythonResourceWrapper {
protected final PyObject vectorStore;
Expand All @@ -74,52 +74,14 @@ public PythonVectorStore(
}

@Override
public void open() {
public void open() throws Exception {
// Resolve the Java-side embedding model so embeddings are generated on the mailbox thread
// (single Java->Python crossing per op). Without this, add/query would re-embed inside
// Python and re-enter Java, which deadlocks when run on the async pool thread.
super.open();
adapter.callMethod(vectorStore, "open", Collections.emptyMap());
}

@Override
@SuppressWarnings("unchecked")
public List<String> add(
List<Document> documents, @Nullable String collection, Map<String, Object> extraArgs)
throws IOException {
Object pythonDocuments = adapter.toPythonDocuments(documents);

Map<String, Object> kwargs = new HashMap<>(extraArgs);
kwargs.put("documents", pythonDocuments);

if (collection != null) {
kwargs.put("collection_name", collection);
}

return (List<String>) adapter.callMethod(vectorStore, "add", kwargs);
}

@Override
public void update(
List<Document> documents, @Nullable String collection, Map<String, Object> extraArgs)
throws IOException {
Object pythonDocuments = adapter.toPythonDocuments(documents);

Map<String, Object> kwargs = new HashMap<>(extraArgs);
kwargs.put("documents", pythonDocuments);

if (collection != null) {
kwargs.put("collection_name", collection);
}

adapter.callMethod(vectorStore, "update", kwargs);
}

@Override
public VectorStoreQueryResult query(VectorStoreQuery query) {
Object pythonQuery = adapter.toPythonVectorStoreQuery(query);

PyObject pythonResult = (PyObject) vectorStore.invokeMethod("query", pythonQuery);

return adapter.fromPythonVectorStoreQueryResult(pythonResult);
}

@Override
@SuppressWarnings("unchecked")
public List<Document> get(
Expand Down Expand Up @@ -170,30 +132,101 @@ public void delete(

@Override
public Map<String, Object> getStoreKwargs() {
return Map.of();
return new HashMap<>();
}

@Override
@SuppressWarnings("unchecked")
public List<Document> queryEmbedding(
float[] embedding,
int limit,
@Nullable String collection,
@Nullable Map<String, Object> filters,
Map<String, Object> args) {
return List.of();
Map<String, Object> kwargs = new HashMap<>(args);
// pemja maps float[] to a Python tuple, which Chroma rejects; pass a list instead.
List<Float> embeddingList = new ArrayList<>(embedding.length);
for (float v : embedding) {
embeddingList.add(v);
}
kwargs.put("embedding", embeddingList);
kwargs.put("limit", limit);
if (collection != null) {
kwargs.put("collection_name", collection);
}
if (filters != null) {
kwargs.put("filters", filters);
}
Object pythonDocuments = adapter.callMethod(vectorStore, "_query_embedding", kwargs);
return adapter.fromPythonDocuments((List<PyObject>) pythonDocuments);
}

/** Embed query text via the configured model (no numpy, so it stays on the async pool). */
public float[] embedQuery(String text) {
return getEmbeddingModel().embed(text);
}

/**
* Convert the raw embedding to the Python store's native vector form. This runs the numpy
* conversion on the mailbox thread: numpy releases/re-acquires the GIL during the copy, and
* pemja keeps a single PyThreadState, so doing it on an async worker thread can stall the
* interpreter (observed as a hang in CI; benign with spare cores locally). The returned Python
* object is forwarded back into {@link #queryNormalized}. See
* https://github.com/apache/flink-agents/issues/844.
*/
public Object normalizeEmbedding(float[] embedding) {
List<Float> embeddingList = new ArrayList<>(embedding.length);
for (float v : embedding) {
embeddingList.add(v);
}
Map<String, Object> kwargs = new HashMap<>();
kwargs.put("embeddings", embeddingList);
return adapter.callMethod(vectorStore, "_normalize_embeddings", kwargs);
}

/** Query with a pre-normalized embedding; numpy-free, so it stays on the async pool. */
@SuppressWarnings("unchecked")
public List<Document> queryNormalized(
Object normalizedEmbedding,
int limit,
@Nullable String collection,
@Nullable Map<String, Object> filters,
Map<String, Object> args) {
Map<String, Object> kwargs = new HashMap<>(args);
kwargs.put("embedding", normalizedEmbedding);
kwargs.put("limit", limit);
if (collection != null) {
kwargs.put("collection_name", collection);
}
if (filters != null) {
kwargs.put("filters", filters);
}
Object pythonDocuments = adapter.callMethod(vectorStore, "_query_embedding", kwargs);
return adapter.fromPythonDocuments((List<PyObject>) pythonDocuments);
}

@Override
@SuppressWarnings("unchecked")
public List<String> addEmbedding(
List<Document> documents, @Nullable String collection, Map<String, Object> extraArgs)
throws IOException {
return List.of();
Map<String, Object> kwargs = new HashMap<>(extraArgs);
kwargs.put("documents", adapter.toPythonDocuments(documents));
if (collection != null) {
kwargs.put("collection_name", collection);
}
return (List<String>) adapter.callMethod(vectorStore, "_add_embedding", kwargs);
}

@Override
public void updateEmbedding(
List<Document> documents, @Nullable String collection, Map<String, Object> extraArgs) {
// no-op; Python forwards public update() directly
Map<String, Object> kwargs = new HashMap<>(extraArgs);
kwargs.put("documents", adapter.toPythonDocuments(documents));
if (collection != null) {
kwargs.put("collection_name", collection);
}
adapter.callMethod(vectorStore, "_update_embedding", kwargs);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,19 @@ void testDeleteCollection() throws Exception {

@Test
void testAddDocuments() throws Exception {
List<Document> documents =
Arrays.asList(
new Document("content1", Map.of("key", "value1"), "doc1"),
new Document("content2", Map.of("key", "value2"), "doc2"));
Document d1 = new Document("content1", Map.of("key", "value1"), "doc1");
Document d2 = new Document("content2", Map.of("key", "value2"), "doc2");
// Pre-computed embeddings so the Java auto-embed path is skipped (no model configured).
d1.setEmbedding(new float[] {0.1f, 0.2f});
d2.setEmbedding(new float[] {0.3f, 0.4f});
List<Document> documents = Arrays.asList(d1, d2);
String collection = "test_collection";
Map<String, Object> extraArgs = Map.of("batch_size", 10);

List<String> expectedIds = Arrays.asList("doc1", "doc2");

when(mockAdapter.toPythonDocuments(documents)).thenReturn(new Object());
when(mockAdapter.callMethod(eq(mockVectorStore), eq("add"), any(Map.class)))
when(mockAdapter.callMethod(eq(mockVectorStore), eq("_add_embedding"), any(Map.class)))
.thenReturn(expectedIds);

List<String> result = vectorStore.add(documents, collection, extraArgs);
Expand All @@ -184,7 +186,7 @@ void testAddDocuments() throws Exception {
verify(mockAdapter)
.callMethod(
eq(mockVectorStore),
eq("add"),
eq("_add_embedding"),
argThat(
kwargs -> {
assertThat(kwargs).containsKey("documents");
Expand All @@ -196,10 +198,12 @@ void testAddDocuments() throws Exception {

@Test
void testUpdateDocuments() throws Exception {
List<Document> documents =
Arrays.asList(
new Document("c1", Map.of("k", "v1"), "doc1"),
new Document("c2", Map.of("k", "v2"), "doc2"));
Document d1 = new Document("c1", Map.of("k", "v1"), "doc1");
Document d2 = new Document("c2", Map.of("k", "v2"), "doc2");
// Pre-computed embeddings so the Java auto-embed path is skipped (no model configured).
d1.setEmbedding(new float[] {0.1f, 0.2f});
d2.setEmbedding(new float[] {0.3f, 0.4f});
List<Document> documents = Arrays.asList(d1, d2);
String collection = "test_collection";
Map<String, Object> extraArgs = Map.of("batch_size", 5);

Expand All @@ -211,7 +215,7 @@ void testUpdateDocuments() throws Exception {
verify(mockAdapter)
.callMethod(
eq(mockVectorStore),
eq("update"),
eq("_update_embedding"),
argThat(
kwargs -> {
assertThat(kwargs).containsKey("documents");
Expand Down
4 changes: 2 additions & 2 deletions dist/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ under the License.
<packaging>pom</packaging>

<properties>
<flink.1.20.version>1.20.4</flink.1.20.version>
<flink.1.20.version>1.20.5</flink.1.20.version>
<flink.2.0.version>2.0.2</flink.2.0.version>
<flink.2.1.version>2.1.2</flink.2.1.version>
<flink.2.1.version>2.1.3</flink.2.1.version>
<flink.2.2.version>2.2.1</flink.2.2.version>
</properties>

Expand Down
8 changes: 4 additions & 4 deletions e2e-test/flink-agents-end-to-end-tests-integration/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ under the License.
<name>Flink Agents : E2E Tests: Integration</name>

<properties>
<flink.1.20.version>1.20.3</flink.1.20.version>
<flink.2.0.version>2.0.1</flink.2.0.version>
<flink.2.1.version>2.1.1</flink.2.1.version>
<flink.2.2.version>2.2.0</flink.2.2.version>
<flink.1.20.version>1.20.5</flink.1.20.version>
<flink.2.0.version>2.0.2</flink.2.0.version>
<flink.2.1.version>2.1.3</flink.2.1.version>
<flink.2.2.version>2.2.1</flink.2.2.version>

<flink.version>${flink.2.2.version}</flink.version>
<flink.agents.dist.artifactId>flink-agents-dist-flink-2.2</flink.agents.dist.artifactId>
Expand Down
Loading
Loading