Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public class ReActAgent extends Agent {
public ReActAgent(
ResourceDescriptor descriptor, @Nullable Prompt prompt, @Nullable Object outputSchema) {
this.addResource(DEFAULT_CHAT_MODEL, ResourceType.CHAT_MODEL, descriptor);
Map<String, Object> actionConfig = new HashMap<>();

if (outputSchema != null) {
String jsonSchema;
Expand All @@ -81,15 +82,13 @@ public ReActAgent(
"The final response should be json format, and match the schema %s",
jsonSchema));
this.addResource(DEFAULT_SCHEMA_PROMPT, ResourceType.PROMPT, schemaPrompt);
actionConfig.put("output_schema", outputSchema);
}

if (prompt != null) {
this.addResource(DEFAULT_USER_PROMPT, ResourceType.PROMPT, prompt);
}

Map<String, Object> actionConfig = new HashMap<>();
actionConfig.put("output_schema", outputSchema);

try {
Method method =
this.getClass().getMethod("startAction", InputEvent.class, RunnerContext.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
Expand All @@ -46,6 +47,7 @@
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.apache.flink.agents.api.agents.AgentExecutionOptions.ERROR_HANDLING_STRATEGY;
Expand Down Expand Up @@ -114,7 +116,7 @@ public void testReActAgent() throws Exception {
agentsEnv.getConfig().set(MAX_RETRIES, 3);

// Declare the ReAct agent.
Agent agent = getAgent();
Agent agent = getAgent(true);

// Create input table from sample data
Table inputTable =
Expand Down Expand Up @@ -152,8 +154,74 @@ public void testReActAgent() throws Exception {
checkResult(results);
}

// create ReAct agent.
private static Agent getAgent() {
@Test
public void testReActAgentNoOutputSchema() throws Exception {
Assumptions.assumeTrue(ollamaReady, String.format("%s is not ready", OLLAMA_MODEL));
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);

// Create the table environment
StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
tableEnv.getConfig().set("table.exec.result.display.max-column-width", "100");

// Create agents execution environment
AgentsExecutionEnvironment agentsEnv =
AgentsExecutionEnvironment.getExecutionEnvironment(env, tableEnv);

// register resource to agents execution environment.
agentsEnv
.addResource(
"ollama",
ResourceType.CHAT_MODEL_CONNECTION,
ResourceDescriptor.Builder.newBuilder(
ResourceName.ChatModel.OLLAMA_CONNECTION)
.addInitialArgument("endpoint", "http://localhost:11434")
.addInitialArgument("requestTimeout", 240)
.build())
.addResource(
"add",
ResourceType.TOOL,
Tool.fromMethod(
ReActAgentTest.class.getMethod("add", Double.class, Double.class)))
.addResource(
"multiply",
ResourceType.TOOL,
Tool.fromMethod(
ReActAgentTest.class.getMethod(
"multiply", Double.class, Double.class)));

agentsEnv.getConfig().set(ERROR_HANDLING_STRATEGY, ReActAgent.ErrorHandlingStrategy.RETRY);
agentsEnv.getConfig().set(MAX_RETRIES, 3);

// Declare the ReAct agent without an output schema.
Agent agent = getAgent(false);

// Create input table from sample data
Table inputTable =
tableEnv.fromValues(
DataTypes.ROW(
DataTypes.FIELD("a", DataTypes.DOUBLE()),
DataTypes.FIELD("b", DataTypes.DOUBLE()),
DataTypes.FIELD("c", DataTypes.DOUBLE())),
Row.of(2131, 29847, 3));

// Apply agent to the Table; without an output schema the result is a string.
DataStream<Object> out =
agentsEnv
.fromTable(
inputTable,
(KeySelector<Object, Double>)
value -> (Double) ((Row) value).getField("a"))
.apply(agent)
.toDataStream();

out.print();

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test passes as long as env.execute() doesn't throw, so it covers "the plan serializes and the job runs" — which is the crash this PR fixes — but it doesn't pin the contract the no-schema path establishes: that the output comes back as a String. The Python counterpart goes a step further with assert len(output_list) == 1 and assert isinstance(output_list[0]["0001"], str), and the schema-case Java test routes through checkResult(...) for an exact value check. Would it be worth collecting the output here too — collectAsync() like the schema test, then asserting hasNext() and that the value is a String — so a regression that emitted nothing or wrongly took the STRUCTURED_OUTPUT branch in stopAction would fail on the Java side as well? Or is the intent to keep the Java e2e deliberately lighter than the Python one?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I agree that the Java test also needs to validate the results, and I will update this PR accordingly.

Additionally, since this fix has already been merged into the main branch, I will open a separate PR later to update the Java tests on the main branch as well.


env.execute();
}

// create ReAct agent; pass false to skip the output schema.
private static Agent getAgent(boolean withSchema) {
ResourceDescriptor chatModelDescriptor =
ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.OLLAMA_SETUP)
.addInitialArgument("connection", "ollama")
Expand All @@ -162,21 +230,24 @@ private static Agent getAgent() {
.addInitialArgument("extract_reasoning", true)
.build();

Prompt prompt =
Prompt.fromMessages(
List.of(
new ChatMessage(
MessageRole.SYSTEM,
"Must call function tool to do the calculate."),
new ChatMessage(
MessageRole.SYSTEM,
"An example of output is {\"result\": 30.32}"),
new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}.")));
List<ChatMessage> messages = new ArrayList<>();
messages.add(
new ChatMessage(
MessageRole.SYSTEM, "Must call function tool to do the calculate."));
if (withSchema) {
messages.add(
new ChatMessage(
MessageRole.SYSTEM, "An example of output is {\"result\": 30.32}"));
}
messages.add(new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}."));

RowTypeInfo outputTypeInfo =
new RowTypeInfo(
new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO},
new String[] {"result"});
return new ReActAgent(chatModelDescriptor, prompt, outputTypeInfo);
withSchema
? new RowTypeInfo(
new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO},
new String[] {"result"})
: null;
return new ReActAgent(chatModelDescriptor, Prompt.fromMessages(messages), outputTypeInfo);
}

private void checkResult(CloseableIterator<?> results) {
Expand Down
4 changes: 3 additions & 1 deletion python/flink_agents/api/agents/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def __init__(
name="start_action",
events=[InputEvent],
func=self.start_action,
output_schema=OutputSchema(output_schema=output_schema),
output_schema=OutputSchema(output_schema=output_schema)
if output_schema
else None,
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,59 @@ def test_react_agent_on_local_runner() -> None: # noqa: D103
assert output_list[0]["0001"].result == 1386528


@pytest.mark.skipif(
client is None, reason="Ollama client is not available or test model is missing"
)
def test_react_agent_no_output_schema_on_local_runner() -> None: # noqa: D103
env = AgentsExecutionEnvironment.get_execution_environment()
env.get_config().set(
AgentExecutionOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.RETRY
)
env.get_config().set(AgentExecutionOptions.MAX_RETRIES, 3)

# register resource to execution environment
(
env.add_resource(
"ollama",
ResourceType.CHAT_MODEL_CONNECTION,
ResourceDescriptor(clazz=ResourceName.ChatModel.OLLAMA_CONNECTION, request_timeout=240.0),
)
.add_resource("add", ResourceType.TOOL, Tool.from_callable(add))
.add_resource("multiply", ResourceType.TOOL, Tool.from_callable(multiply))
)

# prepare prompt
prompt = Prompt.from_messages(
messages=[
ChatMessage(role=MessageRole.USER, content="What is ({a} + {b}) * {c}"),
],
)

# create ReAct agent without an output schema; result is emitted as a string.
agent = ReActAgent(
chat_model=ResourceDescriptor(
clazz=ResourceName.ChatModel.OLLAMA_SETUP,
connection="ollama",
model=OLLAMA_MODEL,
tools=["add", "multiply"],
),
prompt=prompt,
)

# execute agent
input_list = []

output_list = env.from_list(input_list).apply(agent).to_list()
input_list.append({"key": "0001", "value": InputData(a=2123, b=2321, c=312)})

env.execute()

assert len(output_list) == 1, (
"This may be caused by the LLM response failing to produce a result, you can rerun this case."
)
assert isinstance(output_list[0]["0001"], str)


@pytest.mark.skipif(
client is None, reason="Ollama client is not available or test model is missing"
)
Expand Down
Loading