Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ public Object getPythonResource() {
return chatModel;
}

@Override
public PythonResourceAdapter getPythonResourceAdapter() {
return adapter;
}

@Override
public ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object> modelParams) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ public Object getPythonResource() {
return chatModelSetup;
}

@Override
public PythonResourceAdapter getPythonResourceAdapter() {
return adapter;
}

@Override
public Map<String, Object> getParameters() {
return Map.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ public Object getPythonResource() {
return embeddingModel;
}

@Override
public PythonResourceAdapter getPythonResourceAdapter() {
return adapter;
}

@Override
public void close() throws Exception {
this.embeddingModel.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,9 @@ public Map<String, Object> getParameters() {
public Object getPythonResource() {
return embeddingModelSetup;
}

@Override
public PythonResourceAdapter getPythonResourceAdapter() {
return adapter;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.agents.api.resource;

import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.python.PythonResourceWrapper;

/**
* Base interface for all kinds of resources, including chat models, tools, prompts and so on.
Expand Down Expand Up @@ -60,6 +61,9 @@ public ResourceContext getResourceContext() {
*/
public void setMetricGroup(FlinkAgentsMetricGroup metricGroup) {
this.metricGroup = metricGroup;
if (this instanceof PythonResourceWrapper) {
((PythonResourceWrapper) this).setPythonResourceMetricGroup(metricGroup);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The abstract base Resource (package ...api.resource) now imports and instanceof-checks PythonResourceWrapper from the more specialized ...api.resource.python sub-package. That inverts the usual dependency direction — every Java-only resource (chat models, tools, prompts, vector stores) now carries a compile-time edge to the Python-bridge interface in its base type, and a hypothetical third forwarding flavor would mean editing this base method again rather than overriding it.

One option that keeps the base oblivious to the bridge: have each Python* wrapper override setMetricGroup to call super.setMetricGroup(...) and then setPythonResourceMetricGroup(...), pushing the Python concern down into the wrappers where getPythonResourceAdapter() already lives. The PR already adds a new getPythonResourceAdapter() override to all eight wrappers, so a per-wrapper setMetricGroup override would touch the same set of files — just in the Python layer instead of the base class. Was the centralized instanceof chosen deliberately, or mainly to avoid touching the wrappers?

}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.agents.api.resource.python;

import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.tools.Tool;
import org.apache.flink.agents.api.vectorstores.Document;
import org.apache.flink.agents.api.vectorstores.VectorStoreQuery;
Expand Down Expand Up @@ -120,6 +121,14 @@ public interface PythonResourceAdapter {
*/
Object callMethod(Object obj, String methodName, Map<String, Object> kwargs);

/**
* Binds a Java metric group to a Python resource.
*
* @param pythonResource the Python resource object
* @param metricGroup the Java metric group to expose through Python's metric group API
*/
default void setMetricGroup(Object pythonResource, FlinkAgentsMetricGroup metricGroup) {}

/**
* Invokes a method with the specified name and arguments.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.flink.agents.api.resource.python;

import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;

/**
* Wrapper interface for Python resource objects. This interface provides a unified way to access
* the underlying Python resource from Java objects that encapsulate Python functionality.
Expand All @@ -29,4 +31,26 @@ public interface PythonResourceWrapper {
* @return the wrapped Python resource object
*/
Object getPythonResource();

/**
* Returns the adapter that owns the wrapped Python resource.
*
* @return the Python resource adapter, or null if metric forwarding is unsupported
*/
default PythonResourceAdapter getPythonResourceAdapter() {
return null;
}

/**
* Binds the current Java metric group to the wrapped Python resource.
*
* @param metricGroup the metric group to bind
*/
default void setPythonResourceMetricGroup(FlinkAgentsMetricGroup metricGroup) {
PythonResourceAdapter adapter = getPythonResourceAdapter();
Object pythonResource = getPythonResource();
if (adapter != null && pythonResource != null) {
adapter.setMetricGroup(pythonResource, metricGroup);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,9 @@ public void updateEmbedding(
public Object getPythonResource() {
return vectorStore;
}

@Override
public PythonResourceAdapter getPythonResourceAdapter() {
return adapter;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.flink.agents.api.chat.model.python;

import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.ResourceContext;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.python.PythonResourceAdapter;
Expand Down Expand Up @@ -161,4 +162,13 @@ void testImplementsPythonResourceWrapper() {
.isInstanceOf(
org.apache.flink.agents.api.resource.python.PythonResourceWrapper.class);
}

@Test
void testSetMetricGroupPropagatesToPythonResource() {
FlinkAgentsMetricGroup metricGroup = mock(FlinkAgentsMetricGroup.class);

pythonChatModelSetup.setMetricGroup(metricGroup);

verify(mockAdapter).setMetricGroup(mockChatModelSetup, metricGroup);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public Object getPythonResource() {
return prompt;
}

@Override
public PythonResourceAdapter getPythonResourceAdapter() {
return adapter;
}

public String getName() {
if (name == null) {
name = prompt.getAttr("name").toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ public Object getPythonResource() {
return server;
}

@Override
public PythonResourceAdapter getPythonResourceAdapter() {
return adapter;
}

@Override
public ResourceType getResourceType() {
return ResourceType.MCP_SERVER;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ public Object getPythonResource() {
return tool;
}

@Override
public PythonResourceAdapter getPythonResourceAdapter() {
return adapter;
}

@Override
public ToolType getToolType() {
return ToolType.MCP;
Expand Down
13 changes: 13 additions & 0 deletions python/flink_agents/runtime/java/java_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
)
from flink_agents.api.resource import ResourceType
from flink_agents.api.tools.tool import Tool
from flink_agents.runtime.java.java_resource_wrapper import (
set_java_resource_metric_group,
)


class JavaChatModelConnectionImpl(JavaChatModelConnection):
Expand All @@ -51,6 +54,11 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N
self._j_resource = j_resource
self._j_resource_adapter = j_resource_adapter

@override
def set_metric_group(self, metric_group: Any) -> None:
super().set_metric_group(metric_group)
set_java_resource_metric_group(self._j_resource, metric_group)

@override
def chat(
self,
Expand Down Expand Up @@ -114,6 +122,11 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N
self._j_resource = j_resource
self._j_resource_adapter = j_resource_adapter

@override
def set_metric_group(self, metric_group: Any) -> None:
super().set_metric_group(metric_group)
set_java_resource_metric_group(self._j_resource, metric_group)

@property
@override
def model_kwargs(self) -> Dict[str, Any]:
Expand Down
13 changes: 13 additions & 0 deletions python/flink_agents/runtime/java/java_embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
JavaEmbeddingModelConnection,
JavaEmbeddingModelSetup,
)
from flink_agents.runtime.java.java_resource_wrapper import (
set_java_resource_metric_group,
)


class JavaEmbeddingModelConnectionImpl(JavaEmbeddingModelConnection):
Expand All @@ -48,6 +51,11 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N
self._j_resource = j_resource
self._j_resource_adapter = j_resource_adapter

@override
def set_metric_group(self, metric_group: Any) -> None:
super().set_metric_group(metric_group)
set_java_resource_metric_group(self._j_resource, metric_group)

def embed(
self, text: str | Sequence[str], **kwargs: Any
) -> list[float] | list[list[float]]:
Expand Down Expand Up @@ -92,6 +100,11 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N
self._j_resource = j_resource
self._j_resource_adapter = j_resource_adapter

@override
def set_metric_group(self, metric_group: Any) -> None:
super().set_metric_group(metric_group)
set_java_resource_metric_group(self._j_resource, metric_group)

@property
def model_kwargs(self) -> Dict[str, Any]:
"""Return embedding model settings.
Expand Down
8 changes: 8 additions & 0 deletions python/flink_agents/runtime/java/java_resource_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
from flink_agents.api.tools.tool import Tool, ToolMetadata, ToolType


def set_java_resource_metric_group(j_resource: Any, metric_group: Any) -> None:
"""Bind the underlying Java metric group to a wrapped Java resource."""
if j_resource is None:
return
j_metric_group = getattr(metric_group, "_j_metric_group", metric_group)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the real flow this is correct — metric_group is always a FlinkMetricGroup (it comes from action_metric_group), so the unwrap hits _j_metric_group and forwards the genuine Java group. Low confidence that the edge below matters in practice, so genuinely a question rather than a flag.

RunnerContext.get_resource(name, type, metric_group=...) is a documented public API accepting any MetricGroup — an abstract base a user could subclass without a _j_metric_group. If such a custom group reaches a Java resource wrapper, the silent , metric_group fallback would forward a raw Python object into Java setMetricGroup(FlinkAgentsMetricGroup), failing opaquely at the pemja boundary rather than with a clear error. Worth making the fallback explicit (unwrap only a real FlinkMetricGroup, else fail loudly), or is that out of scope today given no in-tree caller hits it? (The None case is benign — get_resource substitutes action_metric_group first, and a direct None just clears the field — so no concern there.)

j_resource.setMetricGroup(j_metric_group)


class JavaTool(Tool):
"""Java Tool that carries tool metadata and can be recognized by PythonChatModel."""

Expand Down
8 changes: 8 additions & 0 deletions python/flink_agents/runtime/java/java_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
Document,
_maybe_cast_to_list,
)
from flink_agents.runtime.java.java_resource_wrapper import (
set_java_resource_metric_group,
)
from flink_agents.runtime.python_java_utils import from_java_document


Expand Down Expand Up @@ -60,6 +63,11 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N
self._j_resource = j_resource
self._j_resource_adapter = j_resource_adapter

@override
def set_metric_group(self, metric_group: Any) -> None:
super().set_metric_group(metric_group)
set_java_resource_metric_group(self._j_resource, metric_group)

@property
@override
def store_kwargs(self) -> Dict[str, Any]:
Expand Down
10 changes: 10 additions & 0 deletions python/flink_agents/runtime/python_java_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,13 @@ def call_method(obj: Any, method_name: str, kwargs: Dict[str, Any]) -> Any:

method = getattr(obj, method_name)
return method(**kwargs)


def set_metric_group(obj: Resource, j_metric_group: Any) -> None:
"""Bind a Java metric group to a Python resource."""
from flink_agents.runtime.flink_metric_group import FlinkMetricGroup

metric_group = (
FlinkMetricGroup(j_metric_group) if j_metric_group is not None else None
)
obj.set_metric_group(metric_group)
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
from typing import Any

import pytest

from flink_agents.runtime.java.java_chat_model import (
JavaChatModelConnectionImpl,
JavaChatModelSetupImpl,
)
from flink_agents.runtime.java.java_embedding_model import (
JavaEmbeddingModelConnectionImpl,
JavaEmbeddingModelSetupImpl,
)
from flink_agents.runtime.java.java_resource_wrapper import (
set_java_resource_metric_group,
)


class _JavaResource:
def __init__(self) -> None:
self.metric_group: Any = None

def setMetricGroup(self, metric_group: Any) -> None:
self.metric_group = metric_group


class _MetricGroup:
def __init__(self) -> None:
self._j_metric_group = object()


@pytest.mark.parametrize(
"resource",
[
JavaChatModelConnectionImpl(
j_resource=_JavaResource(), j_resource_adapter=None
),
JavaChatModelSetupImpl(
j_resource=_JavaResource(),
j_resource_adapter=None,
connection="connection",
model="model",
),
JavaEmbeddingModelConnectionImpl(
j_resource=_JavaResource(), j_resource_adapter=None
),
JavaEmbeddingModelSetupImpl(
j_resource=_JavaResource(),
j_resource_adapter=None,
connection="connection",
model="model",
),
],
)
def test_java_resource_wrappers_forward_metric_group(resource):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new file covers the Python→Java direction well, but only that direction — set_java_resource_metric_group plus the four Java*Impl wrappers. The Java→Python direction's Python half, python_java_utils.set_metric_group (python_java_utils.py:382), has no direct test: it's the function that wraps the raw Java group in FlinkMetricGroup and owns the only None branch in the feature, and the Java half is mock-verify only (testSetMetricGroup asserts the invoke string, it doesn't execute Python). So that seam isn't exercised end-to-end today.

Would a small unit test in this file's style be worth adding — a fake resource that captures its set_metric_group arg, then set_metric_group(fake, sentinel_j_group) asserting the captured arg is a FlinkMetricGroup whose _j_metric_group is sentinel_j_group, plus a None case asserting None is forwarded? That would pin the wrap-and-None behavior that's currently uncovered.

metric_group = _MetricGroup()

resource.set_metric_group(metric_group)

assert resource.metric_group is metric_group
assert resource._j_resource.metric_group is metric_group._j_metric_group


def test_set_java_resource_metric_group_unwraps_flink_metric_group():
java_resource = _JavaResource()
metric_group = _MetricGroup()

set_java_resource_metric_group(java_resource, metric_group)

assert java_resource.metric_group is metric_group._j_metric_group
Loading
Loading