diff --git a/src/main/java/com/alibaba/dashscope/audio/asr/recognition/timestamp/Sentence.java b/src/main/java/com/alibaba/dashscope/audio/asr/recognition/timestamp/Sentence.java index 3a721a0..d49e795 100644 --- a/src/main/java/com/alibaba/dashscope/audio/asr/recognition/timestamp/Sentence.java +++ b/src/main/java/com/alibaba/dashscope/audio/asr/recognition/timestamp/Sentence.java @@ -56,6 +56,9 @@ public class Sentence { @SerializedName("sentence_end") boolean sentenceEnd; + @SerializedName("speaker_id") + String speakerId; + public static Sentence from(String message) { return JsonUtils.fromJson(message, Sentence.class); } diff --git a/src/main/java/com/alibaba/dashscope/multimodal/MultiModalDialog.java b/src/main/java/com/alibaba/dashscope/multimodal/MultiModalDialog.java index 64c463b..4672c26 100644 --- a/src/main/java/com/alibaba/dashscope/multimodal/MultiModalDialog.java +++ b/src/main/java/com/alibaba/dashscope/multimodal/MultiModalDialog.java @@ -21,6 +21,8 @@ import java.util.Queue; import java.util.UUID; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import lombok.Builder; import lombok.Getter; @@ -64,7 +66,9 @@ private static class AsyncCmdBuffer { // Asynchronous command buffer class private final Queue DialogBuffer = new LinkedList<>(); // Dialogue buffer queue - private AtomicReference stopLatch = + private final AtomicBoolean closed = new AtomicBoolean(false); // Idempotent close guard + + private final AtomicReference stopLatch = new AtomicReference<>(null); // Stop signal latch @SuperBuilder @@ -209,7 +213,8 @@ public void start() { }, BackpressureStrategy.BUFFER); - stopLatch = new AtomicReference<>(new CountDownLatch(1)); // Initializes stop signal latch + closed.set(false); // Reset idempotent close guard for reuse across sessions + stopLatch.set(new CountDownLatch(1)); // Initializes stop signal latch String preTaskId = requestParam.getTaskId() != null ? requestParam.getTaskId() : UUID.randomUUID().toString(); @@ -440,17 +445,69 @@ public void updateInfo(MultiModalRequestParam.UpdateParams updateParams) { sendTextFrame("UpdateInfo"); } - /** Stops the MultiModalDialog. */ + /** + * Stops the MultiModalDialog gracefully. Sends a finish message to the server and waits for the + * server to acknowledge. If the server does not respond within the timeout (30s), falls back to + * force close. + * + *

This method is safe to call after onError — it will not deadlock. + */ public void stop() { - sendFinishTaskMessage(); - if (stopLatch.get() != null) { + boolean sent = sendFinishTaskMessage(); + + if (!sent) { + // emitter was null — finish message not sent, no point waiting; force close directly + close(); + return; + } + + CountDownLatch latch = stopLatch.get(); + if (latch != null) { try { - stopLatch.get().await(); - } catch (InterruptedException ignored) { + // Use timeout to prevent deadlock: if server doesn't respond in 30s, force close + boolean completed = latch.await(30, TimeUnit.SECONDS); + if (!completed) { + log.warn("stop() timed out waiting for server acknowledgement, forcing close"); + close(); + } + } catch (InterruptedException e) { + close(); + Thread.currentThread().interrupt(); } } } + /** + * Force closes the dialog immediately without sending any message to the server. Safe to call + * from any callback thread (including onError/onStopped). Does not block. + * + *

This method: - Nullifies emitter to prevent further data sending - Disposes upstream + * subscription to stop sendBinaryWithRetry - Forces WebSocket cancel (non-blocking) - Sets + * isClosed to short-circuit any reconnection loops in progress - Releases stopLatch to unblock + * any thread waiting in stop() + */ + public void close() { + if (!closed.compareAndSet(false, true)) { + return; // already closed, skip duplicate work + } + // Nullify emitter to prevent further data sending + synchronized (MultiModalDialog.this) { + conversationEmitter = null; + DialogBuffer.clear(); + } + // Force cancel: disposes upstream subscription + cancels WebSocket (non-blocking) + // Also sets isClosed=true which breaks reconnection loops in + // establishWebSocketClient/sendTextWithRetry/sendBinaryWithRetry + duplexApi.cancel(); + // Release stopLatch to prevent stop() from blocking forever. + // This is critical: if close() is called from onError callback, any thread waiting + // in stop() will be unblocked immediately instead of deadlocking. + CountDownLatch latch = stopLatch.get(); + if (latch != null) { + latch.countDown(); + } + } + /** * Gets current dialogue state. * @@ -532,12 +589,14 @@ private void sendTextFrame( } } - /** Sends stop message. */ - private void sendFinishTaskMessage() { // Instruction type + /** Sends stop message. Returns true if the message was actually sent. */ + private boolean sendFinishTaskMessage() { // Instruction type synchronized (MultiModalDialog.this) { // Synchronized block ensures thread safety if (conversationEmitter != null) { conversationEmitter.onComplete(); // Ends data flow + return true; } + return false; } } } diff --git a/src/main/java/com/alibaba/dashscope/protocol/okhttp/OkHttpWebSocketClient.java b/src/main/java/com/alibaba/dashscope/protocol/okhttp/OkHttpWebSocketClient.java index 157990b..c91bd36 100644 --- a/src/main/java/com/alibaba/dashscope/protocol/okhttp/OkHttpWebSocketClient.java +++ b/src/main/java/com/alibaba/dashscope/protocol/okhttp/OkHttpWebSocketClient.java @@ -17,6 +17,7 @@ import io.reactivex.Flowable; import io.reactivex.FlowableEmitter; import io.reactivex.Observable; +import io.reactivex.disposables.Disposable; import io.reactivex.functions.Action; import java.io.IOException; import java.nio.ByteBuffer; @@ -40,17 +41,20 @@ public class OkHttpWebSocketClient extends WebSocketListener private WebSocket webSocketClient; // indicate the websocket connection is established. private AtomicBoolean isOpen = new AtomicBoolean(false); - private AtomicBoolean isClosed = new AtomicBoolean(false); + protected AtomicBoolean isClosed = new AtomicBoolean(false); // indicate the first response is received. protected AtomicBoolean isFirstMessage = new AtomicBoolean(false); // used for get request response - protected FlowableEmitter responseEmitter; + protected volatile FlowableEmitter responseEmitter; // is the result is flatten format. private boolean isFlattenResult; private FlowableEmitter connectionEmitter; private AtomicBoolean passTaskStarted = new AtomicBoolean(false); + // Disposable for the streaming data subscription, used to cancel upstream when stopping + protected volatile Disposable streamingDataDisposable; + public OkHttpWebSocketClient(OkHttpClient client, boolean passTaskStarted) { this.client = client; this.passTaskStarted.set(passTaskStarted); @@ -97,6 +101,13 @@ public boolean close(int code, String reason) { } public void cancel() { + // Set isClosed BEFORE cancel to suppress onFailure error propagation + isClosed.set(true); + // Dispose upstream subscription to stop sending data + Disposable d = streamingDataDisposable; + if (d != null && !d.isDisposed()) { + d.dispose(); + } if (webSocketClient != null) { webSocketClient.cancel(); } @@ -111,6 +122,11 @@ private void establishWebSocketClient( int reconnectionTimes = 0; String errorMessage = ""; while (reconnectionTimes < MAX_CONNECTION_TIMES) { + // Bail out immediately if cancel() has been called + if (isClosed.get()) { + log.debug("Connection cancelled, stop reconnecting."); + return; + } try { Flowable flowable = Flowable.create( @@ -144,9 +160,17 @@ private void establishWebSocketClient( } else if (errorMessage.contains(Constants.NO_API_KEY_ERROR)) { throw ex; } + // Check again before sleeping + if (isClosed.get()) { + log.debug("Connection cancelled during retry, stop reconnecting."); + return; + } try { Thread.sleep(10000); - } catch (InterruptedException e) {; + } catch (InterruptedException e) { + // Respect interruption - exit the loop + Thread.currentThread().interrupt(); + return; } } } @@ -167,7 +191,6 @@ public void onClosed(WebSocket webSocket, int code, String reason) { log.debug( StringUtils.format("WebSocket %s closed: %d, %s", webSocket.toString(), code, reason)); isOpen.set(false); - isClosed.set(false); } @Override @@ -379,13 +402,25 @@ protected void sendTextWithRetry( String workspace, Map customHeaders, String baseWebSocketUrl) { + // Guard: skip if already cancelled + if (isClosed.get()) { + log.debug("sendTextWithRetry skipped: connection already closed."); + return; + } // simple retry with fixed delay, no strategy if (!isOpen.get()) { establishWebSocketClient(apiKey, isSecurityCheck, workspace, customHeaders, baseWebSocketUrl); } + if (isClosed.get()) { + return; + } int maxRetries = 3; if (passTaskStarted.get()) { // when pass througn task started, no need to retry. + if (webSocketClient == null) { + log.warn("webSocketClient is null, cannot send message."); + return; + } log.info("Sending message: " + message); Boolean isOk = webSocketClient.send(message); if (!isOk) { @@ -395,6 +430,13 @@ protected void sendTextWithRetry( } int retryCount = 0; while (retryCount < maxRetries) { + if (isClosed.get()) { + return; + } + if (webSocketClient == null) { + log.warn("webSocketClient is null, cannot send message."); + return; + } log.debug("Sending message: " + message); Boolean isOk = webSocketClient.send(message); if (isOk) { @@ -418,12 +460,26 @@ protected void sendBinaryWithRetry( String workspace, Map customHeaders, String baseWebSocketUrl) { + // Guard: skip if already cancelled + if (isClosed.get()) { + return; + } if (!isOpen.get()) { establishWebSocketClient(apiKey, isSecurityCheck, workspace, customHeaders, baseWebSocketUrl); } + if (isClosed.get()) { + return; + } int maxRetries = 3; int retryCount = 0; while (retryCount < maxRetries) { + if (isClosed.get()) { + return; + } + if (webSocketClient == null) { + log.warn("webSocketClient is null, cannot send binary message."); + return; + } Boolean isOk = webSocketClient.send(message); if (isOk) { break; @@ -564,92 +620,115 @@ public void run() throws Exception { }); } + /** + * Hook method called before sending the start message. Subclasses may override to add additional + * logging or pre-processing. + */ + protected void onBeforeSendStartMessage(JsonObject startMessage) { + // no-op by default + } + + /** Core streaming request logic. Extracted to allow subclasses to use different executors. */ + protected void executeStreamRequest(FullDuplexRequest req) { + try { + isClosed.set(false); // Reset for reuse across sessions + isFirstMessage.set(false); + + JsonObject startMessage = req.getStartTaskMessage(); + onBeforeSendStartMessage(startMessage); + String taskId = startMessage.get("header").getAsJsonObject().get("task_id").getAsString(); + // send start message out. + sendTextWithRetry( + req.getApiKey(), + req.isSecurityCheck(), + JsonUtils.toJson(startMessage), + req.getWorkspace(), + req.getHeaders(), + req.getBaseWebSocketUrl()); + + Flowable streamingData = req.getStreamingData(); + Disposable d = + streamingData.subscribe( + data -> { + try { + if (data instanceof String) { + JsonObject continueData = req.getContinueMessage((String) data, taskId); + sendTextWithRetry( + req.getApiKey(), + req.isSecurityCheck(), + JsonUtils.toJson(continueData), + req.getWorkspace(), + req.getHeaders(), + req.getBaseWebSocketUrl()); + } else if (data instanceof byte[]) { + sendBinaryWithRetry( + req.getApiKey(), + req.isSecurityCheck(), + ByteString.of((byte[]) data), + req.getWorkspace(), + req.getHeaders(), + req.getBaseWebSocketUrl()); + } else if (data instanceof ByteBuffer) { + sendBinaryWithRetry( + req.getApiKey(), + req.isSecurityCheck(), + ByteString.of((ByteBuffer) data), + req.getWorkspace(), + req.getHeaders(), + req.getBaseWebSocketUrl()); + } else { + JsonObject continueData = req.getContinueMessage(data, taskId); + sendTextWithRetry( + req.getApiKey(), + req.isSecurityCheck(), + JsonUtils.toJson(continueData), + req.getWorkspace(), + req.getHeaders(), + req.getBaseWebSocketUrl()); + } + } catch (Throwable ex) { + log.error(StringUtils.format("sendStreamData exception: %s", ex.getMessage())); + if (responseEmitter != null && !responseEmitter.isCancelled()) { + responseEmitter.onError(ex); + } + } + }, + err -> { + log.error(StringUtils.format("Get stream data error!")); + if (responseEmitter != null && !responseEmitter.isCancelled()) { + responseEmitter.onError(err); + } + }, + new Action() { + @Override + public void run() throws Exception { + log.debug(StringUtils.format("Stream data send completed!")); + sendTextWithRetry( + req.getApiKey(), + req.isSecurityCheck(), + JsonUtils.toJson(req.getFinishedTaskMessage(taskId)), + req.getWorkspace(), + req.getHeaders(), + req.getBaseWebSocketUrl()); + } + }); + // Publish the disposable, then check if cancel() raced ahead. + // If isClosed is already true, cancel() has already run and missed + // this disposable, so we must dispose it ourselves. + streamingDataDisposable = d; + if (isClosed.get()) { + d.dispose(); + } + } catch (Throwable ex) { + log.error(StringUtils.format("sendStreamData exception: %s", ex.getMessage())); + if (responseEmitter != null && !responseEmitter.isCancelled()) { + responseEmitter.onError(ex); + } + } + } + protected CompletableFuture sendStreamRequest(FullDuplexRequest req) { - CompletableFuture future = - CompletableFuture.runAsync( - () -> { - try { - isFirstMessage.set(false); - - JsonObject startMessage = req.getStartTaskMessage(); - String taskId = - startMessage.get("header").getAsJsonObject().get("task_id").getAsString(); - // send start message out. - sendTextWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - JsonUtils.toJson(startMessage), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); - - Flowable streamingData = req.getStreamingData(); - streamingData.subscribe( - data -> { - try { - if (data instanceof String) { - JsonObject continueData = req.getContinueMessage((String) data, taskId); - sendTextWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - JsonUtils.toJson(continueData), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); - } else if (data instanceof byte[]) { - sendBinaryWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - ByteString.of((byte[]) data), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); - } else if (data instanceof ByteBuffer) { - sendBinaryWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - ByteString.of((ByteBuffer) data), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); - } else { - JsonObject continueData = req.getContinueMessage(data, taskId); - sendTextWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - JsonUtils.toJson(continueData), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); - } - } catch (Throwable ex) { - log.error( - StringUtils.format("sendStreamData exception: %s", ex.getMessage())); - responseEmitter.onError(ex); - } - }, - err -> { - log.error(StringUtils.format("Get stream data error!")); - responseEmitter.onError(err); - }, - new Action() { - @Override - public void run() throws Exception { - log.debug(StringUtils.format("Stream data send completed!")); - sendTextWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - JsonUtils.toJson(req.getFinishedTaskMessage(taskId)), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); - } - }); - } catch (Throwable ex) { - log.error(StringUtils.format("sendStreamData exception: %s", ex.getMessage())); - responseEmitter.onError(ex); - } - }); + CompletableFuture future = CompletableFuture.runAsync(() -> executeStreamRequest(req)); return future; } diff --git a/src/main/java/com/alibaba/dashscope/protocol/okhttp/OkHttpWebSocketClientForAudio.java b/src/main/java/com/alibaba/dashscope/protocol/okhttp/OkHttpWebSocketClientForAudio.java index 13019f8..a7171e9 100644 --- a/src/main/java/com/alibaba/dashscope/protocol/okhttp/OkHttpWebSocketClientForAudio.java +++ b/src/main/java/com/alibaba/dashscope/protocol/okhttp/OkHttpWebSocketClientForAudio.java @@ -2,17 +2,12 @@ import com.alibaba.dashscope.protocol.FullDuplexRequest; import com.alibaba.dashscope.utils.JsonUtils; -import com.alibaba.dashscope.utils.StringUtils; import com.google.gson.JsonObject; -import io.reactivex.Flowable; -import io.reactivex.functions.Action; -import java.nio.ByteBuffer; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import lombok.extern.slf4j.Slf4j; import okhttp3.OkHttpClient; -import okio.ByteString; /** @author songsong.shao */ @Slf4j @@ -45,101 +40,20 @@ public OkHttpWebSocketClientForAudio(OkHttpClient client, boolean passTaskStarte } @Override - protected CompletableFuture sendStreamRequest(FullDuplexRequest req) { - CompletableFuture future = - CompletableFuture.runAsync( - () -> { - try { - isFirstMessage.set(false); - - JsonObject startMessage = req.getStartTaskMessage(); - log.info("send run-task request {}", JsonUtils.toJson(startMessage)); - String taskId = - startMessage.get("header").getAsJsonObject().get("task_id").getAsString(); - // send start message out. - sendTextWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - JsonUtils.toJson(startMessage), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); + protected void onBeforeSendStartMessage(JsonObject startMessage) { + log.info("send run-task request {}", JsonUtils.toJson(startMessage)); + } - Flowable streamingData = req.getStreamingData(); - streamingData.subscribe( - data -> { - try { - if (data instanceof String) { - JsonObject continueData = req.getContinueMessage((String) data, taskId); - sendTextWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - JsonUtils.toJson(continueData), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); - } else if (data instanceof byte[]) { - sendBinaryWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - ByteString.of((byte[]) data), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); - } else if (data instanceof ByteBuffer) { - sendBinaryWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - ByteString.of((ByteBuffer) data), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); - } else { - JsonObject continueData = req.getContinueMessage(data, taskId); - sendTextWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - JsonUtils.toJson(continueData), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); - } - } catch (Throwable ex) { - log.error( - StringUtils.format("sendStreamData exception: %s", ex.getMessage())); - responseEmitter.onError(ex); - } - }, - err -> { - log.error(StringUtils.format("Get stream data error!")); - responseEmitter.onError(err); - }, - new Action() { - @Override - public void run() throws Exception { - log.debug(StringUtils.format("Stream data send completed!")); - sendTextWithRetry( - req.getApiKey(), - req.isSecurityCheck(), - JsonUtils.toJson(req.getFinishedTaskMessage(taskId)), - req.getWorkspace(), - req.getHeaders(), - req.getBaseWebSocketUrl()); - } - }); - } catch (Throwable ex) { - log.error(StringUtils.format("sendStreamData exception: %s", ex.getMessage())); - responseEmitter.onError(ex); - } - }, - STREAMING_REQUEST_EXECUTOR); - return future; + @Override + protected CompletableFuture sendStreamRequest(FullDuplexRequest req) { + return CompletableFuture.runAsync(() -> executeStreamRequest(req), STREAMING_REQUEST_EXECUTOR); } static { // auto close when jvm shutdown Runtime.getRuntime() .addShutdownHook(new Thread(OkHttpWebSocketClientForAudio::shutdownStreamingExecutor)); } + /** * Shutdown the streaming request executor gracefully. This method should be called when the * application is shutting down to ensure proper resource cleanup.