diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java index 34b2b2994..313e20d98 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java @@ -116,7 +116,22 @@ public void open() throws Exception { * @param completionTokens the number of completion tokens */ public void recordTokenMetrics(String modelName, long promptTokens, long completionTokens) { - FlinkAgentsMetricGroup metricGroup = getMetricGroup(); + recordTokenMetrics(getMetricGroup(), modelName, promptTokens, completionTokens); + } + + /** + * Record token usage metrics for the given model on the provided metric group. + * + * @param metricGroup the metric group captured when the request was initiated + * @param modelName the name of the model used + * @param promptTokens the number of prompt tokens + * @param completionTokens the number of completion tokens + */ + public void recordTokenMetrics( + @Nullable FlinkAgentsMetricGroup metricGroup, + String modelName, + long promptTokens, + long completionTokens) { if (metricGroup == null) { return; } diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java index 8e47105f0..99d03f1ef 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java @@ -62,6 +62,10 @@ public TestChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourc public Map getParameters() { return Collections.emptyMap(); } + + FlinkAgentsMetricGroup captureMetricGroup() { + return getMetricGroup(); + } } @BeforeEach @@ -104,6 +108,27 @@ void testRecordTokenMetricsWithoutMetricGroup() { verifyNoInteractions(mockMetricGroup); } + @Test + @DisplayName("Test token metrics use the request-scoped metric group") + void testRecordTokenMetricsWithRequestScopedMetricGroup() { + TestMetricGroup actionA = new TestMetricGroup(); + TestMetricGroup actionB = new TestMetricGroup(); + + setup.setMetricGroup(actionA); + FlinkAgentsMetricGroup requestMetricGroup = setup.captureMetricGroup(); + + setup.setMetricGroup(actionB); + setup.recordTokenMetrics(requestMetricGroup, "gpt-4", 100, 50); + + TestMetricGroup actionAModelGroup = (TestMetricGroup) actionA.getSubGroup("model", "gpt-4"); + assertEquals(100, actionAModelGroup.counters.get("promptTokens").getCount()); + assertEquals(50, actionAModelGroup.counters.get("completionTokens").getCount()); + + TestMetricGroup actionBModelGroup = (TestMetricGroup) actionB.getSubGroup("model", "gpt-4"); + assertEquals(0, actionBModelGroup.getCounter("promptTokens").getCount()); + assertEquals(0, actionBModelGroup.getCounter("completionTokens").getCount()); + } + @Test @DisplayName("Test token metrics hierarchy: metricGroup -> modelName -> counters") void testTokenMetricsHierarchy() { diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index df28c1d41..f58b0b423 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -182,7 +182,10 @@ private static void recordRetryMetrics( } } - static void recordChatTokenMetrics(BaseChatModelSetup chatModel, ChatMessage response) { + static void recordChatTokenMetrics( + BaseChatModelSetup chatModel, + ChatMessage response, + FlinkAgentsMetricGroup requestMetricGroup) { Map extraArgs = response.getExtraArgs(); Object modelName = extraArgs.get("model_name"); Object promptTokens = extraArgs.get("promptTokens"); @@ -194,7 +197,8 @@ static void recordChatTokenMetrics(BaseChatModelSetup chatModel, ChatMessage res long prompt = ((Number) promptTokens).longValue(); long completion = ((Number) completionTokens).longValue(); if (prompt > 0 && completion > 0) { - chatModel.recordTokenMetrics(modelName.toString(), prompt, completion); + chatModel.recordTokenMetrics( + requestMetricGroup, modelName.toString(), prompt, completion); } } } @@ -322,6 +326,7 @@ public static void chat( throws Exception { BaseChatModelSetup chatModel = (BaseChatModelSetup) ctx.getResource(model, ResourceType.CHAT_MODEL); + FlinkAgentsMetricGroup requestMetricGroup = ctx.getActionMetricGroup(); boolean chatAsync = ctx.getConfig().get(AgentExecutionOptions.CHAT_ASYNC); @@ -372,7 +377,7 @@ public ChatMessage call() throws Exception { chatAsync ? ctx.durableExecuteAsync(callable) : ctx.durableExecute(callable); - recordChatTokenMetrics(chatModel, response); + recordChatTokenMetrics(chatModel, response, requestMetricGroup); // only generate structured output for final response. if (outputSchema != null && response.getToolCalls().isEmpty()) { response = generateStructuredOutput(response, outputSchema); diff --git a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java index 8c1395059..f93a0836f 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java @@ -127,6 +127,35 @@ void chatSucceedsWithoutRetry_retryCountIsZero() throws Exception { verify(mockActionMetricGroup, never()).getSubGroup(anyString(), anyString()); } + @Test + void chatRecordsTokenMetricsWithRequestScopedMetricGroup() throws Exception { + configureRetryStrategy(0, 0); + FlinkAgentsMetricGroup actionA = mock(FlinkAgentsMetricGroup.class); + FlinkAgentsMetricGroup actionB = mock(FlinkAgentsMetricGroup.class); + when(mockCtx.getActionMetricGroup()).thenReturn(actionA, actionB); + + ChatMessage response = + new ChatMessage( + MessageRole.ASSISTANT, + "hello", + Map.of( + "model_name", "provider-model", + "promptTokens", 100L, + "completionTokens", 50L)); + when(mockChatModel.chat(any(), any(), any())).thenReturn(response); + + ChatModelAction.chat( + UUID.randomUUID(), + "test-model", + List.of(new ChatMessage(MessageRole.USER, "hi")), + Map.of(), + null, + mockCtx); + + verify(mockChatModel).recordTokenMetrics(actionA, "provider-model", 100L, 50L); + verify(mockChatModel, never()).recordTokenMetrics(actionB, "provider-model", 100L, 50L); + } + @Test void chatRetriesWithExponentialBackoff() throws Exception { // 1 second base interval; fail once then succeed -> wait 1s (1 * 2^0) diff --git a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java index 85c263a66..94668f3df 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java @@ -20,12 +20,14 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; @@ -42,71 +44,82 @@ private static ChatMessage responseWith(Map extraArgs) { @Test void testRecordChatTokenMetricsRecordsWhenAllKeysPresent() { BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + FlinkAgentsMetricGroup requestMetricGroup = mock(FlinkAgentsMetricGroup.class); Map extraArgs = new HashMap<>(); extraArgs.put("model_name", "m"); extraArgs.put("promptTokens", 100L); extraArgs.put("completionTokens", 50L); - ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs), requestMetricGroup); - verify(setup).recordTokenMetrics("m", 100L, 50L); + verify(setup).recordTokenMetrics(requestMetricGroup, "m", 100L, 50L); } @Test void testRecordChatTokenMetricsHandlesIntegerTokenValues() { BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + FlinkAgentsMetricGroup requestMetricGroup = mock(FlinkAgentsMetricGroup.class); Map extraArgs = new HashMap<>(); extraArgs.put("model_name", "m"); extraArgs.put("promptTokens", 100); extraArgs.put("completionTokens", 50); - ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs), requestMetricGroup); - verify(setup).recordTokenMetrics("m", 100L, 50L); + verify(setup).recordTokenMetrics(requestMetricGroup, "m", 100L, 50L); } @Test void testRecordChatTokenMetricsSkipsWhenTokenValueNonNumeric() { BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + FlinkAgentsMetricGroup requestMetricGroup = mock(FlinkAgentsMetricGroup.class); Map extraArgs = new HashMap<>(); extraArgs.put("model_name", "m"); extraArgs.put("promptTokens", "100"); extraArgs.put("completionTokens", 50L); - ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs), requestMetricGroup); - verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong()); + verify(setup, never()) + .recordTokenMetrics( + any(FlinkAgentsMetricGroup.class), anyString(), anyLong(), anyLong()); } @Test void testRecordChatTokenMetricsSkipsWhenKeyMissing() { BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + FlinkAgentsMetricGroup requestMetricGroup = mock(FlinkAgentsMetricGroup.class); Map extraArgs = new HashMap<>(); extraArgs.put("model_name", "m"); extraArgs.put("completionTokens", 50L); - ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs), requestMetricGroup); - verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong()); + verify(setup, never()) + .recordTokenMetrics( + any(FlinkAgentsMetricGroup.class), anyString(), anyLong(), anyLong()); } @Test void testRecordChatTokenMetricsSkipsZeroTokensOrEmptyModel() { BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + FlinkAgentsMetricGroup requestMetricGroup = mock(FlinkAgentsMetricGroup.class); Map zeroPrompt = new HashMap<>(); zeroPrompt.put("model_name", "m"); zeroPrompt.put("promptTokens", 0L); zeroPrompt.put("completionTokens", 50L); - ChatModelAction.recordChatTokenMetrics(setup, responseWith(zeroPrompt)); + ChatModelAction.recordChatTokenMetrics(setup, responseWith(zeroPrompt), requestMetricGroup); Map emptyModel = new HashMap<>(); emptyModel.put("model_name", ""); emptyModel.put("promptTokens", 100L); emptyModel.put("completionTokens", 50L); - ChatModelAction.recordChatTokenMetrics(setup, responseWith(emptyModel)); + ChatModelAction.recordChatTokenMetrics(setup, responseWith(emptyModel), requestMetricGroup); - verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong()); + verify(setup, never()) + .recordTokenMetrics( + any(FlinkAgentsMetricGroup.class), anyString(), anyLong(), anyLong()); } @Test diff --git a/python/flink_agents/api/chat_models/chat_model.py b/python/flink_agents/api/chat_models/chat_model.py index ac4a814aa..ea37e45fa 100644 --- a/python/flink_agents/api/chat_models/chat_model.py +++ b/python/flink_agents/api/chat_models/chat_model.py @@ -27,6 +27,7 @@ MessageRole, find_first_system_message, ) +from flink_agents.api.metric_group import MetricGroup from flink_agents.api.prompts.prompt import Prompt from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.skills import BASH_TOOL, LOAD_SKILL_TOOL @@ -261,7 +262,11 @@ def chat( ) def _record_token_metrics( - self, model_name: str, prompt_tokens: int, completion_tokens: int + self, + model_name: str, + prompt_tokens: int, + completion_tokens: int, + metric_group: MetricGroup | None = None, ) -> None: """Record token usage metrics for the given model. @@ -273,8 +278,12 @@ def _record_token_metrics( The number of prompt tokens completion_tokens : int The number of completion tokens + metric_group : MetricGroup | None + The metric group captured when the request was initiated. If not provided, + this resource's currently bound metric group is used. """ - metric_group = self.metric_group + if metric_group is None: + metric_group = self.metric_group if metric_group is None: return diff --git a/python/flink_agents/api/chat_models/tests/test_token_metrics.py b/python/flink_agents/api/chat_models/tests/test_token_metrics.py index e81951511..6ac10ee72 100644 --- a/python/flink_agents/api/chat_models/tests/test_token_metrics.py +++ b/python/flink_agents/api/chat_models/tests/test_token_metrics.py @@ -46,10 +46,16 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: return ChatMessage(role=MessageRole.ASSISTANT, content="Test response") def test_record_token_metrics( - self, model_name: str, prompt_tokens: int, completion_tokens: int + self, + model_name: str, + prompt_tokens: int, + completion_tokens: int, + metric_group: MetricGroup | None = None, ) -> None: """Expose protected method for testing.""" - self._record_token_metrics(model_name, prompt_tokens, completion_tokens) + self._record_token_metrics( + model_name, prompt_tokens, completion_tokens, metric_group + ) class _MockCounter(Counter): @@ -124,6 +130,28 @@ def test_record_token_metrics_without_metric_group(self) -> None: chat_model.test_record_token_metrics("gpt-4", 100, 50) # No exception should be raised + def test_record_token_metrics_with_request_scoped_metric_group(self) -> None: + """Token metrics use the metric group captured when the request started.""" + chat_model = TestChatModelSetup(connection="mock", model="mock-model") + action_a_metric_group = _MockMetricGroup() + action_b_metric_group = _MockMetricGroup() + + chat_model.set_metric_group(action_a_metric_group) + request_metric_group = chat_model.metric_group + + chat_model.set_metric_group(action_b_metric_group) + chat_model.test_record_token_metrics( + "gpt-4", 100, 50, metric_group=request_metric_group + ) + + action_a_model_group = action_a_metric_group.get_sub_group("model", "gpt-4") + assert action_a_model_group.get_counter("promptTokens").get_count() == 100 + assert action_a_model_group.get_counter("completionTokens").get_count() == 50 + + action_b_model_group = action_b_metric_group.get_sub_group("model", "gpt-4") + assert action_b_model_group.get_counter("promptTokens").get_count() == 0 + assert action_b_model_group.get_counter("completionTokens").get_count() == 0 + def test_token_metrics_hierarchy(self) -> None: """Test token metrics hierarchy: actionMetricGroup -> modelName -> counters.""" chat_model = TestChatModelSetup(connection="mock", model="mock-model") diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index e4572056e..2918a59bc 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -290,6 +290,7 @@ async def chat( chat_model = cast( "BaseChatModelSetup", ctx.get_resource(model, ResourceType.CHAT_MODEL) ) + request_metric_group = ctx.action_metric_group chat_async = ctx.config.get(AgentExecutionOptions.CHAT_ASYNC) @@ -334,6 +335,7 @@ async def chat( response.extra_args["model_name"], response.extra_args["promptTokens"], response.extra_args["completionTokens"], + request_metric_group, ) if output_schema is not None and len(response.tool_calls) == 0: response = _generate_structured_output(response, output_schema)