Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

package org.apache.flink.agents.api.event;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.agents.api.Event;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
Expand Down Expand Up @@ -62,7 +64,10 @@ public ChatRequestEvent(String model, List<ChatMessage> messages) {
this(model, messages, null, null);
}

public ChatRequestEvent(UUID id, Map<String, Object> attributes) {
@JsonCreator
public ChatRequestEvent(
@JsonProperty("id") UUID id,
@JsonProperty("attributes") Map<String, Object> attributes) {
super(id, EVENT_TYPE, attributes);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

package org.apache.flink.agents.api.event;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.agents.api.Event;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
Expand Down Expand Up @@ -46,7 +48,10 @@ public ChatResponseEvent(
setAttr("total_retry_wait_sec", totalRetryWaitSec);
}

public ChatResponseEvent(UUID id, Map<String, Object> attributes) {
@JsonCreator
public ChatResponseEvent(
@JsonProperty("id") UUID id,
@JsonProperty("attributes") Map<String, Object> attributes) {
super(id, EVENT_TYPE, attributes);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

package org.apache.flink.agents.api.event;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.flink.agents.api.Event;

import java.util.HashMap;
Expand All @@ -43,7 +45,10 @@ public ContextRetrievalRequestEvent(String query, String vectorStore, int maxRes
setAttr("max_results", maxResults);
}

public ContextRetrievalRequestEvent(UUID id, Map<String, Object> attributes) {
@JsonCreator
public ContextRetrievalRequestEvent(
@JsonProperty("id") UUID id,
@JsonProperty("attributes") Map<String, Object> attributes) {
super(id, EVENT_TYPE, attributes);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

package org.apache.flink.agents.api.event;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.agents.api.Event;
import org.apache.flink.agents.api.vectorstores.Document;
Expand All @@ -43,7 +45,10 @@ public ContextRetrievalResponseEvent(UUID requestId, String query, List<Document
setAttr("documents", new ArrayList<>(documents));
}

public ContextRetrievalResponseEvent(UUID id, Map<String, Object> attributes) {
@JsonCreator
public ContextRetrievalResponseEvent(
@JsonProperty("id") UUID id,
@JsonProperty("attributes") Map<String, Object> attributes) {
super(id, EVENT_TYPE, attributes);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

package org.apache.flink.agents.api.event;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.flink.agents.api.Event;

import java.util.HashMap;
Expand All @@ -37,7 +39,10 @@ public ToolRequestEvent(String model, List<Map<String, Object>> toolCalls) {
setAttr("tool_calls", toolCalls);
}

public ToolRequestEvent(UUID id, Map<String, Object> attributes) {
@JsonCreator
public ToolRequestEvent(
@JsonProperty("id") UUID id,
@JsonProperty("attributes") Map<String, Object> attributes) {
super(id, EVENT_TYPE, attributes);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

package org.apache.flink.agents.api.event;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.agents.api.Event;
import org.apache.flink.agents.api.tools.ToolResponse;
Expand Down Expand Up @@ -57,7 +59,10 @@ public ToolResponseEvent(
this(requestId, responses, success, error, Map.of());
}

public ToolResponseEvent(UUID id, Map<String, Object> attributes) {
@JsonCreator
public ToolResponseEvent(
@JsonProperty("id") UUID id,
@JsonProperty("attributes") Map<String, Object> attributes) {
super(id, EVENT_TYPE, attributes);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,17 @@
import org.apache.flink.agents.api.Event;
import org.apache.flink.agents.api.InputEvent;
import org.apache.flink.agents.api.OutputEvent;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.context.MemoryUpdate;
import org.apache.flink.agents.api.event.ChatRequestEvent;
import org.apache.flink.agents.api.event.ChatResponseEvent;
import org.apache.flink.agents.api.event.ContextRetrievalRequestEvent;
import org.apache.flink.agents.api.event.ContextRetrievalResponseEvent;
import org.apache.flink.agents.api.event.ToolRequestEvent;
import org.apache.flink.agents.api.event.ToolResponseEvent;
import org.apache.flink.agents.api.tools.ToolResponse;
import org.apache.flink.agents.api.vectorstores.Document;
import org.junit.jupiter.api.Test;

import java.nio.charset.StandardCharsets;
Expand All @@ -29,6 +39,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;

import static org.junit.jupiter.api.Assertions.*;

Expand Down Expand Up @@ -333,4 +344,63 @@ public void testKafkaSederDelegatesToActionStateSerde() throws Exception {
assertNull(kafkaSeder.serialize("test-topic", null));
assertNull(kafkaSeder.deserialize("test-topic", null));
}

@Test
public void testBuiltinChatToolAndContextEventsRoundTripThroughOutputEvents() throws Exception {
ChatMessage msg = new ChatMessage(MessageRole.USER, "hello");
UUID requestId = UUID.randomUUID();
Document doc = new Document("doc content", Map.of("source", "unit-test"), "doc-1");

// Built-in events are persisted both as the triggering taskEvent and as
// outputEvents; cover both paths.
ActionState originalState = new ActionState(new ChatRequestEvent("myModel", List.of(msg)));
originalState.addEvent(new ChatRequestEvent("myModel", List.of(msg)));
originalState.addEvent(new ChatResponseEvent(requestId, msg));
originalState.addEvent(new ToolRequestEvent("myModel", List.of(Map.of("name", "myTool"))));
originalState.addEvent(
new ToolResponseEvent(
requestId,
Map.of("call-1", ToolResponse.success("result")),
Map.of("call-1", true),
Map.of()));
originalState.addEvent(new ContextRetrievalRequestEvent("query text", "myVectorStore", 5));
originalState.addEvent(
new ContextRetrievalResponseEvent(requestId, "query text", List.of(doc)));

byte[] serialized = ActionStateSerde.serialize(originalState);
ActionState deserializedState = ActionStateSerde.deserialize(serialized);

assertEquals(ChatRequestEvent.class, deserializedState.getTaskEvent().getClass());

List<Event> outputEvents = deserializedState.getOutputEvents();
assertEquals(6, outputEvents.size());
assertEquals(ChatRequestEvent.class, outputEvents.get(0).getClass());
assertEquals(ChatResponseEvent.class, outputEvents.get(1).getClass());
assertEquals(ToolRequestEvent.class, outputEvents.get(2).getClass());
assertEquals(ToolResponseEvent.class, outputEvents.get(3).getClass());
assertEquals(ContextRetrievalRequestEvent.class, outputEvents.get(4).getClass());
assertEquals(ContextRetrievalResponseEvent.class, outputEvents.get(5).getClass());

// Replayed events are consumed through fromEvent() (see ChatModelAction,
// ToolCallAction, ContextRetrievalAction), which must restore nested
// attributes degraded to raw maps by the JSON round-trip back to their
// typed forms.
ChatRequestEvent chatRequest = ChatRequestEvent.fromEvent(outputEvents.get(0));
assertEquals("hello", chatRequest.getMessages().get(0).getContent());
assertEquals(MessageRole.USER, chatRequest.getMessages().get(0).getRole());

ChatResponseEvent chatResponse = ChatResponseEvent.fromEvent(outputEvents.get(1));
assertEquals(requestId, chatResponse.getRequestId());
assertEquals("hello", chatResponse.getResponse().getContent());

ToolResponseEvent toolResponse = ToolResponseEvent.fromEvent(outputEvents.get(3));
assertEquals(requestId, toolResponse.getRequestId());
assertEquals("result", toolResponse.getResponses().get("call-1").getResult());

ContextRetrievalResponseEvent retrievalResponse =
ContextRetrievalResponseEvent.fromEvent(outputEvents.get(5));
assertEquals(requestId, retrievalResponse.getRequestId());
assertEquals("doc content", retrievalResponse.getDocuments().get(0).getContent());
assertEquals("doc-1", retrievalResponse.getDocuments().get(0).getId());
}
}
Loading