From c312e355a219c165564a230434e7af3d0112eaf3 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Thu, 11 Jun 2026 23:49:31 -0700 Subject: [PATCH] [api][integrations] Support models' native structured output (foundation + OpenAI) Add the foundation for using a model provider's native structured-output capability at the chat-model connection layer, plus the OpenAI implementation, in both Java and Python. Previously output_schema was honored only by prompt-engineering the request and parsing the response text. The request's output schema is carried to the connection through a reserved key in the existing modelParams/kwargs map, so the abstract chat() signature is unchanged. Each connection declares a boolean native-structured-output capability (default false). A connection applies the native API only when a schema is present, no tools are bound on the call, the schema is a POJO (Java) / BaseModel (Python) rather than a RowTypeInfo, and the setup is same-language. The reserved key is always removed before the SDK call so it cannot leak into a provider request. The prompt-engineered path is retained as the fallback and is unaffected: in the ReAct loop tools are always bound, so the native path stays dormant there. OpenAI applies response_format json_schema with strict validation. Other connections only strip the reserved key; their native paths and the ReActAgent final-output wiring follow in later changes. --- .../chat/model/BaseChatModelConnection.java | 35 ++++ .../api/chat/model/BaseChatModelTest.java | 27 +++ .../openai/OpenAICompletionsConnection.java | 43 ++++- .../OpenAICompletionsConnectionTest.java | 154 ++++++++++++++++++ .../agents/plan/actions/ChatModelAction.java | 15 +- .../plan/actions/ChatModelActionTest.java | 78 +++++++++ .../api/chat_models/chat_model.py | 33 ++++ .../anthropic/anthropic_chat_model.py | 4 + .../azure/azure_openai_chat_model.py | 4 + .../chat_models/ollama_chat_model.py | 4 + .../chat_models/openai/openai_chat_model.py | 46 +++++- .../test_openai_native_structured_output.py | 127 +++++++++++++++ .../tests/test_reserved_key_no_leak.py | 62 +++++++ .../chat_models/tongyi_chat_model.py | 4 + .../plan/actions/chat_model_action.py | 17 +- .../actions/test_chat_model_action_retry.py | 66 ++++++++ 16 files changed, 711 insertions(+), 8 deletions(-) create mode 100644 integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnectionTest.java create mode 100644 python/flink_agents/integrations/chat_models/openai/tests/test_openai_native_structured_output.py create mode 100644 python/flink_agents/integrations/chat_models/tests/test_reserved_key_no_leak.py diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java index 5181073a7..7dc8afc4f 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java @@ -36,6 +36,14 @@ */ public abstract class BaseChatModelConnection extends Resource { + /** + * Reserved {@code modelParams} key carrying the raw output schema (a POJO {@link Class} or an + * {@link org.apache.flink.agents.api.agents.OutputSchema}) down to the connection so it can + * apply the provider's native structured-output mechanism. The key is intra-language only and + * must be removed before the provider SDK call so it never reaches the request body. + */ + public static final String STRUCTURED_OUTPUT_SCHEMA_KEY = "__structured_output_schema__"; + public BaseChatModelConnection(ResourceDescriptor descriptor, ResourceContext resourceContext) { super(descriptor, resourceContext); } @@ -45,6 +53,33 @@ public ResourceType getResourceType() { return ResourceType.CHAT_MODEL_CONNECTION; } + /** + * Whether this connection applies the provider's native structured-output API when an output + * schema is supplied. Connections that translate a schema into a native provider parameter + * override this to return {@code true}; the default false keeps non-native connections on the + * prompt-engineering fallback. + * + * @return true if this connection supports native structured output + */ + protected boolean supportsNativeStructuredOutput() { + return false; + } + + /** + * Removes and returns the reserved structured-output schema from {@code modelParams}. Every + * connection must call this so the reserved key never leaks into the provider SDK request; + * native connections additionally use the returned value to build the native parameter. + * + * @param modelParams the mutable model parameters map (may be null) + * @return the raw output schema if present, otherwise null + */ + protected static Object popStructuredOutputSchema(Map modelParams) { + if (modelParams == null) { + return null; + } + return modelParams.remove(STRUCTURED_OUTPUT_SCHEMA_KEY); + } + /** * Process a chat request and return a chat response. * diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java index 43f8c8b05..d5a962e1a 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java @@ -320,6 +320,33 @@ void testChatRefillsTemplateOnSubsequentInvocations() { assertEquals("tool result", connection.capturedMessages.get(1).getContent()); } + @Test + @DisplayName("popStructuredOutputSchema removes the reserved key and returns its value") + void testPopStructuredOutputSchemaRemovesAndReturns() { + Object schema = new Object(); + Map modelParams = new HashMap<>(); + modelParams.put(BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY, schema); + modelParams.put("temperature", 0.5); + + Object popped = BaseChatModelConnection.popStructuredOutputSchema(modelParams); + + assertSame(schema, popped); + assertFalse(modelParams.containsKey(BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY)); + assertTrue(modelParams.containsKey("temperature")); + } + + @Test + @DisplayName( + "popStructuredOutputSchema returns null when the reserved key is absent or map is null") + void testPopStructuredOutputSchemaNoKey() { + Map modelParams = new HashMap<>(); + modelParams.put("temperature", 0.5); + + assertNull(BaseChatModelConnection.popStructuredOutputSchema(modelParams)); + assertEquals(1, modelParams.size()); + assertNull(BaseChatModelConnection.popStructuredOutputSchema(null)); + } + @Test @DisplayName("Test chat with long input") void testChatWithLongInput() { diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java index 29d0dcf78..3ab5a8a86 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java @@ -22,11 +22,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonSchemaLocalValidation; import com.openai.core.JsonValue; import com.openai.models.ChatModel; import com.openai.models.FunctionDefinition; import com.openai.models.FunctionParameters; import com.openai.models.ReasoningEffort; +import com.openai.models.ResponseFormatJsonSchema; import com.openai.models.chat.completions.ChatCompletion; import com.openai.models.chat.completions.ChatCompletionCreateParams; import com.openai.models.chat.completions.ChatCompletionFunctionTool; @@ -119,6 +121,11 @@ public OpenAICompletionsConnection( this.client = builder.build(); } + @Override + protected boolean supportsNativeStructuredOutput() { + return true; + } + @Override public ChatMessage chat( List messages, List tools, Map modelParams) { @@ -150,7 +157,9 @@ public ChatMessage chat( } } - private ChatCompletionCreateParams buildRequest( + // Package-private so the request body (including the native response_format) can be asserted + // without issuing a live API call through the final OpenAI client. + ChatCompletionCreateParams buildRequest( List messages, List tools, Map rawModelParams) { Map modelParams = rawModelParams != null ? new HashMap<>(rawModelParams) : new HashMap<>(); @@ -161,15 +170,25 @@ private ChatCompletionCreateParams buildRequest( modelName = this.defaultModel; } + // Always pop the reserved schema so it never leaks into the SDK request body. + Object outputSchema = popStructuredOutputSchema(modelParams); + ChatCompletionCreateParams.Builder builder = ChatCompletionCreateParams.builder() .model(ChatModel.of(modelName)) .messages(OpenAIChatCompletionsUtils.convertToOpenAIMessages(messages)); - if (tools != null && !tools.isEmpty()) { + boolean hasTools = tools != null && !tools.isEmpty(); + if (hasTools) { builder.tools(convertTools(tools, strictMode)); } + // Native structured output applies only for a POJO Class schema when no tools are bound; + // a RowTypeInfo (wrapped in OutputSchema) keeps the prompt-engineering fallback. + if (outputSchema instanceof Class && !hasTools) { + builder.responseFormat(toNativeResponseFormat((Class) outputSchema)); + } + Object temperature = modelParams.remove("temperature"); if (temperature instanceof Number) { builder.temperature(((Number) temperature).doubleValue()); @@ -208,6 +227,26 @@ private ChatCompletionCreateParams buildRequest( return builder.build(); } + // Derives the strict json_schema response format from a POJO class via the SDK's typed + // structured-output builder. The Kotlin-facade StructuredOutputsKt.responseFormatFromClass is + // not callable from Java, so the response format is extracted through the typed builder, which + // generates the same strict draft-2020-12 schema, and then reattached to the standard builder. + private static ResponseFormatJsonSchema toNativeResponseFormat(Class schemaClass) { + return ChatCompletionCreateParams.builder() + .model(ChatModel.of("")) + .addUserMessage("") + .responseFormat(schemaClass, JsonSchemaLocalValidation.NO) + .build() + .rawParams() + .responseFormat() + .orElseThrow( + () -> + new IllegalStateException( + "OpenAI SDK did not produce a response_format for schema " + + schemaClass.getName())) + .asJsonSchema(); + } + private List convertTools(List tools, boolean strictMode) { List openaiTools = new ArrayList<>(tools.size()); for (Tool tool : tools) { diff --git a/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnectionTest.java b/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnectionTest.java new file mode 100644 index 000000000..767e36713 --- /dev/null +++ b/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnectionTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.openai; + +import com.openai.models.ResponseFormatJsonSchema; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.api.tools.ToolType; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link OpenAICompletionsConnection}'s native structured-output behavior. These + * assert the built request body without a live API call by inspecting {@code buildRequest}. + */ +class OpenAICompletionsConnectionTest { + + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); + + /** A representative POJO output schema. */ + public static class Person { + public String name; + public int age; + } + + private static OpenAICompletionsConnection connection() { + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(OpenAICompletionsConnection.class.getName()) + .addInitialArgument("api_key", "test-key") + .addInitialArgument("model", "gpt-4o") + .build(); + return new OpenAICompletionsConnection(desc, NOOP); + } + + private static Map paramsWithSchema(Object schema) { + Map params = new HashMap<>(); + params.put("model", "gpt-4o"); + if (schema != null) { + params.put(BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY, schema); + } + return params; + } + + private static List userMessage() { + return List.of(new ChatMessage(MessageRole.USER, "hi")); + } + + @Test + @DisplayName("Native response_format json_schema strict applied for a POJO and no tools") + void testNativeAppliedForPojoNoTools() { + ChatCompletionCreateParams params = + connection().buildRequest(userMessage(), List.of(), paramsWithSchema(Person.class)); + + assertThat(params.responseFormat()).isPresent(); + ResponseFormatJsonSchema jsonSchema = params.responseFormat().get().asJsonSchema(); + assertThat(jsonSchema.jsonSchema().strict()).contains(true); + } + + /** Minimal tool stub; only its presence in the list matters for the empty-tools gate. */ + private static class StubTool extends Tool { + StubTool() { + super(new ToolMetadata("add", "adds", "{\"type\":\"object\"}")); + } + + @Override + public ToolType getToolType() { + return ToolType.FUNCTION; + } + + @Override + public ToolResponse call(ToolParameters parameters) { + return ToolResponse.success(null); + } + } + + @Test + @DisplayName("Native NOT applied when tools are bound (empty-tools gate)") + void testNativeNotAppliedWithTools() { + ChatCompletionCreateParams params = + connection() + .buildRequest( + userMessage(), + List.of(new StubTool()), + paramsWithSchema(Person.class)); + + assertThat(params.responseFormat()).isEmpty(); + } + + @Test + @DisplayName("Native NOT applied for a non-POJO schema form (BaseModel/POJO-only scope)") + void testNativeNotAppliedForNonClassSchema() { + // A RowTypeInfo schema arrives wrapped (not a bare POJO Class), so it must not activate + // native structured output; any non-Class schema object exercises the same gate. + Object nonClassSchema = "row"; + + ChatCompletionCreateParams params = + connection() + .buildRequest(userMessage(), List.of(), paramsWithSchema(nonClassSchema)); + + assertThat(params.responseFormat()).isEmpty(); + } + + @Test + @DisplayName( + "Reserved schema key is consumed as response_format, not passed through as a body property") + void testReservedKeyConsumedNotLeaked() { + // The reserved key is consumed by the native path into response_format rather than left + // in the modelParams to leak. The pop-helper's remove-and-return contract (which makes + // this possible) is exercised directly in BaseChatModelTest; this case pins that for the + // OpenAI connection the reserved key drives response_format and is absent from the body. + ChatCompletionCreateParams params = + connection().buildRequest(userMessage(), List.of(), paramsWithSchema(Person.class)); + + assertThat(params.responseFormat()).isPresent(); + assertThat(params._additionalBodyProperties()) + .doesNotContainKey(BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY); + } + + @Test + @DisplayName("OpenAI completions declares native structured-output support") + void testDeclaresNativeCapability() { + assertThat(connection().supportsNativeStructuredOutput()).isTrue(); + } +} diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index df28c1d41..972472dcb 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -25,6 +25,7 @@ import org.apache.flink.agents.api.agents.OutputSchema; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; import org.apache.flink.agents.api.chat.model.python.PythonChatModelSetup; import org.apache.flink.agents.api.context.DurableCallable; @@ -348,6 +349,18 @@ public static void chat( int actualRetryCount = 0; int totalWaitTimeSec = 0; + // Thread the output schema to the connection via a reserved modelParams key so a + // native-capable connection can apply the provider's structured-output API. The + // connection pops the key before its SDK call (see BaseChatModelConnection). Only + // thread it for a same-language (Java) setup: native structured output cannot work + // across the Pemja bridge because a Java schema object is not consumable by a Python + // connection, so a Python-backed setup keeps the prior empty-map behavior. + final Map modelParams = + outputSchema != null && !(chatModel instanceof PythonChatModelSetup) + ? Collections.singletonMap( + BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY, outputSchema) + : Map.of(); + DurableCallable callable = new DurableCallable<>() { @Override @@ -362,7 +375,7 @@ public Class getResultClass() { @Override public ChatMessage call() throws Exception { - return chatModel.chat(messages, promptArgs, Map.of()); + return chatModel.chat(messages, promptArgs, modelParams); } }; diff --git a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java index 85c263a66..8cea46a34 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java @@ -17,20 +17,36 @@ */ package org.apache.flink.agents.plan.actions; +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.agents.AgentExecutionOptions; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.chat.model.python.PythonChatModelSetup; +import org.apache.flink.agents.api.configuration.ReadableConfiguration; +import org.apache.flink.agents.api.context.DurableCallable; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.resource.ResourceType; import org.junit.jupiter.api.Test; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** Tests for {@link ChatModelAction}. */ class ChatModelActionTest { @@ -152,4 +168,66 @@ void testCleanLlmResponseWithMultipleLinesInBlock() { String expected = "{\n \"key\": \"value\"\n}"; assertEquals(expected, ChatModelAction.cleanLlmResponse(input)); } + + /** A representative POJO output schema. */ + public static class Person { + public String name; + } + + /** + * Invokes {@link ChatModelAction#chat} once with an output schema and captures the modelParams + * the resolved setup's {@code chat} receives. The setup throws so the IGNORE strategy returns + * before any downstream serialization runs. + */ + @SuppressWarnings("unchecked") + private static Map captureModelParams(BaseChatModelSetup chatModel) + throws Exception { + AtomicReference> captured = new AtomicReference<>(); + when(chatModel.chat(any(), any(), any())) + .thenAnswer( + inv -> { + captured.set(inv.getArgument(2)); + throw new RuntimeException("stop after capture"); + }); + + ReadableConfiguration config = mock(ReadableConfiguration.class); + when(config.get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY)) + .thenReturn(Agent.ErrorHandlingStrategy.IGNORE); + when(config.get(AgentExecutionOptions.CHAT_ASYNC)).thenReturn(false); + + RunnerContext ctx = mock(RunnerContext.class); + when(ctx.getConfig()).thenReturn(config); + when(ctx.getResource("model", ResourceType.CHAT_MODEL)).thenReturn(chatModel); + when(ctx.durableExecute(any())) + .thenAnswer(inv -> ((DurableCallable) inv.getArgument(0)).call()); + + ChatModelAction.chat( + UUID.randomUUID(), + "model", + List.of(new ChatMessage(MessageRole.USER, "hi")), + Map.of(), + Person.class, + ctx); + return captured.get(); + } + + @Test + void testOutputSchemaThreadedForSameLanguageSetup() throws Exception { + Map modelParams = captureModelParams(mock(BaseChatModelSetup.class)); + + assertNotNull(modelParams); + assertTrue( + modelParams.containsKey(BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY), + "A Java setup must receive the reserved output-schema key"); + } + + @Test + void testOutputSchemaNotThreadedForPythonSetup() throws Exception { + Map modelParams = captureModelParams(mock(PythonChatModelSetup.class)); + + assertNotNull(modelParams); + assertFalse( + modelParams.containsKey(BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY), + "A Python-backed setup must not receive a Java schema across the bridge"); + } } diff --git a/python/flink_agents/api/chat_models/chat_model.py b/python/flink_agents/api/chat_models/chat_model.py index ac4a814aa..7babe774e 100644 --- a/python/flink_agents/api/chat_models/chat_model.py +++ b/python/flink_agents/api/chat_models/chat_model.py @@ -32,6 +32,13 @@ from flink_agents.api.skills import BASH_TOOL, LOAD_SKILL_TOOL from flink_agents.api.tools.tool import Tool +STRUCTURED_OUTPUT_SCHEMA_KEY = "__structured_output_schema__" +"""Reserved ``kwargs`` key carrying the raw output schema (an ``OutputSchema`` +wrapping a ``BaseModel`` subclass or ``RowTypeInfo``) down to the connection so it +can apply the provider's native structured-output mechanism. The key is +intra-language only and must be removed before the provider SDK call so it never +reaches the request.""" + class BaseChatModelConnection(Resource, ABC): """Base abstract class for chat model connection. @@ -54,6 +61,32 @@ def resource_type(cls) -> ResourceType: """Return resource type of class.""" return ResourceType.CHAT_MODEL_CONNECTION + supports_native_structured_output: ClassVar[bool] = False + """Whether this connection applies the provider's native structured-output API + when an output schema is supplied. Connections that translate a schema into a + native provider parameter override this to ``True``; the default keeps non-native + connections on the prompt-engineering fallback.""" + + @staticmethod + def _pop_structured_output_schema(kwargs: Dict[str, Any]) -> Any: + """Remove and return the reserved structured-output schema from ``kwargs``. + + Every connection must call this so the reserved key never leaks into the + provider SDK request; native connections additionally use the returned value + to build the native parameter. + + Parameters + ---------- + kwargs : Dict[str, Any] + The mutable keyword arguments passed to ``chat``. + + Returns: + ------- + Any + The raw output schema if present, otherwise ``None``. + """ + return kwargs.pop(STRUCTURED_OUTPUT_SCHEMA_KEY, None) + DEFAULT_REASONING_PATTERNS: ClassVar[Tuple[re.Pattern[str], ...]] = ( re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE), re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE), diff --git a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py index c077c6c8e..478d07d45 100644 --- a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py +++ b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py @@ -171,6 +171,10 @@ def chat( **kwargs: Any, ) -> ChatMessage: """Direct communication with Anthropic model service for chat conversation.""" + # Strip the reserved schema so it never leaks into the SDK request; this + # connection does not apply native structured output, so it is discarded. + self._pop_structured_output_schema(kwargs) + anthropic_tools = None if tools is not None: anthropic_tools = [ diff --git a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py index 18a092128..d97c38a76 100644 --- a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py @@ -136,6 +136,10 @@ def chat( ChatMessage Model response message """ + # Strip the reserved schema so it never leaks into the SDK request; this + # connection does not apply native structured output, so it is discarded. + self._pop_structured_output_schema(kwargs) + tool_specs = None if tools is not None: tool_specs = [to_openai_tool(metadata=tool.metadata) for tool in tools] diff --git a/python/flink_agents/integrations/chat_models/ollama_chat_model.py b/python/flink_agents/integrations/chat_models/ollama_chat_model.py index 7c36ec38e..2101755c3 100644 --- a/python/flink_agents/integrations/chat_models/ollama_chat_model.py +++ b/python/flink_agents/integrations/chat_models/ollama_chat_model.py @@ -88,6 +88,10 @@ def chat( **kwargs: Any, ) -> ChatMessage: """Process a sequence of messages, and return a response.""" + # Strip the reserved schema so it never leaks into the SDK request; this + # connection does not apply native structured output, so it is discarded. + self._pop_structured_output_schema(kwargs) + ollama_messages = self.__convert_to_ollama_messages(messages) # Convert tool format diff --git a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py index 2e5fe720a..4371b3704 100644 --- a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py @@ -15,13 +15,20 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import Any, Dict, List, Literal, Sequence +from typing import Any, ClassVar, Dict, List, Literal, Sequence import httpx from openai import NOT_GIVEN, OpenAI -from pydantic import Field, PrivateAttr + +# Private SDK module (leading underscore): the openai client itself uses this helper to +# build the strict json_schema for response_format, and there is no public re-export. It +# has existed at this path since the structured-output support in openai 1.66.3 (the +# pinned minimum). A future openai bump that moves it will fail loudly on import here. +from openai.lib._pydantic import to_strict_json_schema +from pydantic import BaseModel, Field, PrivateAttr from typing_extensions import override +from flink_agents.api.agents.types import OutputSchema from flink_agents.api.chat_message import ChatMessage from flink_agents.api.chat_models.chat_model import ( BaseChatModelConnection, @@ -38,6 +45,30 @@ DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo" +def _native_response_format( + output_schema: Any, tool_specs: List[Dict[str, Any]] | None +) -> Dict[str, Any] | None: + """Build the OpenAI ``response_format`` for a native structured-output request. + + Returns ``None`` (leaving behavior unchanged) unless an output schema is present, + no tools are bound, and the schema is a ``BaseModel`` subclass. A ``RowTypeInfo`` + schema is skipped so it keeps the prompt-engineering fallback. + """ + if output_schema is None or tool_specs: + return None + model = output_schema.output_schema if isinstance(output_schema, OutputSchema) else None + if not (isinstance(model, type) and issubclass(model, BaseModel)): + return None + return { + "type": "json_schema", + "json_schema": { + "name": model.__name__, + "schema": to_strict_json_schema(model), + "strict": True, + }, + } + + class OpenAIChatModelConnection(BaseChatModelConnection): """The connection to the OpenAI LLM. @@ -57,6 +88,8 @@ class OpenAIChatModelConnection(BaseChatModelConnection): Whether to reuse the OpenAI client between requests. """ + supports_native_structured_output: ClassVar[bool] = True + api_key: str = Field(default=None, description="The OpenAI API key.") api_base_url: str = Field(description="The base URL for OpenAI API.") max_retries: int = Field( @@ -157,6 +190,9 @@ def chat( ChatMessage Model response message """ + # Always pop the reserved schema so it never leaks into the SDK request. + output_schema = self._pop_structured_output_schema(kwargs) + tool_specs = None if tools is not None: tool_specs = [to_openai_tool(metadata=tool.metadata) for tool in tools] @@ -166,6 +202,12 @@ def chat( tool_spec["function"]["strict"] = strict tool_spec["function"]["parameters"]["additionalProperties"] = False + # Native structured output applies only for a BaseModel schema when no tools + # are bound; a RowTypeInfo schema keeps the prompt-engineering fallback. + response_format = _native_response_format(output_schema, tool_specs) + if response_format is not None: + kwargs["response_format"] = response_format + response = self.client.chat.completions.create( messages=convert_to_openai_messages(messages), tools=tool_specs or NOT_GIVEN, diff --git a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_native_structured_output.py b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_native_structured_output.py new file mode 100644 index 000000000..4b358c795 --- /dev/null +++ b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_native_structured_output.py @@ -0,0 +1,127 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing import Any +from unittest.mock import MagicMock + +from pydantic import BaseModel +from pyflink.common.typeinfo import Types + +from flink_agents.api.agents.types import OutputSchema +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.chat_models.chat_model import STRUCTURED_OUTPUT_SCHEMA_KEY +from flink_agents.integrations.chat_models.openai.openai_chat_model import ( + OpenAIChatModelConnection, +) +from flink_agents.plan.function import PythonFunction +from flink_agents.plan.tools.function_tool import FunctionTool + + +class Person(BaseModel): + """A representative BaseModel output schema.""" + + name: str + age: int + + +def _connection() -> OpenAIChatModelConnection: + conn = OpenAIChatModelConnection( + name="openai", api_key="test-key", api_base_url="http://localhost" + ) + mock_client = MagicMock() + mock_message = MagicMock() + mock_message.role = "assistant" + mock_message.content = "ok" + mock_message.tool_calls = None + mock_client.chat.completions.create.return_value.choices = [ + MagicMock(message=mock_message) + ] + mock_client.chat.completions.create.return_value.usage = None + conn._client = mock_client + return conn + + +def _create_call_kwargs(conn: OpenAIChatModelConnection) -> dict[str, Any]: + return conn.client.chat.completions.create.call_args.kwargs + + +def _add(a: int, b: int) -> int: + """Add two integers. + + Parameters + ---------- + a : int + first + b : int + second + + Returns: + ------- + int + sum + """ + return a + b + + +def test_native_applied_for_basemodel_no_tools() -> None: + """Native response_format json_schema strict applied for a BaseModel and no tools.""" + conn = _connection() + conn.chat( + [ChatMessage(role=MessageRole.USER, content="hi")], + model="gpt-4o", + **{STRUCTURED_OUTPUT_SCHEMA_KEY: OutputSchema(output_schema=Person)}, + ) + response_format = _create_call_kwargs(conn)["response_format"] + assert response_format["type"] == "json_schema" + assert response_format["json_schema"]["strict"] is True + assert response_format["json_schema"]["schema"]["additionalProperties"] is False + + +def test_native_not_applied_with_tools() -> None: + """Native NOT applied when tools are bound (empty-tools gate).""" + conn = _connection() + tool = FunctionTool(func=PythonFunction.from_callable(_add)) + conn.chat( + [ChatMessage(role=MessageRole.USER, content="hi")], + tools=[tool], + model="gpt-4o", + **{STRUCTURED_OUTPUT_SCHEMA_KEY: OutputSchema(output_schema=Person)}, + ) + assert "response_format" not in _create_call_kwargs(conn) + + +def test_native_not_applied_for_row_type_info() -> None: + """Native NOT applied for a RowTypeInfo schema (BaseModel/POJO-only scope).""" + conn = _connection() + row_type = Types.ROW_NAMED(["name"], [Types.STRING()]) + conn.chat( + [ChatMessage(role=MessageRole.USER, content="hi")], + model="gpt-4o", + **{STRUCTURED_OUTPUT_SCHEMA_KEY: OutputSchema(output_schema=row_type)}, + ) + assert "response_format" not in _create_call_kwargs(conn) + + +def test_reserved_key_never_leaks_to_sdk() -> None: + """The reserved schema key is never forwarded to the SDK call.""" + conn = _connection() + conn.chat( + [ChatMessage(role=MessageRole.USER, content="hi")], + model="gpt-4o", + **{STRUCTURED_OUTPUT_SCHEMA_KEY: OutputSchema(output_schema=Person)}, + ) + assert STRUCTURED_OUTPUT_SCHEMA_KEY not in _create_call_kwargs(conn) diff --git a/python/flink_agents/integrations/chat_models/tests/test_reserved_key_no_leak.py b/python/flink_agents/integrations/chat_models/tests/test_reserved_key_no_leak.py new file mode 100644 index 000000000..b449e3397 --- /dev/null +++ b/python/flink_agents/integrations/chat_models/tests/test_reserved_key_no_leak.py @@ -0,0 +1,62 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from unittest.mock import MagicMock, PropertyMock, patch + +from pydantic import BaseModel + +from flink_agents.api.agents.types import OutputSchema +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.chat_models.chat_model import STRUCTURED_OUTPUT_SCHEMA_KEY +from flink_agents.integrations.chat_models.ollama_chat_model import ( + OllamaChatModelConnection, +) + + +class Person(BaseModel): + """A representative BaseModel output schema.""" + + name: str + + +def test_non_native_connection_does_not_forward_reserved_key() -> None: + """A non-native connection given a schema must not pass the reserved key to its SDK.""" + conn = OllamaChatModelConnection(name="ollama") + + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.message.role = "assistant" + mock_response.message.content = "ok" + mock_response.message.tool_calls = None + mock_response.prompt_eval_count = None + mock_response.eval_count = None + mock_client.chat.return_value = mock_response + + with patch.object( + OllamaChatModelConnection, "client", new_callable=PropertyMock + ) as client_prop: + client_prop.return_value = mock_client + conn.chat( + [ChatMessage(role=MessageRole.USER, content="hi")], + model="qwen3:1.7b", + **{STRUCTURED_OUTPUT_SCHEMA_KEY: OutputSchema(output_schema=Person)}, + ) + + call_kwargs = mock_client.chat.call_args.kwargs + # The reserved key must not appear anywhere the SDK receives it. + assert STRUCTURED_OUTPUT_SCHEMA_KEY not in call_kwargs + assert STRUCTURED_OUTPUT_SCHEMA_KEY not in call_kwargs.get("options", {}) diff --git a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py index 6587a8cba..bd5b13339 100644 --- a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py @@ -104,6 +104,10 @@ def chat( **kwargs: Any, ) -> ChatMessage: """Process a sequence of messages, and return a response.""" + # Strip the reserved schema so it never leaks into the SDK request; this + # connection does not apply native structured output, so it is discarded. + self._pop_structured_output_schema(kwargs) + tongyi_messages = self.__convert_to_tongyi_messages(messages) tongyi_tools: List[Dict[str, Any]] | None = ( diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index e4572056e..82138d4b3 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -19,7 +19,7 @@ import logging import re import time -from typing import TYPE_CHECKING, Dict, List, cast +from typing import TYPE_CHECKING, Any, Dict, List, cast from uuid import UUID from pydantic import BaseModel @@ -29,6 +29,7 @@ from flink_agents.api.agents.agent import STRUCTURED_OUTPUT from flink_agents.api.agents.react_agent import OutputSchema from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.chat_models.chat_model import STRUCTURED_OUTPUT_SCHEMA_KEY from flink_agents.api.chat_models.java_chat_model import JavaChatModelSetup from flink_agents.api.core_options import ( AgentExecutionOptions, @@ -314,15 +315,25 @@ async def chat( actual_retry_count = 0 total_wait_time_sec = 0 + # Thread the output schema to the connection via a reserved kwarg so a + # native-capable connection can apply the provider's structured-output API. The + # connection pops the key before its SDK call (see BaseChatModelConnection). Only + # thread it for a same-language (Python) setup: native structured output cannot work + # across the Pemja bridge because a Python schema object is not consumable by a Java + # connection, so a Java-backed setup keeps the prior behavior (no reserved kwarg). + chat_kwargs: Dict[str, Any] = {"prompt_args": prompt_args} + if output_schema is not None and not isinstance(chat_model, JavaChatModelSetup): + chat_kwargs[STRUCTURED_OUTPUT_SCHEMA_KEY] = output_schema + for attempt in range(num_retries + 1): try: if chat_async: response = await ctx.durable_execute_async( - chat_model.chat, messages, prompt_args=prompt_args + chat_model.chat, messages, **chat_kwargs ) else: response = ctx.durable_execute( - chat_model.chat, messages, prompt_args=prompt_args + chat_model.chat, messages, **chat_kwargs ) if ( diff --git a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py index e9e799da3..c8f9e3245 100644 --- a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py +++ b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py @@ -24,8 +24,12 @@ from uuid import uuid4 import pytest +from pydantic import BaseModel +from flink_agents.api.agents.react_agent import OutputSchema from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.chat_models.chat_model import STRUCTURED_OUTPUT_SCHEMA_KEY +from flink_agents.api.chat_models.java_chat_model import JavaChatModelSetup from flink_agents.api.core_options import ( AgentExecutionOptions, ErrorHandlingStrategy, @@ -344,3 +348,65 @@ def mock_chat(messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: assert captured_prompt_args[0] == saved_prompt_args assert len(sent_events) == 1 assert isinstance(sent_events[0], ChatResponseEvent) + + +class TestOutputSchemaThreadingIsSameLanguageOnly: + """The reserved output-schema kwarg is threaded only when the resolved setup + is a same-language (Python) chat model. A Java-backed setup receives no + reserved kwarg, because a Python schema object cannot drive native structured + output across the Pemja bridge. + """ + + def _run_chat_and_capture_kwargs(self, chat_model: Any) -> dict: + captured: dict = {} + + def mock_chat(messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: + captured.update(kwargs) + # Stop after capturing; IGNORE strategy makes chat() return before any + # downstream structured-output parsing runs. + err_msg = "stop after capture" + raise RuntimeError(err_msg) + + chat_model.chat = mock_chat + + config = MagicMock() + option_values = { + id(AgentExecutionOptions.ERROR_HANDLING_STRATEGY): ErrorHandlingStrategy.IGNORE, + id(AgentExecutionOptions.CHAT_ASYNC): False, + } + config.get = MagicMock( + side_effect=lambda option: option_values.get( + id(option), option.get_default_value() + ) + ) + ctx = MagicMock() + ctx.config = config + ctx.get_resource = MagicMock(return_value=chat_model) + ctx.durable_execute = MagicMock( + side_effect=lambda fn, *args, **kwargs: fn(*args, **kwargs) + ) + + class _Result(BaseModel): + result: int + + asyncio.run( + chat( + uuid4(), + "test-model", + [ChatMessage(role=MessageRole.USER, content="hi")], + {}, + OutputSchema(output_schema=_Result), + ctx, + ) + ) + return captured + + def test_python_setup_receives_reserved_kwarg(self) -> None: + captured = self._run_chat_and_capture_kwargs(MagicMock()) + assert STRUCTURED_OUTPUT_SCHEMA_KEY in captured + + def test_java_setup_does_not_receive_reserved_kwarg(self) -> None: + captured = self._run_chat_and_capture_kwargs( + MagicMock(spec=JavaChatModelSetup) + ) + assert STRUCTURED_OUTPUT_SCHEMA_KEY not in captured