diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java index 278e356ba..33eedbecd 100644 --- a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java +++ b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java @@ -59,6 +59,7 @@ public class ReActAgent extends Agent { public ReActAgent( ResourceDescriptor descriptor, @Nullable Prompt prompt, @Nullable Object outputSchema) { this.addResource(DEFAULT_CHAT_MODEL, ResourceType.CHAT_MODEL, descriptor); + Map actionConfig = new HashMap<>(); if (outputSchema != null) { String jsonSchema; @@ -81,15 +82,13 @@ public ReActAgent( "The final response should be json format, and match the schema %s", jsonSchema)); this.addResource(DEFAULT_SCHEMA_PROMPT, ResourceType.PROMPT, schemaPrompt); + actionConfig.put("output_schema", outputSchema); } if (prompt != null) { this.addResource(DEFAULT_USER_PROMPT, ResourceType.PROMPT, prompt); } - Map actionConfig = new HashMap<>(); - actionConfig.put("output_schema", outputSchema); - try { Method method = this.getClass().getMethod("startAction", InputEvent.class, RunnerContext.class); diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java index 1baa40d17..7bff4550b 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java @@ -34,6 +34,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.Schema; @@ -46,6 +47,7 @@ import org.junit.jupiter.api.Test; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import static org.apache.flink.agents.api.agents.AgentExecutionOptions.ERROR_HANDLING_STRATEGY; @@ -114,7 +116,7 @@ public void testReActAgent() throws Exception { agentsEnv.getConfig().set(MAX_RETRIES, 3); // Declare the ReAct agent. - Agent agent = getAgent(); + Agent agent = getAgent(true); // Create input table from sample data Table inputTable = @@ -152,8 +154,74 @@ public void testReActAgent() throws Exception { checkResult(results); } - // create ReAct agent. - private static Agent getAgent() { + @Test + public void testReActAgentNoOutputSchema() throws Exception { + Assumptions.assumeTrue(ollamaReady, String.format("%s is not ready", OLLAMA_MODEL)); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + // Create the table environment + StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env); + tableEnv.getConfig().set("table.exec.result.display.max-column-width", "100"); + + // Create agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env, tableEnv); + + // register resource to agents execution environment. + agentsEnv + .addResource( + "ollama", + ResourceType.CHAT_MODEL_CONNECTION, + ResourceDescriptor.Builder.newBuilder( + ResourceName.ChatModel.OLLAMA_CONNECTION) + .addInitialArgument("endpoint", "http://localhost:11434") + .addInitialArgument("requestTimeout", 240) + .build()) + .addResource( + "add", + ResourceType.TOOL, + Tool.fromMethod( + ReActAgentTest.class.getMethod("add", Double.class, Double.class))) + .addResource( + "multiply", + ResourceType.TOOL, + Tool.fromMethod( + ReActAgentTest.class.getMethod( + "multiply", Double.class, Double.class))); + + agentsEnv.getConfig().set(ERROR_HANDLING_STRATEGY, ReActAgent.ErrorHandlingStrategy.RETRY); + agentsEnv.getConfig().set(MAX_RETRIES, 3); + + // Declare the ReAct agent without an output schema. + Agent agent = getAgent(false); + + // Create input table from sample data + Table inputTable = + tableEnv.fromValues( + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.DOUBLE()), + DataTypes.FIELD("b", DataTypes.DOUBLE()), + DataTypes.FIELD("c", DataTypes.DOUBLE())), + Row.of(2131, 29847, 3)); + + // Apply agent to the Table; without an output schema the result is a string. + DataStream out = + agentsEnv + .fromTable( + inputTable, + (KeySelector) + value -> (Double) ((Row) value).getField("a")) + .apply(agent) + .toDataStream(); + + out.print(); + + env.execute(); + } + + // create ReAct agent; pass false to skip the output schema. + private static Agent getAgent(boolean withSchema) { ResourceDescriptor chatModelDescriptor = ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.OLLAMA_SETUP) .addInitialArgument("connection", "ollama") @@ -162,21 +230,24 @@ private static Agent getAgent() { .addInitialArgument("extract_reasoning", true) .build(); - Prompt prompt = - Prompt.fromMessages( - List.of( - new ChatMessage( - MessageRole.SYSTEM, - "Must call function tool to do the calculate."), - new ChatMessage( - MessageRole.SYSTEM, - "An example of output is {\"result\": 30.32}"), - new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}."))); + List messages = new ArrayList<>(); + messages.add( + new ChatMessage( + MessageRole.SYSTEM, "Must call function tool to do the calculate.")); + if (withSchema) { + messages.add( + new ChatMessage( + MessageRole.SYSTEM, "An example of output is {\"result\": 30.32}")); + } + messages.add(new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}.")); + RowTypeInfo outputTypeInfo = - new RowTypeInfo( - new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO}, - new String[] {"result"}); - return new ReActAgent(chatModelDescriptor, prompt, outputTypeInfo); + withSchema + ? new RowTypeInfo( + new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO}, + new String[] {"result"}) + : null; + return new ReActAgent(chatModelDescriptor, Prompt.fromMessages(messages), outputTypeInfo); } private void checkResult(CloseableIterator results) { diff --git a/python/flink_agents/api/agents/react_agent.py b/python/flink_agents/api/agents/react_agent.py index fbb4b93f3..bdd7414dd 100644 --- a/python/flink_agents/api/agents/react_agent.py +++ b/python/flink_agents/api/agents/react_agent.py @@ -138,7 +138,9 @@ def __init__( name="start_action", events=[InputEvent], func=self.start_action, - output_schema=OutputSchema(output_schema=output_schema), + output_schema=OutputSchema(output_schema=output_schema) + if output_schema + else None, ) @staticmethod diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py index 5e4880dfe..31d4b4472 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py @@ -133,6 +133,59 @@ def test_react_agent_on_local_runner() -> None: # noqa: D103 assert output_list[0]["0001"].result == 1386528 +@pytest.mark.skipif( + client is None, reason="Ollama client is not available or test model is missing" +) +def test_react_agent_no_output_schema_on_local_runner() -> None: # noqa: D103 + env = AgentsExecutionEnvironment.get_execution_environment() + env.get_config().set( + AgentExecutionOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.RETRY + ) + env.get_config().set(AgentExecutionOptions.MAX_RETRIES, 3) + + # register resource to execution environment + ( + env.add_resource( + "ollama", + ResourceType.CHAT_MODEL_CONNECTION, + ResourceDescriptor(clazz=ResourceName.ChatModel.OLLAMA_CONNECTION, request_timeout=240.0), + ) + .add_resource("add", ResourceType.TOOL, Tool.from_callable(add)) + .add_resource("multiply", ResourceType.TOOL, Tool.from_callable(multiply)) + ) + + # prepare prompt + prompt = Prompt.from_messages( + messages=[ + ChatMessage(role=MessageRole.USER, content="What is ({a} + {b}) * {c}"), + ], + ) + + # create ReAct agent without an output schema; result is emitted as a string. + agent = ReActAgent( + chat_model=ResourceDescriptor( + clazz=ResourceName.ChatModel.OLLAMA_SETUP, + connection="ollama", + model=OLLAMA_MODEL, + tools=["add", "multiply"], + ), + prompt=prompt, + ) + + # execute agent + input_list = [] + + output_list = env.from_list(input_list).apply(agent).to_list() + input_list.append({"key": "0001", "value": InputData(a=2123, b=2321, c=312)}) + + env.execute() + + assert len(output_list) == 1, ( + "This may be caused by the LLM response failing to produce a result, you can rerun this case." + ) + assert isinstance(output_list[0]["0001"], str) + + @pytest.mark.skipif( client is None, reason="Ollama client is not available or test model is missing" )