Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,22 @@ public void open() throws Exception {
* @param completionTokens the number of completion tokens
*/
public void recordTokenMetrics(String modelName, long promptTokens, long completionTokens) {
FlinkAgentsMetricGroup metricGroup = getMetricGroup();
recordTokenMetrics(getMetricGroup(), modelName, promptTokens, completionTokens);
}

/**
* Record token usage metrics for the given model on the provided metric group.
*
* @param metricGroup the metric group captured when the request was initiated
* @param modelName the name of the model used
* @param promptTokens the number of prompt tokens
* @param completionTokens the number of completion tokens
*/
public void recordTokenMetrics(
@Nullable FlinkAgentsMetricGroup metricGroup,
String modelName,
long promptTokens,
long completionTokens) {
if (metricGroup == null) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ public TestChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourc
public Map<String, Object> getParameters() {
return Collections.emptyMap();
}

FlinkAgentsMetricGroup captureMetricGroup() {
return getMetricGroup();
}
}

@BeforeEach
Expand Down Expand Up @@ -104,6 +108,27 @@ void testRecordTokenMetricsWithoutMetricGroup() {
verifyNoInteractions(mockMetricGroup);
}

@Test
@DisplayName("Test token metrics use the request-scoped metric group")
void testRecordTokenMetricsWithRequestScopedMetricGroup() {
TestMetricGroup actionA = new TestMetricGroup();
TestMetricGroup actionB = new TestMetricGroup();

setup.setMetricGroup(actionA);
FlinkAgentsMetricGroup requestMetricGroup = setup.captureMetricGroup();

setup.setMetricGroup(actionB);
setup.recordTokenMetrics(requestMetricGroup, "gpt-4", 100, 50);

TestMetricGroup actionAModelGroup = (TestMetricGroup) actionA.getSubGroup("model", "gpt-4");
assertEquals(100, actionAModelGroup.counters.get("promptTokens").getCount());
assertEquals(50, actionAModelGroup.counters.get("completionTokens").getCount());

TestMetricGroup actionBModelGroup = (TestMetricGroup) actionB.getSubGroup("model", "gpt-4");
assertEquals(0, actionBModelGroup.getCounter("promptTokens").getCount());
assertEquals(0, actionBModelGroup.getCounter("completionTokens").getCount());
}

@Test
@DisplayName("Test token metrics hierarchy: metricGroup -> modelName -> counters")
void testTokenMetricsHierarchy() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ private static void recordRetryMetrics(
}
}

