Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,16 @@ public class OpenAIMultiModalTool {
/** Base URL for OpenAI API (defaults to https://api.openai.com). */
private final String baseUrl;

/** Default vision model used when the caller does not specify one. */
private final String defaultModelName;

/**
* Create a new OpenAIMultiModalTool with default base URL.
*
* @param apiKey the OpenAI API key
*/
public OpenAIMultiModalTool(String apiKey) {
this(apiKey, null);
this(apiKey, null, "gpt-4o");
}

/**
Expand All @@ -83,11 +86,27 @@ public OpenAIMultiModalTool(String apiKey) {
* @param baseUrl the base URL (null for default https://api.openai.com)
*/
public OpenAIMultiModalTool(String apiKey, String baseUrl) {
this(apiKey, baseUrl, "gpt-4o");
}

/**
* Create a new OpenAIMultiModalTool with custom base URL and default vision model.
*
* @param apiKey the OpenAI API key
* @param baseUrl the base URL (null for default https://api.openai.com)
* @param defaultModelName the default vision model name used when the caller omits the model
* parameter (e.g., "gpt-4o" for OpenAI, or your backend's vision model name)
*/
public OpenAIMultiModalTool(String apiKey, String baseUrl, String defaultModelName) {
if (apiKey == null || apiKey.trim().isEmpty()) {
throw new IllegalArgumentException("OpenAI API key cannot be empty.");
}
if (defaultModelName == null || defaultModelName.trim().isEmpty()) {
throw new IllegalArgumentException("defaultModelName cannot be empty.");
}
this.apiKey = apiKey;
this.baseUrl = baseUrl;
this.defaultModelName = defaultModelName;
this.client = new OpenAIClient();
}

Expand All @@ -97,8 +116,22 @@ public OpenAIMultiModalTool(String apiKey, String baseUrl) {
* @param client the OpenAI client
*/
protected OpenAIMultiModalTool(OpenAIClient client) {
this(client, "gpt-4o");
}

/**
* Create a new OpenAIMultiModalTool with custom client and default model (for testing).
*
* @param client the OpenAI client
* @param defaultModelName the default vision model name
*/
protected OpenAIMultiModalTool(OpenAIClient client, String defaultModelName) {
if (defaultModelName == null || defaultModelName.trim().isEmpty()) {
throw new IllegalArgumentException("defaultModelName cannot be empty.");
}
this.apiKey = "test-key";
this.baseUrl = null;
this.defaultModelName = defaultModelName;
this.client = client;
}

Expand Down Expand Up @@ -249,7 +282,7 @@ public Mono<ToolResultBlock> openaiTextToImage(
*
* @param imageUrls the URLs of the images to analyze
* @param prompt the text prompt describing what to extract from the images
* @param model the vision model to use (e.g., "gpt-4o", "gpt-4-vision-preview")
* @param model the vision model to use (leave empty to use the configured default)
* @param maxTokens the maximum number of tokens in the response
* @return a ToolResultBlock containing the text description of the images
*/
Expand All @@ -272,8 +305,8 @@ public Mono<ToolResultBlock> openaiImageToText(
@ToolParam(
name = "model",
description =
"The vision model to use, e.g., 'gpt-4o',"
+ " 'gpt-4-vision-preview'",
"The vision model to use (leave empty to use the configured"
+ " default)",
required = false)
String model,
@ToolParam(
Expand All @@ -283,7 +316,9 @@ public Mono<ToolResultBlock> openaiImageToText(
Integer maxTokens) {

String finalModel =
Optional.ofNullable(model).filter(s -> !s.trim().isEmpty()).orElse("gpt-4o");
Optional.ofNullable(model)
.filter(s -> !s.trim().isEmpty())
.orElse(this.defaultModelName);
String finalPrompt =
Optional.ofNullable(prompt)
.filter(s -> !s.trim().isEmpty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -65,4 +67,41 @@ void testTextToImage_Url() {
assertTrue(image.getSource() instanceof URLSource);
assertEquals("https://example.com/cat.png", ((URLSource) image.getSource()).getUrl());
}

@Test
void testImageToText_usesCustomDefaultModel() {
String imageUrl = "https://example.com/image.png";
String jsonResponse =
"{\"choices\": [{\"message\": {\"content\": \"A cat sitting on a mat.\"}}]}";

when(client.callApi(
any(),
any(),
eq("/v1/chat/completions"),
argThat(
req -> {
@SuppressWarnings("unchecked")
java.util.Map<String, Object> map =
(java.util.Map<String, Object>) req;
return "my-custom-vision-model".equals(map.get("model"));
})))
.thenReturn(jsonResponse);

OpenAIMultiModalTool toolWithCustomModel =
new OpenAIMultiModalTool(client, "my-custom-vision-model");

Mono<ToolResultBlock> resultMono =
toolWithCustomModel.openaiImageToText(imageUrl, null, null, null);
ToolResultBlock result = resultMono.block();

assertNotNull(result);
}

@Test
void testConstructor_rejectsBlankDefaultModelName() {
assertThrows(
IllegalArgumentException.class, () -> new OpenAIMultiModalTool("key", null, ""));
assertThrows(
IllegalArgumentException.class, () -> new OpenAIMultiModalTool("key", null, null));
}
}
Loading