Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public class ReActAgent extends Agent {
public ReActAgent(
ResourceDescriptor descriptor, @Nullable Prompt prompt, @Nullable Object outputSchema) {
this.addResource(DEFAULT_CHAT_MODEL, ResourceType.CHAT_MODEL, descriptor);
Map<String, Object> actionConfig = new HashMap<>();

if (outputSchema != null) {
String jsonSchema;
Expand All @@ -82,15 +83,13 @@ public ReActAgent(
"The final response should be json format, and match the schema %s",
jsonSchema));
this.addResource(DEFAULT_SCHEMA_PROMPT, ResourceType.PROMPT, schemaPrompt);
actionConfig.put("output_schema", outputSchema);
}

if (prompt != null) {
this.addResource(DEFAULT_USER_PROMPT, ResourceType.PROMPT, prompt);
}

Map<String, Object> actionConfig = new HashMap<>();
actionConfig.put("output_schema", outputSchema);

try {
Method method =
this.getClass().getMethod("startAction", Event.class, RunnerContext.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
Expand All @@ -46,6 +47,7 @@
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.apache.flink.agents.api.agents.AgentExecutionOptions.ERROR_HANDLING_STRATEGY;
Expand Down Expand Up @@ -114,7 +116,7 @@ public void testReActAgent() throws Exception {
agentsEnv.getConfig().set(MAX_RETRIES, 3);

// Declare the ReAct agent.
Agent agent = getAgent();
Agent agent = getAgent(true);

// Create input table from sample data
Table inputTable =
Expand Down Expand Up @@ -152,8 +154,74 @@ public void testReActAgent() throws Exception {
checkResult(results);
}

// create ReAct agent.
private static Agent getAgent() {
@Test
public void testReActAgentNoOutputSchema() throws Exception {
Assumptions.assumeTrue(ollamaReady, String.format("%s is not ready", OLLAMA_MODEL));
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);

// Create the table environment
StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
tableEnv.getConfig().set("table.exec.result.display.max-column-width", "100");

// Create agents execution environment
AgentsExecutionEnvironment agentsEnv =
AgentsExecutionEnvironment.getExecutionEnvironment(env, tableEnv);

// register resource to agents execution environment.
agentsEnv
.addResource(
"ollama",
ResourceType.CHAT_MODEL_CONNECTION,
ResourceDescriptor.Builder.newBuilder(
ResourceName.ChatModel.OLLAMA_CONNECTION)
.addInitialArgument("endpoint", "http://localhost:11434")
.addInitialArgument("requestTimeout", 240)
.build())
.addResource(
"add",
ResourceType.TOOL,
Tool.fromMethod(
ReActAgentTest.class.getMethod("add", Double.class, Double.class)))
.addResource(
"multiply",
ResourceType.TOOL,
Tool.fromMethod(
ReActAgentTest.class.getMethod(
"multiply", Double.class, Double.class)));

agentsEnv.getConfig().set(ERROR_HANDLING_STRATEGY, ReActAgent.ErrorHandlingStrategy.RETRY);
agentsEnv.getConfig().set(MAX_RETRIES, 3);

// Declare the ReAct agent without an output schema.
Agent agent = getAgent(false);

// Create input table from sample data
Table inputTable =
tableEnv.fromValues(
DataTypes.ROW(
DataTypes.FIELD("a", DataTypes.DOUBLE()),
DataTypes.FIELD("b", DataTypes.DOUBLE()),
DataTypes.FIELD("c", DataTypes.DOUBLE())),
Row.of(2131, 29847, 3));

// Apply agent to the Table; without an output schema the result is a string.
DataStream<Object> out =
agentsEnv
.fromTable(
inputTable,
(KeySelector<Object, Double>)
value -> (Double) ((Row) value).getField("a"))
.apply(agent)
.toDataStream();

out.print();

env.execute();
}

// create ReAct agent; pass false to skip the output schema.
private static Agent getAgent(boolean withSchema) {
ResourceDescriptor chatModelDescriptor =
ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.OLLAMA_SETUP)
.addInitialArgument("connection", "ollama")
Expand All @@ -162,21 +230,24 @@ private static Agent getAgent() {
.addInitialArgument("extract_reasoning", true)
.build();

Prompt prompt =
Prompt.fromMessages(
List.of(
new ChatMessage(
MessageRole.SYSTEM,
"Must call function tool to do the calculate."),
new ChatMessage(
MessageRole.SYSTEM,
"An example of output is {\"result\": 30.32}"),
new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}.")));
List<ChatMessage> messages = new ArrayList<>();
messages.add(
new ChatMessage(
MessageRole.SYSTEM, "Must call function tool to do the calculate."));
if (withSchema) {
messages.add(
new ChatMessage(
MessageRole.SYSTEM, "An example of output is {\"result\": 30.32}"));
}
messages.add(new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}."));

RowTypeInfo outputTypeInfo =
new RowTypeInfo(
new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO},
new String[] {"result"});
return new ReActAgent(chatModelDescriptor, prompt, outputTypeInfo);
withSchema
? new RowTypeInfo(
new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO},
new String[] {"result"})
: null;
return new ReActAgent(chatModelDescriptor, Prompt.fromMessages(messages), outputTypeInfo);
}

private void checkResult(CloseableIterator<?> results) {
Expand Down
2 changes: 1 addition & 1 deletion python/flink_agents/api/agents/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(
name="start_action",
events=[InputEvent.EVENT_TYPE],
func=self.start_action,
output_schema=OutputSchema(output_schema=output_schema),
output_schema=OutputSchema(output_schema=output_schema) if output_schema else None,
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,88 @@ def test_react_agent_on_remote_runner(
# through the event-log capture path.
invocations = collect_tool_invocations(log_dir)
assert_tool_invoked(invocations, "multiply", {"a": 4444, "b": 312})


@pytest.mark.skipif(
client is None, reason="Ollama client is not available or test model is missing"
)
def test_react_agent_no_output_schema_on_remote_runner(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""ReAct agent without an output_schema should emit a plain string result."""
monkeypatch.setenv("OLLAMA_CHAT_MODEL", OLLAMA_MODEL)
stream_env = StreamExecutionEnvironment.get_execution_environment()

stream_env.set_parallelism(1)

t_env = StreamTableEnvironment.create(stream_execution_environment=stream_env)

table = t_env.from_elements(
elements=[(2123, 2321, 312)],
schema=DataTypes.ROW(
[
DataTypes.FIELD("a", DataTypes.INT()),
DataTypes.FIELD("b", DataTypes.INT()),
DataTypes.FIELD("c", DataTypes.INT()),
]
),
)

env = AgentsExecutionEnvironment.get_execution_environment(
env=stream_env, t_env=t_env
)

env.get_config().set(
AgentExecutionOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.RETRY
)

env.get_config().set(AgentExecutionOptions.MAX_RETRIES, 3)

log_dir = tmp_path / "event_logs"
log_dir.mkdir(parents=True, exist_ok=True)
env.get_config().set_str("baseLogDir", str(log_dir))

# register resource to execution environment
(
env.add_resource(
"ollama",
ResourceType.CHAT_MODEL_CONNECTION,
ResourceDescriptor(
clazz=ResourceName.ChatModel.OLLAMA_CONNECTION, request_timeout=240.0
),
)
.add_resource("add", ResourceType.TOOL, Tool.from_callable(add))
.add_resource("multiply", ResourceType.TOOL, Tool.from_callable(multiply))
)

# prepare prompt
prompt = Prompt.from_messages(
messages=[
ChatMessage(role=MessageRole.USER, content="What is ({a} + {b}) * {c}"),
],
)

# create ReAct agent without an output schema; result is emitted as a string.
agent = ReActAgent(
chat_model=ResourceDescriptor(
clazz=ResourceName.ChatModel.OLLAMA_SETUP,
connection="ollama",
model=OLLAMA_MODEL,
tools=["add", "multiply"],
),
prompt=prompt,
)

output_stream = (
env.from_table(input=table, key_selector=MyKeySelector())
.apply(agent)
.to_datastream()
)
output_stream.print()

env.execute()

# multiply's first arg (4444 = 2123 + 2321) proves the addition was computed
# correctly and threaded into multiply, even without an output schema.
invocations = collect_tool_invocations(log_dir)
assert_tool_invoked(invocations, "multiply", {"a": 4444, "b": 312})
Loading