diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java index 2a362f7a7..80c1f0547 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java @@ -65,6 +65,11 @@ public Object getPythonResource() { return chatModel; } + @Override + public PythonResourceAdapter getPythonResourceAdapter() { + return adapter; + } + @Override public ChatMessage chat( List messages, List tools, Map modelParams) { diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java index ad2117d36..a478f6827 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java @@ -89,6 +89,11 @@ public Object getPythonResource() { return chatModelSetup; } + @Override + public PythonResourceAdapter getPythonResourceAdapter() { + return adapter; + } + @Override public Map getParameters() { return Map.of(); diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnection.java index 974e362a1..bb6a593e1 100644 --- a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnection.java +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnection.java @@ -123,6 +123,11 @@ public Object getPythonResource() { return embeddingModel; } + @Override + public PythonResourceAdapter getPythonResourceAdapter() { + return adapter; + } + @Override public void close() throws Exception { this.embeddingModel.close(); diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java index f0b9eca4b..cc0d1e169 100644 --- a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java @@ -132,4 +132,9 @@ public Map getParameters() { public Object getPythonResource() { return embeddingModelSetup; } + + @Override + public PythonResourceAdapter getPythonResourceAdapter() { + return adapter; + } } diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java b/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java index 52ba40ffd..97f8a53b5 100644 --- a/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java +++ b/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java @@ -19,6 +19,7 @@ package org.apache.flink.agents.api.resource; import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; +import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; /** * Base interface for all kinds of resources, including chat models, tools, prompts and so on. @@ -60,6 +61,9 @@ public ResourceContext getResourceContext() { */ public void setMetricGroup(FlinkAgentsMetricGroup metricGroup) { this.metricGroup = metricGroup; + if (this instanceof PythonResourceWrapper) { + ((PythonResourceWrapper) this).setPythonResourceMetricGroup(metricGroup); + } } /** diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java b/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java index 03eb8248c..527865d9f 100644 --- a/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java +++ b/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java @@ -19,6 +19,7 @@ package org.apache.flink.agents.api.resource.python; import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.tools.Tool; import org.apache.flink.agents.api.vectorstores.Document; import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; @@ -120,6 +121,14 @@ public interface PythonResourceAdapter { */ Object callMethod(Object obj, String methodName, Map kwargs); + /** + * Binds a Java metric group to a Python resource. + * + * @param pythonResource the Python resource object + * @param metricGroup the Java metric group to expose through Python's metric group API + */ + default void setMetricGroup(Object pythonResource, FlinkAgentsMetricGroup metricGroup) {} + /** * Invokes a method with the specified name and arguments. * diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceWrapper.java b/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceWrapper.java index c69cf59b8..7bd3a343d 100644 --- a/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceWrapper.java +++ b/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceWrapper.java @@ -17,6 +17,8 @@ */ package org.apache.flink.agents.api.resource.python; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; + /** * Wrapper interface for Python resource objects. This interface provides a unified way to access * the underlying Python resource from Java objects that encapsulate Python functionality. @@ -29,4 +31,26 @@ public interface PythonResourceWrapper { * @return the wrapped Python resource object */ Object getPythonResource(); + + /** + * Returns the adapter that owns the wrapped Python resource. + * + * @return the Python resource adapter, or null if metric forwarding is unsupported + */ + default PythonResourceAdapter getPythonResourceAdapter() { + return null; + } + + /** + * Binds the current Java metric group to the wrapped Python resource. + * + * @param metricGroup the metric group to bind + */ + default void setPythonResourceMetricGroup(FlinkAgentsMetricGroup metricGroup) { + PythonResourceAdapter adapter = getPythonResourceAdapter(); + Object pythonResource = getPythonResource(); + if (adapter != null && pythonResource != null) { + adapter.setMetricGroup(pythonResource, metricGroup); + } + } } diff --git a/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java b/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java index 69025cc11..9a81728d2 100644 --- a/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java +++ b/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java @@ -233,4 +233,9 @@ public void updateEmbedding( public Object getPythonResource() { return vectorStore; } + + @Override + public PythonResourceAdapter getPythonResourceAdapter() { + return adapter; + } } diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java index 4327fb336..42463b083 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java @@ -18,6 +18,7 @@ package org.apache.flink.agents.api.chat.model.python; import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; @@ -161,4 +162,13 @@ void testImplementsPythonResourceWrapper() { .isInstanceOf( org.apache.flink.agents.api.resource.python.PythonResourceWrapper.class); } + + @Test + void testSetMetricGroupPropagatesToPythonResource() { + FlinkAgentsMetricGroup metricGroup = mock(FlinkAgentsMetricGroup.class); + + pythonChatModelSetup.setMetricGroup(metricGroup); + + verify(mockAdapter).setMetricGroup(mockChatModelSetup, metricGroup); + } } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPPrompt.java b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPPrompt.java index 625a89fd7..32a6159b6 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPPrompt.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPPrompt.java @@ -47,6 +47,11 @@ public Object getPythonResource() { return prompt; } + @Override + public PythonResourceAdapter getPythonResourceAdapter() { + return adapter; + } + public String getName() { if (name == null) { name = prompt.getAttr("name").toString(); diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPServer.java b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPServer.java index a6268347d..030ac0abe 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPServer.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPServer.java @@ -84,6 +84,11 @@ public Object getPythonResource() { return server; } + @Override + public PythonResourceAdapter getPythonResourceAdapter() { + return adapter; + } + @Override public ResourceType getResourceType() { return ResourceType.MCP_SERVER; diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPTool.java b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPTool.java index 89a5435d7..576f451c9 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPTool.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPTool.java @@ -75,6 +75,11 @@ public Object getPythonResource() { return tool; } + @Override + public PythonResourceAdapter getPythonResourceAdapter() { + return adapter; + } + @Override public ToolType getToolType() { return ToolType.MCP; diff --git a/python/flink_agents/runtime/java/java_chat_model.py b/python/flink_agents/runtime/java/java_chat_model.py index 10bc169c1..3f385a9ad 100644 --- a/python/flink_agents/runtime/java/java_chat_model.py +++ b/python/flink_agents/runtime/java/java_chat_model.py @@ -26,6 +26,9 @@ ) from flink_agents.api.resource import ResourceType from flink_agents.api.tools.tool import Tool +from flink_agents.runtime.java.java_resource_wrapper import ( + set_java_resource_metric_group, +) class JavaChatModelConnectionImpl(JavaChatModelConnection): @@ -51,6 +54,11 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N self._j_resource = j_resource self._j_resource_adapter = j_resource_adapter + @override + def set_metric_group(self, metric_group: Any) -> None: + super().set_metric_group(metric_group) + set_java_resource_metric_group(self._j_resource, metric_group) + @override def chat( self, @@ -114,6 +122,11 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N self._j_resource = j_resource self._j_resource_adapter = j_resource_adapter + @override + def set_metric_group(self, metric_group: Any) -> None: + super().set_metric_group(metric_group) + set_java_resource_metric_group(self._j_resource, metric_group) + @property @override def model_kwargs(self) -> Dict[str, Any]: diff --git a/python/flink_agents/runtime/java/java_embedding_model.py b/python/flink_agents/runtime/java/java_embedding_model.py index 2cb15b819..b2ea48724 100644 --- a/python/flink_agents/runtime/java/java_embedding_model.py +++ b/python/flink_agents/runtime/java/java_embedding_model.py @@ -23,6 +23,9 @@ JavaEmbeddingModelConnection, JavaEmbeddingModelSetup, ) +from flink_agents.runtime.java.java_resource_wrapper import ( + set_java_resource_metric_group, +) class JavaEmbeddingModelConnectionImpl(JavaEmbeddingModelConnection): @@ -48,6 +51,11 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N self._j_resource = j_resource self._j_resource_adapter = j_resource_adapter + @override + def set_metric_group(self, metric_group: Any) -> None: + super().set_metric_group(metric_group) + set_java_resource_metric_group(self._j_resource, metric_group) + def embed( self, text: str | Sequence[str], **kwargs: Any ) -> list[float] | list[list[float]]: @@ -92,6 +100,11 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N self._j_resource = j_resource self._j_resource_adapter = j_resource_adapter + @override + def set_metric_group(self, metric_group: Any) -> None: + super().set_metric_group(metric_group) + set_java_resource_metric_group(self._j_resource, metric_group) + @property def model_kwargs(self) -> Dict[str, Any]: """Return embedding model settings. diff --git a/python/flink_agents/runtime/java/java_resource_wrapper.py b/python/flink_agents/runtime/java/java_resource_wrapper.py index 886e4c84c..f11bfa482 100644 --- a/python/flink_agents/runtime/java/java_resource_wrapper.py +++ b/python/flink_agents/runtime/java/java_resource_wrapper.py @@ -27,6 +27,14 @@ from flink_agents.api.tools.tool import Tool, ToolMetadata, ToolType +def set_java_resource_metric_group(j_resource: Any, metric_group: Any) -> None: + """Bind the underlying Java metric group to a wrapped Java resource.""" + if j_resource is None: + return + j_metric_group = getattr(metric_group, "_j_metric_group", metric_group) + j_resource.setMetricGroup(j_metric_group) + + class JavaTool(Tool): """Java Tool that carries tool metadata and can be recognized by PythonChatModel.""" diff --git a/python/flink_agents/runtime/java/java_vector_store.py b/python/flink_agents/runtime/java/java_vector_store.py index 301d7a539..8c3399cad 100644 --- a/python/flink_agents/runtime/java/java_vector_store.py +++ b/python/flink_agents/runtime/java/java_vector_store.py @@ -27,6 +27,9 @@ Document, _maybe_cast_to_list, ) +from flink_agents.runtime.java.java_resource_wrapper import ( + set_java_resource_metric_group, +) from flink_agents.runtime.python_java_utils import from_java_document @@ -60,6 +63,11 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N self._j_resource = j_resource self._j_resource_adapter = j_resource_adapter + @override + def set_metric_group(self, metric_group: Any) -> None: + super().set_metric_group(metric_group) + set_java_resource_metric_group(self._j_resource, metric_group) + @property @override def store_kwargs(self) -> Dict[str, Any]: diff --git a/python/flink_agents/runtime/python_java_utils.py b/python/flink_agents/runtime/python_java_utils.py index 23ed4f5c1..93f1d9562 100644 --- a/python/flink_agents/runtime/python_java_utils.py +++ b/python/flink_agents/runtime/python_java_utils.py @@ -377,3 +377,13 @@ def call_method(obj: Any, method_name: str, kwargs: Dict[str, Any]) -> Any: method = getattr(obj, method_name) return method(**kwargs) + + +def set_metric_group(obj: Resource, j_metric_group: Any) -> None: + """Bind a Java metric group to a Python resource.""" + from flink_agents.runtime.flink_metric_group import FlinkMetricGroup + + metric_group = ( + FlinkMetricGroup(j_metric_group) if j_metric_group is not None else None + ) + obj.set_metric_group(metric_group) diff --git a/python/flink_agents/runtime/tests/test_cross_language_metric_group.py b/python/flink_agents/runtime/tests/test_cross_language_metric_group.py new file mode 100644 index 000000000..45ff1c1f5 --- /dev/null +++ b/python/flink_agents/runtime/tests/test_cross_language_metric_group.py @@ -0,0 +1,86 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing import Any + +import pytest + +from flink_agents.runtime.java.java_chat_model import ( + JavaChatModelConnectionImpl, + JavaChatModelSetupImpl, +) +from flink_agents.runtime.java.java_embedding_model import ( + JavaEmbeddingModelConnectionImpl, + JavaEmbeddingModelSetupImpl, +) +from flink_agents.runtime.java.java_resource_wrapper import ( + set_java_resource_metric_group, +) + + +class _JavaResource: + def __init__(self) -> None: + self.metric_group: Any = None + + def setMetricGroup(self, metric_group: Any) -> None: + self.metric_group = metric_group + + +class _MetricGroup: + def __init__(self) -> None: + self._j_metric_group = object() + + +@pytest.mark.parametrize( + "resource", + [ + JavaChatModelConnectionImpl( + j_resource=_JavaResource(), j_resource_adapter=None + ), + JavaChatModelSetupImpl( + j_resource=_JavaResource(), + j_resource_adapter=None, + connection="connection", + model="model", + ), + JavaEmbeddingModelConnectionImpl( + j_resource=_JavaResource(), j_resource_adapter=None + ), + JavaEmbeddingModelSetupImpl( + j_resource=_JavaResource(), + j_resource_adapter=None, + connection="connection", + model="model", + ), + ], +) +def test_java_resource_wrappers_forward_metric_group(resource): + metric_group = _MetricGroup() + + resource.set_metric_group(metric_group) + + assert resource.metric_group is metric_group + assert resource._j_resource.metric_group is metric_group._j_metric_group + + +def test_set_java_resource_metric_group_unwraps_flink_metric_group(): + java_resource = _JavaResource() + metric_group = _MetricGroup() + + set_java_resource_metric_group(java_resource, metric_group) + + assert java_resource.metric_group is metric_group._j_metric_group diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java index f4284e48e..42ab86d52 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java @@ -19,6 +19,7 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.prompt.Prompt; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceContext; @@ -53,6 +54,8 @@ public class PythonResourceAdapterImpl implements PythonResourceAdapter { static final String CALL_METHOD = PYTHON_MODULE_PREFIX + "call_method"; + static final String SET_METRIC_GROUP = PYTHON_MODULE_PREFIX + "set_metric_group"; + static final String CREATE_RESOURCE = PYTHON_MODULE_PREFIX + "create_resource"; static final String FROM_JAVA_RESOURCE = PYTHON_MODULE_PREFIX + "from_java_resource"; @@ -200,6 +203,11 @@ public Object callMethod(Object obj, String methodName, Map kwar return interpreter.invoke(CALL_METHOD, obj, methodName, kwargs); } + @Override + public void setMetricGroup(Object pythonResource, FlinkAgentsMetricGroup metricGroup) { + interpreter.invoke(SET_METRIC_GROUP, pythonResource, metricGroup); + } + @Override public Object invoke(String name, Object... args) { return interpreter.invoke(name, args); diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java index f8821bfb2..e46e3ea68 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java @@ -18,6 +18,7 @@ package org.apache.flink.agents.runtime.python.utils; import org.apache.flink.agents.api.chat.model.python.PythonChatModelSetup; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.prompt.Prompt; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceContext; @@ -181,6 +182,17 @@ void testCallMethod() { .invoke(PythonResourceAdapterImpl.CALL_METHOD, obj, methodName, kwargs); } + @Test + void testSetMetricGroup() { + Object pythonResource = new Object(); + FlinkAgentsMetricGroup metricGroup = mock(FlinkAgentsMetricGroup.class); + + pythonResourceAdapter.setMetricGroup(pythonResource, metricGroup); + + verify(mockInterpreter) + .invoke(PythonResourceAdapterImpl.SET_METRIC_GROUP, pythonResource, metricGroup); + } + @Test void testInvoke() { String name = "test_function";