diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java index 0b394baab..088f3efd0 100644 --- a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java +++ b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java @@ -60,6 +60,7 @@ public class ReActAgent extends Agent { public ReActAgent( ResourceDescriptor descriptor, @Nullable Prompt prompt, @Nullable Object outputSchema) { this.addResource(DEFAULT_CHAT_MODEL, ResourceType.CHAT_MODEL, descriptor); + Map actionConfig = new HashMap<>(); if (outputSchema != null) { String jsonSchema; @@ -82,15 +83,13 @@ public ReActAgent( "The final response should be json format, and match the schema %s", jsonSchema)); this.addResource(DEFAULT_SCHEMA_PROMPT, ResourceType.PROMPT, schemaPrompt); + actionConfig.put("output_schema", outputSchema); } if (prompt != null) { this.addResource(DEFAULT_USER_PROMPT, ResourceType.PROMPT, prompt); } - Map actionConfig = new HashMap<>(); - actionConfig.put("output_schema", outputSchema); - try { Method method = this.getClass().getMethod("startAction", Event.class, RunnerContext.class); diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java index 1baa40d17..7bff4550b 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java @@ -34,6 +34,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.Schema; @@ -46,6 +47,7 @@ import org.junit.jupiter.api.Test; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import static org.apache.flink.agents.api.agents.AgentExecutionOptions.ERROR_HANDLING_STRATEGY; @@ -114,7 +116,7 @@ public void testReActAgent() throws Exception { agentsEnv.getConfig().set(MAX_RETRIES, 3); // Declare the ReAct agent. - Agent agent = getAgent(); + Agent agent = getAgent(true); // Create input table from sample data Table inputTable = @@ -152,8 +154,74 @@ public void testReActAgent() throws Exception { checkResult(results); } - // create ReAct agent. - private static Agent getAgent() { + @Test + public void testReActAgentNoOutputSchema() throws Exception { + Assumptions.assumeTrue(ollamaReady, String.format("%s is not ready", OLLAMA_MODEL)); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + // Create the table environment + StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env); + tableEnv.getConfig().set("table.exec.result.display.max-column-width", "100"); + + // Create agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env, tableEnv); + + // register resource to agents execution environment. + agentsEnv + .addResource( + "ollama", + ResourceType.CHAT_MODEL_CONNECTION, + ResourceDescriptor.Builder.newBuilder( + ResourceName.ChatModel.OLLAMA_CONNECTION) + .addInitialArgument("endpoint", "http://localhost:11434") + .addInitialArgument("requestTimeout", 240) + .build()) + .addResource( + "add", + ResourceType.TOOL, + Tool.fromMethod( + ReActAgentTest.class.getMethod("add", Double.class, Double.class))) + .addResource( + "multiply", + ResourceType.TOOL, + Tool.fromMethod( + ReActAgentTest.class.getMethod( + "multiply", Double.class, Double.class))); + + agentsEnv.getConfig().set(ERROR_HANDLING_STRATEGY, ReActAgent.ErrorHandlingStrategy.RETRY); + agentsEnv.getConfig().set(MAX_RETRIES, 3); + + // Declare the ReAct agent without an output schema. + Agent agent = getAgent(false); + + // Create input table from sample data + Table inputTable = + tableEnv.fromValues( + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.DOUBLE()), + DataTypes.FIELD("b", DataTypes.DOUBLE()), + DataTypes.FIELD("c", DataTypes.DOUBLE())), + Row.of(2131, 29847, 3)); + + // Apply agent to the Table; without an output schema the result is a string. + DataStream out = + agentsEnv + .fromTable( + inputTable, + (KeySelector) + value -> (Double) ((Row) value).getField("a")) + .apply(agent) + .toDataStream(); + + out.print(); + + env.execute(); + } + + // create ReAct agent; pass false to skip the output schema. + private static Agent getAgent(boolean withSchema) { ResourceDescriptor chatModelDescriptor = ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.OLLAMA_SETUP) .addInitialArgument("connection", "ollama") @@ -162,21 +230,24 @@ private static Agent getAgent() { .addInitialArgument("extract_reasoning", true) .build(); - Prompt prompt = - Prompt.fromMessages( - List.of( - new ChatMessage( - MessageRole.SYSTEM, - "Must call function tool to do the calculate."), - new ChatMessage( - MessageRole.SYSTEM, - "An example of output is {\"result\": 30.32}"), - new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}."))); + List messages = new ArrayList<>(); + messages.add( + new ChatMessage( + MessageRole.SYSTEM, "Must call function tool to do the calculate.")); + if (withSchema) { + messages.add( + new ChatMessage( + MessageRole.SYSTEM, "An example of output is {\"result\": 30.32}")); + } + messages.add(new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}.")); + RowTypeInfo outputTypeInfo = - new RowTypeInfo( - new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO}, - new String[] {"result"}); - return new ReActAgent(chatModelDescriptor, prompt, outputTypeInfo); + withSchema + ? new RowTypeInfo( + new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO}, + new String[] {"result"}) + : null; + return new ReActAgent(chatModelDescriptor, Prompt.fromMessages(messages), outputTypeInfo); } private void checkResult(CloseableIterator results) { diff --git a/python/flink_agents/api/agents/react_agent.py b/python/flink_agents/api/agents/react_agent.py index cef651a1d..c898d119b 100644 --- a/python/flink_agents/api/agents/react_agent.py +++ b/python/flink_agents/api/agents/react_agent.py @@ -143,7 +143,7 @@ def __init__( name="start_action", events=[InputEvent.EVENT_TYPE], func=self.start_action, - output_schema=OutputSchema(output_schema=output_schema), + output_schema=OutputSchema(output_schema=output_schema) if output_schema else None, ) @staticmethod diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py index f0c771785..ad689d058 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py @@ -267,3 +267,88 @@ def test_react_agent_on_remote_runner( # through the event-log capture path. invocations = collect_tool_invocations(log_dir) assert_tool_invoked(invocations, "multiply", {"a": 4444, "b": 312}) + + +@pytest.mark.skipif( + client is None, reason="Ollama client is not available or test model is missing" +) +def test_react_agent_no_output_schema_on_remote_runner( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """ReAct agent without an output_schema should emit a plain string result.""" + monkeypatch.setenv("OLLAMA_CHAT_MODEL", OLLAMA_MODEL) + stream_env = StreamExecutionEnvironment.get_execution_environment() + + stream_env.set_parallelism(1) + + t_env = StreamTableEnvironment.create(stream_execution_environment=stream_env) + + table = t_env.from_elements( + elements=[(2123, 2321, 312)], + schema=DataTypes.ROW( + [ + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.INT()), + DataTypes.FIELD("c", DataTypes.INT()), + ] + ), + ) + + env = AgentsExecutionEnvironment.get_execution_environment( + env=stream_env, t_env=t_env + ) + + env.get_config().set( + AgentExecutionOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.RETRY + ) + + env.get_config().set(AgentExecutionOptions.MAX_RETRIES, 3) + + log_dir = tmp_path / "event_logs" + log_dir.mkdir(parents=True, exist_ok=True) + env.get_config().set_str("baseLogDir", str(log_dir)) + + # register resource to execution environment + ( + env.add_resource( + "ollama", + ResourceType.CHAT_MODEL_CONNECTION, + ResourceDescriptor( + clazz=ResourceName.ChatModel.OLLAMA_CONNECTION, request_timeout=240.0 + ), + ) + .add_resource("add", ResourceType.TOOL, Tool.from_callable(add)) + .add_resource("multiply", ResourceType.TOOL, Tool.from_callable(multiply)) + ) + + # prepare prompt + prompt = Prompt.from_messages( + messages=[ + ChatMessage(role=MessageRole.USER, content="What is ({a} + {b}) * {c}"), + ], + ) + + # create ReAct agent without an output schema; result is emitted as a string. + agent = ReActAgent( + chat_model=ResourceDescriptor( + clazz=ResourceName.ChatModel.OLLAMA_SETUP, + connection="ollama", + model=OLLAMA_MODEL, + tools=["add", "multiply"], + ), + prompt=prompt, + ) + + output_stream = ( + env.from_table(input=table, key_selector=MyKeySelector()) + .apply(agent) + .to_datastream() + ) + output_stream.print() + + env.execute() + + # multiply's first arg (4444 = 2123 + 2321) proves the addition was computed + # correctly and threaded into multiply, even without an output schema. + invocations = collect_tool_invocations(log_dir) + assert_tool_invoked(invocations, "multiply", {"a": 4444, "b": 312})