static void recordChatTokenMetrics(BaseChatModelSetup chatModel, ChatMessage response) {
static void recordChatTokenMetrics(
BaseChatModelSetup chatModel,
ChatMessage response,
FlinkAgentsMetricGroup requestMetricGroup) {
Map<String, Object> extraArgs = response.getExtraArgs();
Object modelName = extraArgs.get("model_name");
Object promptTokens = extraArgs.get("promptTokens");
Expand All @@ -194,7 +197,8 @@ static void recordChatTokenMetrics(BaseChatModelSetup chatModel, ChatMessage res
long prompt = ((Number) promptTokens).longValue();
long completion = ((Number) completionTokens).longValue();
if (prompt > 0 && completion > 0) {
chatModel.recordTokenMetrics(modelName.toString(), prompt, completion);
chatModel.recordTokenMetrics(
requestMetricGroup, modelName.toString(), prompt, completion);
}
}
}
Expand Down Expand Up @@ -322,6 +326,7 @@ public static void chat(
throws Exception {
BaseChatModelSetup chatModel =
(BaseChatModelSetup) ctx.getResource(model, ResourceType.CHAT_MODEL);
FlinkAgentsMetricGroup requestMetricGroup = ctx.getActionMetricGroup();

boolean chatAsync = ctx.getConfig().get(AgentExecutionOptions.CHAT_ASYNC);

Expand Down Expand Up @@ -372,7 +377,7 @@ public ChatMessage call() throws Exception {
chatAsync
? ctx.durableExecuteAsync(callable)
: ctx.durableExecute(callable);
recordChatTokenMetrics(chatModel, response);
recordChatTokenMetrics(chatModel, response, requestMetricGroup);
// only generate structured output for final response.
if (outputSchema != null && response.getToolCalls().isEmpty()) {
response = generateStructuredOutput(response, outputSchema);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,35 @@ void chatSucceedsWithoutRetry_retryCountIsZero() throws Exception {
verify(mockActionMetricGroup, never()).getSubGroup(anyString(), anyString());
}

@Test
void chatRecordsTokenMetricsWithRequestScopedMetricGroup() throws Exception {
configureRetryStrategy(0, 0);
FlinkAgentsMetricGroup actionA = mock(FlinkAgentsMetricGroup.class);
FlinkAgentsMetricGroup actionB = mock(FlinkAgentsMetricGroup.class);
when(mockCtx.getActionMetricGroup()).thenReturn(actionA, actionB);

ChatMessage response =
new ChatMessage(
MessageRole.ASSISTANT,
"hello",
Map.of(
"model_name", "provider-model",
"promptTokens", 100L,
"completionTokens", 50L));
when(mockChatModel.chat(any(), any(), any())).thenReturn(response);

ChatModelAction.chat(
UUID.randomUUID(),
"test-model",
List.of(new ChatMessage(MessageRole.USER, "hi")),
Map.of(),
null,
mockCtx);

verify(mockChatModel).recordTokenMetrics(actionA, "provider-model", 100L, 50L);
verify(mockChatModel, never()).recordTokenMetrics(actionB, "provider-model", 100L, 50L);
}

@Test
void chatRetriesWithExponentialBackoff() throws Exception {
// 1 second base interval; fail once then succeed -> wait 1s (1 * 2^0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.junit.jupiter.api.Test;

import java.util.HashMap;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
Expand All @@ -42,71 +44,82 @@ private static ChatMessage responseWith(Map<String, Object> extraArgs) {
@Test
void testRecordChatTokenMetricsRecordsWhenAllKeysPresent() {
BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
FlinkAgentsMetricGroup requestMetricGroup = mock(FlinkAgentsMetricGroup.class);
Map<String, Object> extraArgs = new HashMap<>();
extraArgs.put("model_name", "m");
extraArgs.put("promptTokens", 100L);
extraArgs.put("completionTokens", 50L);

ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs), requestMetricGroup);

verify(setup).recordTokenMetrics("m", 100L, 50L);
verify(setup).recordTokenMetrics(requestMetricGroup, "m", 100L, 50L);
}

@Test
void testRecordChatTokenMetricsHandlesIntegerTokenValues() {
BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
FlinkAgentsMetricGroup requestMetricGroup = mock(FlinkAgentsMetricGroup.class);
Map<String, Object> extraArgs = new HashMap<>();
extraArgs.put("model_name", "m");
extraArgs.put("promptTokens", 100);
extraArgs.put("completionTokens", 50);

ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs), requestMetricGroup);

verify(setup).recordTokenMetrics("m", 100L, 50L);
verify(setup).recordTokenMetrics(requestMetricGroup, "m", 100L, 50L);
}

@Test
void testRecordChatTokenMetricsSkipsWhenTokenValueNonNumeric() {
BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
FlinkAgentsMetricGroup requestMetricGroup = mock(FlinkAgentsMetricGroup.class);
Map<String, Object> extraArgs = new HashMap<>();
extraArgs.put("model_name", "m");
extraArgs.put("promptTokens", "100");
extraArgs.put("completionTokens", 50L);

ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs), requestMetricGroup);

verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong());
verify(setup, never())
.recordTokenMetrics(
any(FlinkAgentsMetricGroup.class), anyString(), anyLong(), anyLong());
}

@Test
void testRecordChatTokenMetricsSkipsWhenKeyMissing() {
BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
FlinkAgentsMetricGroup requestMetricGroup = mock(FlinkAgentsMetricGroup.class);
Map<String, Object> extraArgs = new HashMap<>();
extraArgs.put("model_name", "m");
extraArgs.put("completionTokens", 50L);

ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs), requestMetricGroup);

verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong());
verify(setup, never())
.recordTokenMetrics(
any(FlinkAgentsMetricGroup.class), anyString(), anyLong(), anyLong());
}

@Test
void testRecordChatTokenMetricsSkipsZeroTokensOrEmptyModel() {
BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
FlinkAgentsMetricGroup requestMetricGroup = mock(FlinkAgentsMetricGroup.class);

Map<String, Object> zeroPrompt = new HashMap<>();
zeroPrompt.put("model_name", "m");
zeroPrompt.put("promptTokens", 0L);
zeroPrompt.put("completionTokens", 50L);
ChatModelAction.recordChatTokenMetrics(setup, responseWith(zeroPrompt));
ChatModelAction.recordChatTokenMetrics(setup, responseWith(zeroPrompt), requestMetricGroup);

Map<String, Object> emptyModel = new HashMap<>();
emptyModel.put("model_name", "");
emptyModel.put("promptTokens", 100L);
emptyModel.put("completionTokens", 50L);
ChatModelAction.recordChatTokenMetrics(setup, responseWith(emptyModel));
ChatModelAction.recordChatTokenMetrics(setup, responseWith(emptyModel), requestMetricGroup);

verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong());
verify(setup, never())
.recordTokenMetrics(
any(FlinkAgentsMetricGroup.class), anyString(), anyLong(), anyLong());
}

