diff --git a/1/MultiModalStreamCallTests.java b/1/MultiModalStreamCallTests.java new file mode 100644 index 0000000..80249d7 --- /dev/null +++ b/1/MultiModalStreamCallTests.java @@ -0,0 +1,245 @@ +package com.example.linkcheck4j; + +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversation; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationParam; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationResult; +import com.alibaba.dashscope.common.MultiModalMessage; +import com.alibaba.dashscope.common.Role; +import com.alibaba.dashscope.exception.ApiException; +import com.alibaba.dashscope.exception.NoApiKeyException; +import com.alibaba.dashscope.exception.UploadFileException; +import io.reactivex.Flowable; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@SpringBootTest +class MultiModalStreamCallTests { + + private static final Logger logger = LoggerFactory.getLogger(MultiModalStreamCallTests.class); + + @Value("${spring.ai.dashscope.api-key}") + private String apiKey; + + private static final String IMAGE_URL = + "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241022/emyrja/dog_and_girl.jpeg"; + + // ==================== 测试用例 ==================== + + /** + * 场景1: qwen3-vl-plus 开启思考 + 增量输出(true) + * 预期: 正常增量输出,每次回调只返回新增片段 + */ + @Test + void qwen3VlPlus_thinking_incrementalOutputTrue_shouldSucceed() + throws ApiException, NoApiKeyException, UploadFileException { + StreamResult result = doStreamCall("qwen3-vl-plus", true, true, + "qwen3-vl-plus | 开思考 | 增量输出(true)"); + assertSuccessfulOutput(result); + assertIncrementalPattern(result, "qwen3-vl-plus 开思考增量输出"); + } + + /** + * 场景2: qwen3-vl-plus 开启思考 + 全量输出(false)" + * 实测: qwen3 系列开思考时不支持全量,实际仍按增量返回 + */ + @Test + void qwen3VlPlus_thinking_incrementalOutputFalse_shouldFallbackToIncremental() + throws ApiException, NoApiKeyException, UploadFileException { + StreamResult result = doStreamCall("qwen3-vl-plus", true, false, + "qwen3-vl-plus | 开思考 | 全量输出(false) -> 实际仍为增量"); + assertSuccessfulOutput(result); + } + + /** + * 场景3: qwen3-vl-plus 不开思考 + 增量输出(true) + * 预期: 正常增量输出 + */ + @Test + void qwen3VlPlus_noThinking_incrementalOutputTrue_shouldSucceed() + throws ApiException, NoApiKeyException, UploadFileException { + StreamResult result = doStreamCall("qwen3-vl-plus", false, true, + "qwen3-vl-plus | 不开思考 | 增量输出(true)"); + assertSuccessfulOutput(result); + assertIncrementalPattern(result, "qwen3-vl-plus 不开思考增量输出"); + } + + /** + * 场景4: qwen3-vl-plus 不开思考 + 全量输出(false)" + * 实测: 不开启思考时设置全量输出不会生效 + */ + @Test + void qwen3VlPlus_noThinking_incrementalOutputFalse_fullOutputNotEffective() + throws ApiException, NoApiKeyException, UploadFileException { + StreamResult result = doStreamCall("qwen3-vl-plus", false, false, + "qwen3-vl-plus | 不开思考 | 全量输出(false) -> 不生效"); + assertSuccessfulOutput(result); + } + + /** + * 场景5: qwen-vl-plus 增量输出(true) + * 预期: 正常增量输出 + */ + @Test + void qwenVlPlus_incrementalOutputTrue_shouldSucceed() + throws ApiException, NoApiKeyException, UploadFileException { + StreamResult result = doStreamCall("qwen-vl-plus", false, true, + "qwen-vl-plus | 增量输出(true)"); + assertSuccessfulOutput(result); + assertIncrementalPattern(result, "qwen-vl-plus 增量输出"); + } + + /** + * 场景6: qwen-vl-plus 全量输出(false)" + * 实测: 可以成功全量输出,每次回调返回累积完整内容 + */ + @Test + void qwenVlPlus_incrementalOutputFalse_shouldBeFullOutput() + throws ApiException, NoApiKeyException, UploadFileException { + StreamResult result = doStreamCall("qwen-vl-plus", false, false, + "qwen-vl-plus | 全量输出(false) -> 正常"); + assertSuccessfulOutput(result); + assertFullOutputPattern(result, "qwen-vl-plus 全量输出"); + } + + // ==================== 公共方法 ==================== + + /** 构建多模态用户消息(图片 + 文本) */ + private MultiModalMessage buildUserMessage(String text) { + return MultiModalMessage.builder() + .role(Role.USER.getValue()) + .content(Arrays.asList( + Collections.singletonMap("image", IMAGE_URL), + Collections.singletonMap("text", text))) + .build(); + } + + /** + * 执行流式调用并打印结果 + * + * @param model 模型名称 + * @param enableThinking 是否开启思考 + * @param incrementalOutput true=增量输出 false=全量输出 + * @param label 测试场景标签(用于打印区分) + * @return 流式调用结果封装 + */ + private StreamResult doStreamCall(String model, boolean enableThinking, boolean incrementalOutput, String label) + throws ApiException, NoApiKeyException, UploadFileException { + System.out.println("\n=================================================="); + System.out.println("[" + label + "]"); + System.out.println(" model: " + model); + System.out.println(" enableThinking: " + enableThinking); + System.out.println(" incrementalOutput: " + incrementalOutput); + System.out.println("--------------------------------------------------"); + + MultiModalConversation conv = new MultiModalConversation(); + MultiModalMessage userMessage = buildUserMessage("图中描绘的是什么景象?"); + + var builder = MultiModalConversationParam.builder() + .apiKey(apiKey) + .model(model) + .messages(Arrays.asList(userMessage)) + .incrementalOutput(incrementalOutput); + if (enableThinking) { + builder.enableThinking(true); + } + MultiModalConversationParam param = builder.build(); + + List fragments = new ArrayList<>(); + AtomicInteger printedLength = new AtomicInteger(0); + Flowable result = conv.streamCall(param); + result.blockingForEach(item -> { + try { + var content = item.getOutput().getChoices().get(0).getMessage().getContent(); + if (content != null && !content.isEmpty()) { + Object textObj = content.get(0).get("text"); + String text = textObj == null ? "" : textObj.toString(); + if (!text.isEmpty()) { + fragments.add(text); + // 打印原始片段,用于直观判断 SDK 实际返回的是增量还是全量 + String preview = text.length() > 80 + ? text.substring(0, 80).replace("\n", "\\n") + "..." + : text.replace("\n", "\\n"); + System.out.println("[原始片段 " + fragments.size() + "] len=" + text.length() + " -> " + preview); + if (incrementalOutput) { + // 增量模式:直接打印每次返回的片段 + System.out.print(text); + } else { + // 全量模式:只打印相比上次新增的部分 + int prev = printedLength.get(); + if (text.length() > prev) { + System.out.print(text.substring(prev)); + printedLength.set(text.length()); + } + } + } + } + } catch (Exception e) { + logger.warn("Parse item failed: {}", e.getMessage()); + } + }); + System.out.println("\n--------------------------------------------------"); + System.out.println("[汇总] 场景: " + label); + System.out.println(" 原始片段总数: " + fragments.size()); + if (!fragments.isEmpty()) { + System.out.println(" 第一个片段长度: " + fragments.get(0).length()); + System.out.println(" 最后一个片段长度: " + fragments.get(fragments.size() - 1).length()); + // 判断长度是否单调不减(全量特征)或普遍很小(增量特征) + boolean monotonicNonDecreasing = true; + int increases = 0; + for (int i = 1; i < fragments.size(); i++) { + if (fragments.get(i).length() < fragments.get(i - 1).length()) { + monotonicNonDecreasing = false; + } else if (fragments.get(i).length() > fragments.get(i - 1).length()) { + increases++; + } + } + System.out.println(" 长度单调不减: " + monotonicNonDecreasing + " (增长次数=" + increases + ")"); + System.out.println(" 判定: " + (monotonicNonDecreasing && increases > fragments.size() / 2 ? "全量返回" : "增量返回")); + } + System.out.println("==================================================\n"); + + String completeText = incrementalOutput + ? String.join("", fragments) + : (fragments.isEmpty() ? "" : fragments.get(fragments.size() - 1)); + return new StreamResult(fragments, completeText, !fragments.isEmpty(), incrementalOutput); + } + + private void assertSuccessfulOutput(StreamResult result) { + assertTrue(result.hasContent(), "流式调用应返回非空内容"); + assertFalse(result.completeText().isBlank(), "最终输出文本不应为空"); + } + + private void assertIncrementalPattern(StreamResult result, String message) { + assertTrue(result.incrementalOutput(), "该断言仅适用于 incrementalOutput=true 的场景"); + String concatenated = String.join("", result.fragments()); + assertEquals(result.completeText(), concatenated, + message + ":增量片段拼接后应等于完整回复"); + } + + private void assertFullOutputPattern(StreamResult result, String message) { + assertFalse(result.incrementalOutput(), "该断言仅适用于 incrementalOutput=false 的场景"); + List fragments = result.fragments(); + for (int i = 1; i < fragments.size(); i++) { + assertTrue(fragments.get(i).length() >= fragments.get(i - 1).length(), + message + ":全量输出片段长度应单调不减"); + } + assertEquals(fragments.get(fragments.size() - 1), result.completeText(), + message + ":最后一个片段应等于完整回复"); + } + + private record StreamResult(List fragments, String completeText, boolean hasContent, boolean incrementalOutput) { + } +} diff --git a/1/TestQwen3VlPlusFullOutput.class b/1/TestQwen3VlPlusFullOutput.class new file mode 100644 index 0000000..e2e80ea Binary files /dev/null and b/1/TestQwen3VlPlusFullOutput.class differ diff --git a/1/TestQwen3VlPlusFullOutput.java b/1/TestQwen3VlPlusFullOutput.java new file mode 100644 index 0000000..806a6e8 --- /dev/null +++ b/1/TestQwen3VlPlusFullOutput.java @@ -0,0 +1,82 @@ +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversation; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationParam; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationResult; +import com.alibaba.dashscope.common.MultiModalMessage; +import com.alibaba.dashscope.common.Role; +import io.reactivex.Flowable; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.ArrayList; + +public class TestQwen3VlPlusFullOutput { + + private static final String IMAGE_URL = + "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241022/emyrja/dog_and_girl.jpeg"; + + public static void main(String[] args) throws Exception { + String apiKey = args.length > 0 ? args[0] : System.getenv("DASHSCOPE_API_KEY"); + if (apiKey == null || apiKey.isEmpty()) { + System.err.println("Usage: java TestQwen3VlPlusFullOutput "); + System.err.println("Or set DASHSCOPE_API_KEY environment variable."); + System.exit(1); + } + + MultiModalConversation conv = new MultiModalConversation(); + MultiModalMessage userMessage = MultiModalMessage.builder() + .role(Role.USER.getValue()) + .content(Arrays.asList( + Collections.singletonMap("image", IMAGE_URL), + Collections.singletonMap("text", "图中描绘的是什么景象?"))) + .build(); + + MultiModalConversationParam param = MultiModalConversationParam.builder() + .apiKey(apiKey) + //.model("qwen3-vl-plus") + .model("qwen-vl-max") + .messages(Arrays.asList(userMessage)) + //.incrementalOutput(false) + .build(); + + System.out.println("=== qwen3-vl-plus | incrementalOutput=false | streamCall ===\n"); + + List fragments = new ArrayList<>(); + Flowable result = conv.streamCall(param); + result.blockingForEach(item -> { + try { + List> content = item.getOutput().getChoices().get(0).getMessage().getContent(); + if (content != null && !content.isEmpty()) { + Object textObj = content.get(0).get("text"); + String text = textObj == null ? "" : textObj.toString(); + if (!text.isEmpty()) { + fragments.add(text); + int preview = Math.min(text.length(), 80); + System.out.printf("[chunk %d] len=%d -> %s%n", + fragments.size(), text.length(), + text.substring(0, preview).replace("\n", "\\n")); + } + } + } catch (Exception e) { + System.err.println("Parse error: " + e.getMessage()); + } + }); + + System.out.println("\n--- Summary ---"); + System.out.println("Total chunks: " + fragments.size()); + + if (!fragments.isEmpty()) { + boolean monotonic = true; + for (int i = 1; i < fragments.size(); i++) { + if (fragments.get(i).length() < fragments.get(i - 1).length()) { + monotonic = false; + break; + } + } + System.out.println("Lengths monotonically non-decreasing: " + monotonic); + System.out.println("Output type: " + (monotonic ? "FULL (expected)" : "INCREMENTAL (unexpected)")); + System.out.println("\nFinal text:\n" + fragments.get(fragments.size() - 1)); + } + } +} diff --git a/pom.xml b/pom.xml index 55d9df8..14f0b2a 100644 --- a/pom.xml +++ b/pom.xml @@ -40,7 +40,7 @@ DashScope Java SDK com.alibaba dashscope-sdk-java - 2.22.23 + 2.22.24 8 diff --git a/src/main/java/com/alibaba/dashscope/aigc/conversation/ConversationResult.java b/src/main/java/com/alibaba/dashscope/aigc/conversation/ConversationResult.java index eaafca4..f517363 100644 --- a/src/main/java/com/alibaba/dashscope/aigc/conversation/ConversationResult.java +++ b/src/main/java/com/alibaba/dashscope/aigc/conversation/ConversationResult.java @@ -22,11 +22,16 @@ public static ConversationResult fromDashScopeResult(DashScopeResult dashScopeRe ConversationResult result = new ConversationResult(); result.setRequestId(dashScopeResult.getRequestId()); result.setHeaders(dashScopeResult.getHeaders()); - result.setUsage( - JsonUtils.fromJsonObject( - dashScopeResult.getUsage().getAsJsonObject(), GenerationUsage.class)); - result.setOutput( - JsonUtils.fromJsonObject((JsonObject) dashScopeResult.getOutput(), GenerationOutput.class)); + if (dashScopeResult.getUsage() != null) { + result.setUsage( + JsonUtils.fromJsonObject( + dashScopeResult.getUsage().getAsJsonObject(), GenerationUsage.class)); + } + if (dashScopeResult.getOutput() instanceof JsonObject) { + result.setOutput( + JsonUtils.fromJsonObject( + (JsonObject) dashScopeResult.getOutput(), GenerationOutput.class)); + } return result; } } diff --git a/src/main/java/com/alibaba/dashscope/aigc/generation/GenerationResult.java b/src/main/java/com/alibaba/dashscope/aigc/generation/GenerationResult.java index cc270db..918b8f0 100644 --- a/src/main/java/com/alibaba/dashscope/aigc/generation/GenerationResult.java +++ b/src/main/java/com/alibaba/dashscope/aigc/generation/GenerationResult.java @@ -38,12 +38,14 @@ public static GenerationResult fromDashScopeResult(DashScopeResult dashScopeResu JsonUtils.fromJsonObject( dashScopeResult.getUsage().getAsJsonObject(), GenerationUsage.class)); } - if (dashScopeResult.getOutput() != null) { + if (dashScopeResult.getOutput() instanceof JsonObject) { result.setOutput( JsonUtils.fromJsonObject( (JsonObject) dashScopeResult.getOutput(), GenerationOutput.class)); } else { - log.error(StringUtils.format("Result no output: %s", dashScopeResult)); + log.error( + StringUtils.format( + "Result no output or output is not a JsonObject: %s", dashScopeResult)); } return result; } diff --git a/src/main/java/com/alibaba/dashscope/aigc/imagegeneration/ImageGenerationResult.java b/src/main/java/com/alibaba/dashscope/aigc/imagegeneration/ImageGenerationResult.java index 69093b0..7a90e34 100644 --- a/src/main/java/com/alibaba/dashscope/aigc/imagegeneration/ImageGenerationResult.java +++ b/src/main/java/com/alibaba/dashscope/aigc/imagegeneration/ImageGenerationResult.java @@ -36,12 +36,12 @@ public static ImageGenerationResult fromDashScopeResult(DashScopeResult dashScop JsonUtils.fromJsonObject( dashScopeResult.getUsage().getAsJsonObject(), ImageGenerationUsage.class)); } - if (dashScopeResult.getOutput() != null) { + if (dashScopeResult.getOutput() instanceof JsonObject) { result.setOutput( JsonUtils.fromJsonObject( (JsonObject) dashScopeResult.getOutput(), ImageGenerationOutput.class)); } else { - log.error("Result no output: {}", dashScopeResult); + log.error("Result no output or output is not a JsonObject: {}", dashScopeResult); } return result; } diff --git a/src/main/java/com/alibaba/dashscope/aigc/imagesynthesis/ImageSynthesisResult.java b/src/main/java/com/alibaba/dashscope/aigc/imagesynthesis/ImageSynthesisResult.java index 74a8a4f..76f8229 100644 --- a/src/main/java/com/alibaba/dashscope/aigc/imagesynthesis/ImageSynthesisResult.java +++ b/src/main/java/com/alibaba/dashscope/aigc/imagesynthesis/ImageSynthesisResult.java @@ -41,12 +41,14 @@ public static ImageSynthesisResult fromDashScopeResult(DashScopeResult dashScope JsonUtils.fromJsonObject( dashScopeResult.getUsage().getAsJsonObject(), ImageSynthesisUsage.class)); } - if (dashScopeResult.getOutput() != null) { + if (dashScopeResult.getOutput() instanceof JsonObject) { result.setOutput( JsonUtils.fromJsonObject( (JsonObject) dashScopeResult.getOutput(), ImageSynthesisOutput.class)); } else { - log.error(StringUtils.format("Result no output: %s", dashScopeResult)); + log.error( + StringUtils.format( + "Result no output or output is not a JsonObject: %s", dashScopeResult)); } return result; } diff --git a/src/main/java/com/alibaba/dashscope/aigc/multimodalconversation/MultiModalConversationResult.java b/src/main/java/com/alibaba/dashscope/aigc/multimodalconversation/MultiModalConversationResult.java index 309ddc0..9492535 100644 --- a/src/main/java/com/alibaba/dashscope/aigc/multimodalconversation/MultiModalConversationResult.java +++ b/src/main/java/com/alibaba/dashscope/aigc/multimodalconversation/MultiModalConversationResult.java @@ -36,12 +36,12 @@ public static MultiModalConversationResult fromDashScopeResult(DashScopeResult d JsonUtils.fromJsonObject( dashScopeResult.getUsage().getAsJsonObject(), MultiModalConversationUsage.class)); } - if (dashScopeResult.getOutput() != null) { + if (dashScopeResult.getOutput() instanceof JsonObject) { result.setOutput( JsonUtils.fromJsonObject( (JsonObject) dashScopeResult.getOutput(), MultiModalConversationOutput.class)); } else { - log.error("Result no output: {}", dashScopeResult); + log.error("Result no output or output is not a JsonObject: {}", dashScopeResult); } return result; } diff --git a/src/main/java/com/alibaba/dashscope/app/ApplicationResult.java b/src/main/java/com/alibaba/dashscope/app/ApplicationResult.java index 1d5260d..4d58eea 100644 --- a/src/main/java/com/alibaba/dashscope/app/ApplicationResult.java +++ b/src/main/java/com/alibaba/dashscope/app/ApplicationResult.java @@ -56,12 +56,14 @@ public static ApplicationResult fromDashScopeResult(DashScopeResult dashScopeRes JsonUtils.fromJsonObject( dashScopeResult.getUsage().getAsJsonObject(), ApplicationUsage.class)); } - if (dashScopeResult.getOutput() != null) { + if (dashScopeResult.getOutput() instanceof JsonObject) { result.setOutput( JsonUtils.fromJsonObject( (JsonObject) dashScopeResult.getOutput(), ApplicationOutput.class)); } else { - log.error(StringUtils.format("Result no output: %s", dashScopeResult)); + log.error( + StringUtils.format( + "Result no output or output is not a JsonObject: %s", dashScopeResult)); } return result; diff --git a/src/main/java/com/alibaba/dashscope/audio/omni/OmniRealtimeConversation.java b/src/main/java/com/alibaba/dashscope/audio/omni/OmniRealtimeConversation.java index 2b2dcfd..27e7f3f 100644 --- a/src/main/java/com/alibaba/dashscope/audio/omni/OmniRealtimeConversation.java +++ b/src/main/java/com/alibaba/dashscope/audio/omni/OmniRealtimeConversation.java @@ -39,6 +39,19 @@ public class OmniRealtimeConversation extends WebSocketListener { private long lastFirstAudioDelay = -1; private long lastFirstTextDelay = -1; private AtomicBoolean isClosed = new AtomicBoolean(false); + + /** Immutable holder for WebSocket close code and reason, updated atomically. */ + private static class CloseInfo { + final int code; + final String reason; + + CloseInfo(int code, String reason) { + this.code = code; + this.reason = reason; + } + } + + private final AtomicReference closeInfo = new AtomicReference<>(null); private final AtomicReference disconnectLatch = new AtomicReference<>(null); /** @@ -55,13 +68,26 @@ public OmniRealtimeConversation(OmniRealtimeParam param, OmniRealtimeCallback ca /** Omni APIs */ public void checkStatus() { if (this.isClosed.get()) { - throw new RuntimeException("conversation is already closed!"); + String msg = "conversation is already closed!"; + CloseInfo ci = closeInfo.get(); + if (ci != null && ci.code >= 0) { + msg = msg + " (code=" + ci.code + ", reason=" + ci.reason + ")"; + } + throw new RuntimeException(msg); + } + if (!this.isOpen.get()) { + throw new RuntimeException("conversation is not connected!"); } } /** Connect to server, create session and return default session configuration */ public void connect() throws NoApiKeyException, InterruptedException { - checkStatus(); + if (isClosed.get()) { + throw new RuntimeException("conversation is already closed!"); + } + if (isOpen.get()) { + throw new RuntimeException("conversation is already connected!"); + } Request request = buildConnectionRequest( ApiKey.getApiKey(parameters.getApikey()), @@ -250,9 +276,7 @@ public void cancelResponse() { /** close the connection to server */ public void close() { - checkStatus(); - websocktetClient.close(1000, "bye"); - isClosed.set(true); + close(1000, "bye"); } /** @@ -262,9 +286,13 @@ public void close() { * @param reason websocket close reason */ public void close(int code, String reason) { - checkStatus(); - websocktetClient.close(code, reason); - isClosed.set(true); + if (!isClosed.compareAndSet(false, true)) { + return; + } + if (websocktetClient != null) { + websocktetClient.close(code, reason); + } + isOpen.set(false); } /** @@ -340,6 +368,9 @@ private void sendMessage(String message, boolean enableLog) { log.debug("send message: " + message); } Boolean isOk = websocktetClient.send(message); + if (!isOk) { + throw new RuntimeException("failed to send message"); + } } private void sendMessage(ByteString message) { @@ -410,21 +441,35 @@ public void onMessage(WebSocket webSocket, String text) { @Override public void onClosed(WebSocket webSocket, int code, String reason) { isOpen.set(false); + isClosed.set(true); + closeInfo.set(new CloseInfo(code, reason)); connectLatch.get().countDown(); + CountDownLatch latch = disconnectLatch.get(); + if (latch != null) { + latch.countDown(); + } log.debug("WebSocket closed: " + code + ", " + reason); callback.onClose(code, reason); } @Override public void onFailure(WebSocket webSocket, Throwable t, Response response) { + closeInfo.set(new CloseInfo(-1, "failure: " + t.getMessage())); + isClosed.set(true); + isOpen.set(false); connectLatch.get().countDown(); + CountDownLatch latch = disconnectLatch.get(); + if (latch != null) { + latch.countDown(); + } log.error("WebSocket failed: " + t.getMessage()); + callback.onClose(-1, "failure: " + t.getMessage()); } @Override public void onClosing(@NotNull WebSocket webSocket, int code, @NotNull String reason) { - isClosed.set(true); - websocktetClient.close(code, reason); + closeInfo.set(new CloseInfo(code, reason)); + close(code, reason); log.debug("WebSocket closing: " + code + ", " + reason); } } diff --git a/src/main/java/com/alibaba/dashscope/common/DashScopeResult.java b/src/main/java/com/alibaba/dashscope/common/DashScopeResult.java index ad71ae6..0cfe7cd 100644 --- a/src/main/java/com/alibaba/dashscope/common/DashScopeResult.java +++ b/src/main/java/com/alibaba/dashscope/common/DashScopeResult.java @@ -9,6 +9,7 @@ import com.alibaba.dashscope.utils.ApiKeywords; import com.alibaba.dashscope.utils.EncryptionUtils; import com.alibaba.dashscope.utils.JsonUtils; +import com.google.gson.JsonElement; import com.google.gson.JsonObject; import java.nio.ByteBuffer; import java.util.List; @@ -16,7 +17,9 @@ import java.util.stream.Collectors; import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.extern.slf4j.Slf4j; +@Slf4j @Data @EqualsAndHashCode(callSuper = true) public class DashScopeResult extends Result { @@ -27,6 +30,27 @@ public Boolean isBinaryOutput() { return output instanceof ByteBuffer; } + /** + * Parse the output field from a JsonElement in a type-safe manner. + * + *

Returns {@code null} for JsonNull, {@code JsonObject} for object elements, and the raw + * {@code JsonElement} for primitives/arrays to avoid {@code getAsJsonObject()} throwing. + * + * @param outputElement the JSON element representing the output field + * @return parsed output value, or {@code null} if the element is JSON null + */ + private Object parseOutputField(JsonElement outputElement) { + if (outputElement.isJsonNull()) { + return null; + } else if (outputElement.isJsonObject()) { + return outputElement.getAsJsonObject(); + } else { + // JsonPrimitive or JsonArray — return the raw element so downstream code + // can handle it gracefully instead of throwing getAsJsonObject(). + return outputElement; + } + } + @Override @SuppressWarnings("unchecked") protected T fromResponse(Protocol protocol, NetworkResponse response) @@ -72,10 +96,7 @@ protected T fromResponse(Protocol protocol, NetworkResponse r if (jsonObject.has(ApiKeywords.PAYLOAD)) { JsonObject payload = jsonObject.getAsJsonObject(ApiKeywords.PAYLOAD); if (payload.has(ApiKeywords.OUTPUT)) { - this.output = - payload.get(ApiKeywords.OUTPUT).isJsonNull() - ? null - : payload.get(ApiKeywords.OUTPUT); + this.output = parseOutputField(payload.get(ApiKeywords.OUTPUT)); } if (payload.has(ApiKeywords.USAGE)) { this.setUsage( @@ -94,10 +115,7 @@ protected T fromResponse(Protocol protocol, NetworkResponse r this.setStatusCode(response.getHttpStatusCode()); } if (jsonObject.has(ApiKeywords.OUTPUT)) { - this.output = - jsonObject.get(ApiKeywords.OUTPUT).isJsonNull() - ? null - : jsonObject.get(ApiKeywords.OUTPUT).getAsJsonObject(); + this.output = parseOutputField(jsonObject.get(ApiKeywords.OUTPUT)); } if (jsonObject.has(ApiKeywords.USAGE)) { this.setUsage( @@ -161,6 +179,7 @@ public T fromResponse( } } else { // HTTP JsonObject jsonObject = JsonUtils.parse(response.getMessage()); + // Preserve original behavior: the entire JSON object is the output for flatten mode. this.output = jsonObject; this.event = response.getEvent(); } @@ -183,17 +202,19 @@ public T fromResponse( this.setStatusCode(response.getHttpStatusCode()); } JsonObject jsonObject = JsonUtils.parse(response.getMessage()); - String encryptedOutput = - jsonObject.get(ApiKeywords.OUTPUT).isJsonNull() - ? null - : jsonObject.get(ApiKeywords.OUTPUT).getAsString(); - if (encryptedOutput != null) { - String plainOutput = - EncryptionUtils.AESDecrypt( - encryptedOutput, - req.getEncryptionConfig().getAESEncryptKey(), - req.getEncryptionConfig().getIv()); - this.output = JsonUtils.parse(plainOutput); + if (jsonObject.has(ApiKeywords.OUTPUT) && !jsonObject.get(ApiKeywords.OUTPUT).isJsonNull()) { + if (jsonObject.get(ApiKeywords.OUTPUT).isJsonPrimitive() + && jsonObject.get(ApiKeywords.OUTPUT).getAsJsonPrimitive().isString()) { + String encryptedOutput = jsonObject.get(ApiKeywords.OUTPUT).getAsString(); + String plainOutput = + EncryptionUtils.AESDecrypt( + encryptedOutput, + req.getEncryptionConfig().getAESEncryptKey(), + req.getEncryptionConfig().getIv()); + this.output = JsonUtils.parse(plainOutput); + } else { + this.output = parseOutputField(jsonObject.get(ApiKeywords.OUTPUT)); + } } else { this.output = null; } @@ -237,6 +258,43 @@ public T fromResponse( } return (T) this; } + + // Fallback: server encrypted output but did not set X-DashScope-OutputEncrypted header. + // Only attempt fallback decryption when encryption config is available to avoid false + // positives. + if (protocol == Protocol.HTTP && req.getEncryptionConfig() != null) { + try { + JsonObject jsonObject = JsonUtils.parse(response.getMessage()); + if (jsonObject.has(ApiKeywords.OUTPUT) + && !jsonObject.get(ApiKeywords.OUTPUT).isJsonNull() + && jsonObject.get(ApiKeywords.OUTPUT).isJsonPrimitive() + && jsonObject.get(ApiKeywords.OUTPUT).getAsJsonPrimitive().isString()) { + String encryptedOutput = jsonObject.get(ApiKeywords.OUTPUT).getAsString(); + String plainOutput = + EncryptionUtils.AESDecrypt( + encryptedOutput, + req.getEncryptionConfig().getAESEncryptKey(), + req.getEncryptionConfig().getIv()); + this.output = JsonUtils.parse(plainOutput); + if (response.getHttpStatusCode() != null) { + this.setStatusCode(response.getHttpStatusCode()); + } + if (jsonObject.has(ApiKeywords.USAGE)) { + this.setUsage( + jsonObject.get(ApiKeywords.USAGE).isJsonNull() + ? null + : jsonObject.get(ApiKeywords.USAGE).getAsJsonObject()); + } + if (jsonObject.has(ApiKeywords.REQUEST_ID)) { + this.setRequestId(jsonObject.get(ApiKeywords.REQUEST_ID).getAsString()); + } + return (T) this; + } + } catch (Exception e) { + log.debug("Fallback decryption failed, proceeding with normal parsing: {}", e.getMessage()); + } + } + return fromResponse(protocol, response, isFlattenResult); } diff --git a/src/test/java/com/alibaba/dashscope/audio/omni/TestOmniRealtimeConversation.java b/src/test/java/com/alibaba/dashscope/audio/omni/TestOmniRealtimeConversation.java new file mode 100644 index 0000000..51d80f9 --- /dev/null +++ b/src/test/java/com/alibaba/dashscope/audio/omni/TestOmniRealtimeConversation.java @@ -0,0 +1,228 @@ +// Copyright (c) Alibaba, Inc. and its affiliates. +package com.alibaba.dashscope.audio.omni; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import com.google.gson.JsonObject; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link OmniRealtimeConversation} focusing on WebSocket lifecycle: close() + * idempotency, onFailure callback notification, and onClosing → checkStatus exception information. + */ +public class TestOmniRealtimeConversation { + + private static class RecordingCallback extends OmniRealtimeCallback { + final List events = new ArrayList<>(); + volatile int closeCode = Integer.MIN_VALUE; + volatile String closeReason = null; + volatile boolean openCalled = false; + final AtomicBoolean closeCalled = new AtomicBoolean(false); + + @Override + public void onOpen() { + openCalled = true; + } + + @Override + public void onEvent(JsonObject message) { + events.add(message); + } + + @Override + public void onClose(int code, String reason) { + closeCode = code; + closeReason = reason; + closeCalled.set(true); + } + } + + private OmniRealtimeConversation createConversation(RecordingCallback callback) { + OmniRealtimeParam param = + OmniRealtimeParam.builder().model("test-model").apikey("test-key").build(); + return new OmniRealtimeConversation(param, callback); + } + + private void setConnectLatch(OmniRealtimeConversation conv, CountDownLatch latch) + throws Exception { + Field f = OmniRealtimeConversation.class.getDeclaredField("connectLatch"); + f.setAccessible(true); + f.set(conv, new AtomicReference<>(latch)); + } + + private void setIsClosed(OmniRealtimeConversation conv, boolean value) throws Exception { + Field f = OmniRealtimeConversation.class.getDeclaredField("isClosed"); + f.setAccessible(true); + ((AtomicBoolean) f.get(conv)).set(value); + } + + private void setIsOpen(OmniRealtimeConversation conv, boolean value) throws Exception { + Field f = OmniRealtimeConversation.class.getDeclaredField("isOpen"); + f.setAccessible(true); + ((AtomicBoolean) f.get(conv)).set(value); + } + + private boolean getIsClosed(OmniRealtimeConversation conv) throws Exception { + Field f = OmniRealtimeConversation.class.getDeclaredField("isClosed"); + f.setAccessible(true); + return ((AtomicBoolean) f.get(conv)).get(); + } + + private boolean getIsOpen(OmniRealtimeConversation conv) throws Exception { + Field f = OmniRealtimeConversation.class.getDeclaredField("isOpen"); + f.setAccessible(true); + return ((AtomicBoolean) f.get(conv)).get(); + } + + @Test + public void testCloseIdempotent() throws Exception { + RecordingCallback callback = new RecordingCallback(); + OmniRealtimeConversation conv = createConversation(callback); + + conv.close(1000, "bye"); + assertTrue(getIsClosed(conv)); + assertFalse(getIsOpen(conv)); + + // Second close — should be a no-op + conv.close(1001, "second"); + assertTrue(getIsClosed(conv)); + assertFalse(getIsOpen(conv)); + } + + @Test + public void testOnFailureCallsCallback() throws Exception { + RecordingCallback callback = new RecordingCallback(); + OmniRealtimeConversation conv = createConversation(callback); + setConnectLatch(conv, new CountDownLatch(1)); + + Throwable testError = new RuntimeException("connection reset"); + conv.onFailure(null, testError, null); + + assertTrue(callback.closeCalled.get()); + assertEquals(-1, callback.closeCode); + assertEquals("failure: connection reset", callback.closeReason); + assertTrue(getIsClosed(conv)); + assertFalse(getIsOpen(conv)); + } + + @Test + public void testOnClosingCheckStatusThrowsWithInfo() throws Exception { + RecordingCallback callback = new RecordingCallback(); + OmniRealtimeConversation conv = createConversation(callback); + setConnectLatch(conv, new CountDownLatch(1)); + setIsOpen(conv, true); + + conv.onClosing(null, 1011, "server error"); + + try { + conv.checkStatus(); + fail("checkStatus should throw RuntimeException after onClosing"); + } catch (RuntimeException e) { + String msg = e.getMessage(); + assertTrue(msg.contains("already closed")); + assertTrue(msg.contains("1011")); + assertTrue(msg.contains("server error")); + } + } + + @Test + public void testCheckStatusNotConnected() { + RecordingCallback callback = new RecordingCallback(); + OmniRealtimeConversation conv = createConversation(callback); + + try { + conv.checkStatus(); + fail("checkStatus should throw when not connected"); + } catch (RuntimeException e) { + assertTrue(e.getMessage().contains("not connected")); + } + } + + @Test + public void testCheckStatusClosedNoInfo() throws Exception { + RecordingCallback callback = new RecordingCallback(); + OmniRealtimeConversation conv = createConversation(callback); + setIsClosed(conv, true); + + try { + conv.checkStatus(); + fail("checkStatus should throw when closed"); + } catch (RuntimeException e) { + assertTrue(e.getMessage().contains("already closed")); + assertFalse(e.getMessage().contains("code=")); + } + } + + @Test + public void testConnectThrowsWhenClosed() throws Exception { + RecordingCallback callback = new RecordingCallback(); + OmniRealtimeConversation conv = createConversation(callback); + setIsClosed(conv, true); + + try { + conv.connect(); + fail("connect() should throw when already closed"); + } catch (RuntimeException e) { + assertTrue(e.getMessage().contains("already closed")); + } catch (Exception e) { + fail("Expected RuntimeException, got: " + e.getClass().getName()); + } + } + + @Test + public void testConnectThrowsWhenAlreadyOpen() throws Exception { + RecordingCallback callback = new RecordingCallback(); + OmniRealtimeConversation conv = createConversation(callback); + setIsOpen(conv, true); + + try { + conv.connect(); + fail("connect() should throw when already connected"); + } catch (RuntimeException e) { + assertTrue(e.getMessage().contains("already connected")); + } catch (Exception e) { + fail("Expected RuntimeException, got: " + e.getClass().getName()); + } + } + + @Test + public void testOnClosedCallsCallback() throws Exception { + RecordingCallback callback = new RecordingCallback(); + OmniRealtimeConversation conv = createConversation(callback); + setConnectLatch(conv, new CountDownLatch(1)); + setIsOpen(conv, true); + + conv.onClosed(null, 1000, "normal closure"); + + assertTrue(callback.closeCalled.get()); + assertEquals(1000, callback.closeCode); + assertEquals("normal closure", callback.closeReason); + assertTrue(getIsClosed(conv)); + assertFalse(getIsOpen(conv)); + } + + @Test + public void testOnFailureReleasesDisconnectLatch() throws Exception { + RecordingCallback callback = new RecordingCallback(); + OmniRealtimeConversation conv = createConversation(callback); + setConnectLatch(conv, new CountDownLatch(1)); + + CountDownLatch disconnectLatch = new CountDownLatch(1); + Field f = OmniRealtimeConversation.class.getDeclaredField("disconnectLatch"); + f.setAccessible(true); + f.set(conv, new AtomicReference<>(disconnectLatch)); + + conv.onFailure(null, new RuntimeException("test failure"), null); + + assertEquals(0, disconnectLatch.getCount()); + } +} diff --git a/src/test/java/com/alibaba/dashscope/common/TestDashScopeResult.java b/src/test/java/com/alibaba/dashscope/common/TestDashScopeResult.java new file mode 100644 index 0000000..26ae3b2 --- /dev/null +++ b/src/test/java/com/alibaba/dashscope/common/TestDashScopeResult.java @@ -0,0 +1,320 @@ +// Copyright (c) Alibaba, Inc. and its affiliates. +package com.alibaba.dashscope.common; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.alibaba.dashscope.base.HalfDuplexParamBase; +import com.alibaba.dashscope.exception.ApiException; +import com.alibaba.dashscope.protocol.HalfDuplexRequest; +import com.alibaba.dashscope.protocol.HttpMethod; +import com.alibaba.dashscope.protocol.NetworkResponse; +import com.alibaba.dashscope.protocol.Protocol; +import com.alibaba.dashscope.protocol.ServiceOption; +import com.alibaba.dashscope.protocol.StreamingMode; +import com.alibaba.dashscope.utils.EncryptionConfig; +import com.alibaba.dashscope.utils.EncryptionUtils; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import javax.crypto.SecretKey; +import lombok.experimental.SuperBuilder; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link DashScopeResult} focusing on output field parsing type-safety, flatten mode + * behavior, and encryption fallback decryption. + */ +public class TestDashScopeResult { + + private NetworkResponse buildHttpResponse(String body) { + return NetworkResponse.builder() + .message(body) + .headers(new HashMap<>()) + .httpStatusCode(200) + .build(); + } + + private NetworkResponse buildHttpResponse(String body, Map> headers) { + return NetworkResponse.builder().message(body).headers(headers).httpStatusCode(200).build(); + } + + @Test + public void testOutputAsJsonObject() throws ApiException { + String json = "{\"output\":{\"text\":\"hello\"},\"request_id\":\"req-1\"}"; + NetworkResponse resp = buildHttpResponse(json); + DashScopeResult result = new DashScopeResult(); + result.fromResponse(Protocol.HTTP, resp); + + assertNotNull(result.getOutput()); + assertTrue(result.getOutput() instanceof JsonObject); + JsonObject output = (JsonObject) result.getOutput(); + assertEquals("hello", output.get("text").getAsString()); + assertEquals("req-1", result.getRequestId()); + } + + @Test + public void testOutputAsJsonPrimitiveNoThrow() throws ApiException { + String json = "{\"output\":\"some-plain-string\",\"request_id\":\"req-2\"}"; + NetworkResponse resp = buildHttpResponse(json); + DashScopeResult result = new DashScopeResult(); + result.fromResponse(Protocol.HTTP, resp); + + assertNotNull(result.getOutput()); + assertTrue(result.getOutput() instanceof JsonElement); + assertEquals("some-plain-string", ((JsonElement) result.getOutput()).getAsString()); + } + + @Test + public void testOutputAsJsonNull() throws ApiException { + String json = "{\"output\":null,\"request_id\":\"req-3\"}"; + NetworkResponse resp = buildHttpResponse(json); + DashScopeResult result = new DashScopeResult(); + result.fromResponse(Protocol.HTTP, resp); + + assertNull(result.getOutput()); + } + + @Test + public void testOutputFieldAbsent() throws ApiException { + String json = "{\"request_id\":\"req-4\",\"code\":\"0\",\"message\":\"ok\"}"; + NetworkResponse resp = buildHttpResponse(json); + DashScopeResult result = new DashScopeResult(); + result.fromResponse(Protocol.HTTP, resp); + + assertNull(result.getOutput()); + assertEquals("req-4", result.getRequestId()); + assertEquals("0", result.getCode()); + } + + @Test + public void testIsFlattenHttpReturnsEntireJson() throws ApiException { + String json = + "{\"output\":{\"text\":\"hello\"},\"request_id\":\"req-5\",\"usage\":{\"total\":10}}"; + NetworkResponse resp = buildHttpResponse(json); + DashScopeResult result = new DashScopeResult(); + result.fromResponse(Protocol.HTTP, resp, true); + + assertNotNull(result.getOutput()); + assertTrue(result.getOutput() instanceof JsonObject); + JsonObject output = (JsonObject) result.getOutput(); + assertTrue(output.has("output")); + assertTrue(output.has("request_id")); + assertTrue(output.has("usage")); + assertEquals("hello", output.getAsJsonObject("output").get("text").getAsString()); + } + + @Test + public void testIsFlattenWebSocketReturnsEntireJson() throws ApiException { + String json = + "{\"header\":{\"task_id\":\"task-1\"},\"payload\":{\"output\":{\"text\":\"hi\"}}}"; + NetworkResponse resp = buildHttpResponse(json); + DashScopeResult result = new DashScopeResult(); + result.fromResponse(Protocol.WEBSOCKET, resp, true); + + assertNotNull(result.getOutput()); + assertTrue(result.getOutput() instanceof JsonObject); + JsonObject output = (JsonObject) result.getOutput(); + assertTrue(output.has("header")); + assertTrue(output.has("payload")); + } + + @Test + public void testWebSocketNonFlattenOutput() throws ApiException { + String json = + "{\"header\":{\"task_id\":\"task-2\",\"status_code\":200}," + + "\"payload\":{\"output\":{\"text\":\"ws-hello\"}}}"; + NetworkResponse resp = buildHttpResponse(json); + DashScopeResult result = new DashScopeResult(); + result.fromResponse(Protocol.WEBSOCKET, resp); + + assertNotNull(result.getOutput()); + assertTrue(result.getOutput() instanceof JsonObject); + assertEquals("ws-hello", ((JsonObject) result.getOutput()).get("text").getAsString()); + assertEquals("task-2", result.getRequestId()); + assertEquals(Integer.valueOf(200), result.getStatusCode()); + } + + @Test + public void testOutputWithDataField() throws ApiException { + String json = "{\"data\":{\"key\":\"val\"},\"request_id\":\"req-8\"}"; + NetworkResponse resp = buildHttpResponse(json); + DashScopeResult result = new DashScopeResult(); + result.fromResponse(Protocol.HTTP, resp); + + assertNotNull(result.getOutput()); + assertTrue(result.getOutput() instanceof JsonObject); + JsonObject output = (JsonObject) result.getOutput(); + assertTrue(output.has("data")); + assertFalse(output.has("request_id")); + } + + @Test + public void testEncryptionFallbackDecryption() throws Exception { + SecretKey aesKey = EncryptionUtils.generateAESKey(); + byte[] iv = new byte[12]; + new java.security.SecureRandom().nextBytes(iv); + + String plainOutput = "{\"text\":\"decrypted-content\"}"; + String encryptedOutput = EncryptionUtils.AESEncrypt(plainOutput, aesKey, iv); + + String json = "{\"output\":\"" + encryptedOutput + "\",\"request_id\":\"req-9\"}"; + NetworkResponse resp = buildHttpResponse(json); + + HalfDuplexRequest req = buildTestHalfDuplexRequest(false, aesKey, iv); + + DashScopeResult result = new DashScopeResult(); + result.fromResponse(Protocol.HTTP, resp, false, req); + + assertNotNull(result.getOutput()); + assertTrue(result.getOutput() instanceof JsonObject); + assertEquals("decrypted-content", ((JsonObject) result.getOutput()).get("text").getAsString()); + assertEquals("req-9", result.getRequestId()); + } + + @Test + public void testNoFallbackWhenConfigNull() throws Exception { + String base64LikeString = + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"; + String json = "{\"output\":\"" + base64LikeString + "\",\"request_id\":\"req-10\"}"; + NetworkResponse resp = buildHttpResponse(json); + + HalfDuplexRequest req = buildTestHalfDuplexRequest(false, null, null); + + DashScopeResult result = new DashScopeResult(); + result.fromResponse(Protocol.HTTP, resp, false, req); + + assertNotNull(result.getOutput()); + assertTrue(result.getOutput() instanceof JsonElement); + assertEquals(base64LikeString, ((JsonElement) result.getOutput()).getAsString()); + } + + @Test + public void testEncryptionWithHeader() throws Exception { + SecretKey aesKey = EncryptionUtils.generateAESKey(); + byte[] iv = new byte[12]; + new java.security.SecureRandom().nextBytes(iv); + + String plainOutput = "{\"text\":\"header-decrypted\"}"; + String encryptedOutput = EncryptionUtils.AESEncrypt(plainOutput, aesKey, iv); + + String json = "{\"output\":\"" + encryptedOutput + "\",\"request_id\":\"req-11\"}"; + + Map> headers = new HashMap<>(); + headers.put("x-dashscope-outputencrypted", Arrays.asList("true")); + NetworkResponse resp = buildHttpResponse(json, headers); + + HalfDuplexRequest req = buildTestHalfDuplexRequest(true, aesKey, iv); + + DashScopeResult result = new DashScopeResult(); + result.fromResponse(Protocol.HTTP, resp, false, req); + + assertNotNull(result.getOutput()); + assertTrue(result.getOutput() instanceof JsonObject); + assertEquals("header-decrypted", ((JsonObject) result.getOutput()).get("text").getAsString()); + } + + // ---- Helpers ---- + + @SuperBuilder + private static class TestParamBase extends HalfDuplexParamBase { + @Override + public String getModel() { + return "test-model"; + } + + @Override + public Map getParameters() { + return new HashMap<>(); + } + + @Override + public Map getHeaders() { + return new HashMap<>(); + } + + @Override + public JsonObject getHttpBody() { + return new JsonObject(); + } + + @Override + public Object getInput() { + return null; + } + + @Override + public Object getResources() { + return null; + } + + @Override + public ByteBuffer getBinaryData() { + return null; + } + + @Override + public void validate() {} + } + + private static class TestServiceOption implements ServiceOption { + @Override + public StreamingMode getStreamingMode() { + return null; + } + + @Override + public Protocol getProtocol() { + return Protocol.HTTP; + } + + @Override + public HttpMethod getHttpMethod() { + return HttpMethod.POST; + } + + @Override + public String httpUrl() { + return "/test"; + } + + @Override + public String getBaseHttpUrl() { + return null; + } + + @Override + public String getBaseWebSocketUrl() { + return null; + } + } + + private HalfDuplexRequest buildTestHalfDuplexRequest( + boolean enableEncrypt, SecretKey aesKey, byte[] iv) throws Exception { + TestParamBase param = TestParamBase.builder().enableEncrypt(enableEncrypt).build(); + + HalfDuplexRequest req = new HalfDuplexRequest(param, new TestServiceOption()); + + if (aesKey != null) { + EncryptionConfig config = + EncryptionConfig.builder() + .publicKeyId("test-key-id") + .base64PublicKey("test-public-key") + .AESEncryptKey(aesKey) + .iv(iv) + .build(); + Field f = HalfDuplexRequest.class.getDeclaredField("encryptionConfig"); + f.setAccessible(true); + f.set(req, config); + } + return req; + } +} diff --git a/src/test/java/com/alibaba/dashscope/common/TestResultTypeSafety.java b/src/test/java/com/alibaba/dashscope/common/TestResultTypeSafety.java new file mode 100644 index 0000000..2a48894 --- /dev/null +++ b/src/test/java/com/alibaba/dashscope/common/TestResultTypeSafety.java @@ -0,0 +1,56 @@ +// Copyright (c) Alibaba, Inc. and its affiliates. +package com.alibaba.dashscope.common; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.alibaba.dashscope.aigc.conversation.ConversationResult; +import com.alibaba.dashscope.aigc.generation.GenerationResult; +import com.alibaba.dashscope.utils.JsonUtils; +import com.google.gson.JsonObject; +import com.google.gson.JsonPrimitive; +import org.junit.jupiter.api.Test; + +/** + * Tests that Result subclass {@code fromDashScopeResult} methods handle non-JsonObject output (e.g. + * JsonPrimitive from encrypted or malformed responses) gracefully via the {@code instanceof + * JsonObject} defensive check. + */ +public class TestResultTypeSafety { + + private DashScopeResult buildResultWithPrimitiveOutput() { + DashScopeResult dsr = new DashScopeResult(); + dsr.setRequestId("req-ts-1"); + dsr.setOutput(new JsonPrimitive("not-a-json-object")); + return dsr; + } + + @Test + public void testGenerationResultWithPrimitiveOutput() { + DashScopeResult dsr = buildResultWithPrimitiveOutput(); + GenerationResult result = GenerationResult.fromDashScopeResult(dsr); + assertNull(result.getOutput()); + assertEquals("req-ts-1", result.getRequestId()); + } + + @Test + public void testConversationResultWithPrimitiveOutput() { + DashScopeResult dsr = buildResultWithPrimitiveOutput(); + ConversationResult result = ConversationResult.fromDashScopeResult(dsr); + assertNull(result.getOutput()); + assertEquals("req-ts-1", result.getRequestId()); + } + + @Test + public void testGenerationResultWithJsonObjectOutput() { + DashScopeResult dsr = new DashScopeResult(); + dsr.setRequestId("req-ts-2"); + JsonObject outputJson = JsonUtils.parse("{\"choices\":[],\"text\":\"hello\"}"); + dsr.setOutput(outputJson); + + GenerationResult result = GenerationResult.fromDashScopeResult(dsr); + assertTrue(result.getOutput() != null); + assertEquals("req-ts-2", result.getRequestId()); + } +}