@Test
Expand Down
13 changes: 11 additions & 2 deletions python/flink_agents/api/chat_models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
MessageRole,
find_first_system_message,
)
from flink_agents.api.metric_group import MetricGroup
from flink_agents.api.prompts.prompt import Prompt
from flink_agents.api.resource import Resource, ResourceType
from flink_agents.api.skills import BASH_TOOL, LOAD_SKILL_TOOL
Expand Down Expand Up @@ -261,7 +262,11 @@ def chat(
)

def _record_token_metrics(
self, model_name: str, prompt_tokens: int, completion_tokens: int
self,
model_name: str,
prompt_tokens: int,
completion_tokens: int,
metric_group: MetricGroup | None = None,
) -> None:
"""Record token usage metrics for the given model.

Expand All @@ -273,8 +278,12 @@ def _record_token_metrics(
The number of prompt tokens
completion_tokens : int
The number of completion tokens
metric_group : MetricGroup | None
The metric group captured when the request was initiated. If not provided,
this resource's currently bound metric group is used.
"""
metric_group = self.metric_group
if metric_group is None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two languages now document slightly different contracts for the absent-group case: Java's 4-arg recordTokenMetrics treats a null group as "skip recording" (BaseChatModelSetup.java:135), while here metric_group is None falls back to self.metric_group — the live, possibly-rebound group. Every real call site passes a captured group (chat_model_action.py:334), so there's no impact on today's flow. The part I'm less sure about: a future caller passing None expecting Java parity would record under whatever action is currently bound — the #859 scenario again on the Python side, where Java would no-op instead. Is the asymmetry intentional? If the = None default is only there to keep the param optional for the test helper, dropping it so callers must pass the captured group would match Java's no-fallback contract; alternatively a one-line javadoc note that Java's null means skip would at least make each side self-documenting.

metric_group = self.metric_group
if metric_group is None:
return

Expand Down
32 changes: 30 additions & 2 deletions python/flink_agents/api/chat_models/tests/test_token_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,16 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage:
return ChatMessage(role=MessageRole.ASSISTANT, content="Test response")

def test_record_token_metrics(
self, model_name: str, prompt_tokens: int, completion_tokens: int
self,
model_name: str,
prompt_tokens: int,
completion_tokens: int,
metric_group: MetricGroup | None = None,
) -> None:
"""Expose protected method for testing."""
self._record_token_metrics(model_name, prompt_tokens, completion_tokens)
self._record_token_metrics(
model_name, prompt_tokens, completion_tokens, metric_group
)


class _MockCounter(Counter):
Expand Down Expand Up @@ -124,6 +130,28 @@ def test_record_token_metrics_without_metric_group(self) -> None:
chat_model.test_record_token_metrics("gpt-4", 100, 50)
# No exception should be raised

def test_record_token_metrics_with_request_scoped_metric_group(self) -> None:
"""Token metrics use the metric group captured when the request started."""
chat_model = TestChatModelSetup(connection="mock", model="mock-model")
action_a_metric_group = _MockMetricGroup()
action_b_metric_group = _MockMetricGroup()

chat_model.set_metric_group(action_a_metric_group)
request_metric_group = chat_model.metric_group

chat_model.set_metric_group(action_b_metric_group)
chat_model.test_record_token_metrics(
"gpt-4", 100, 50, metric_group=request_metric_group
)

action_a_model_group = action_a_metric_group.get_sub_group("model", "gpt-4")
assert action_a_model_group.get_counter("promptTokens").get_count() == 100
assert action_a_model_group.get_counter("completionTokens").get_count() == 50

action_b_model_group = action_b_metric_group.get_sub_group("model", "gpt-4")
assert action_b_model_group.get_counter("promptTokens").get_count() == 0
assert action_b_model_group.get_counter("completionTokens").get_count() == 0

def test_token_metrics_hierarchy(self) -> None:
"""Test token metrics hierarchy: actionMetricGroup -> modelName -> counters."""
chat_model = TestChatModelSetup(connection="mock", model="mock-model")
Expand Down
2 changes: 2 additions & 0 deletions python/flink_agents/plan/actions/chat_model_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ async def chat(
chat_model = cast(
"BaseChatModelSetup", ctx.get_resource(model, ResourceType.CHAT_MODEL)
)
request_metric_group = ctx.action_metric_group

chat_async = ctx.config.get(AgentExecutionOptions.CHAT_ASYNC)

Expand Down Expand Up @@ -334,6 +335,7 @@ async def chat(
response.extra_args["model_name"],
response.extra_args["promptTokens"],
response.extra_args["completionTokens"],
request_metric_group,
)
if output_schema is not None and len(response.tool_calls) == 0:
response = _generate_structured_output(response, output_schema)
Expand Down
Loading