Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ public class Sentence {
@SerializedName("sentence_end")
boolean sentenceEnd;

@SerializedName("speaker_id")
String speakerId;

public static Sentence from(String message) {
return JsonUtils.fromJson(message, Sentence.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Queue;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Builder;
import lombok.Getter;
Expand Down Expand Up @@ -64,7 +65,7 @@ private static class AsyncCmdBuffer { // Asynchronous command buffer class

private final Queue<AsyncCmdBuffer> DialogBuffer = new LinkedList<>(); // Dialogue buffer queue

private AtomicReference<CountDownLatch> stopLatch =
private final AtomicReference<CountDownLatch> stopLatch =
new AtomicReference<>(null); // Stop signal latch

@SuperBuilder
Expand Down Expand Up @@ -209,7 +210,7 @@ public void start() {
},
BackpressureStrategy.BUFFER);

stopLatch = new AtomicReference<>(new CountDownLatch(1)); // Initializes stop signal latch
stopLatch.set(new CountDownLatch(1)); // Initializes stop signal latch

String preTaskId =
requestParam.getTaskId() != null ? requestParam.getTaskId() : UUID.randomUUID().toString();
Expand Down Expand Up @@ -440,17 +441,59 @@ public void updateInfo(MultiModalRequestParam.UpdateParams updateParams) {
sendTextFrame("UpdateInfo");
}

/** Stops the MultiModalDialog. */
/**
* Stops the MultiModalDialog gracefully. Sends a finish message to the server and waits for the
* server to acknowledge. If the server does not respond within the timeout (30s), falls back to
* force close.
*
* <p>This method is safe to call after onError — it will not deadlock.
*/
public void stop() {
sendFinishTaskMessage();
if (stopLatch.get() != null) {

CountDownLatch latch = stopLatch.get();
if (latch != null) {
try {
stopLatch.get().await();
} catch (InterruptedException ignored) {
// Use timeout to prevent deadlock: if server doesn't respond in 30s, force close
boolean completed = latch.await(30, TimeUnit.SECONDS);
if (!completed) {
log.warn("stop() timed out waiting for server acknowledgement, forcing close");
close();
}
} catch (InterruptedException e) {
close();
Thread.currentThread().interrupt();
}
}
}

/**
* Force closes the dialog immediately without sending any message to the server. Safe to call
* from any callback thread (including onError/onStopped). Does not block.
*
* <p>This method: - Nullifies emitter to prevent further data sending - Disposes upstream
* subscription to stop sendBinaryWithRetry - Forces WebSocket cancel (non-blocking) - Sets
* isClosed to short-circuit any reconnection loops in progress - Releases stopLatch to unblock
* any thread waiting in stop()
*/
public void close() {
// Nullify emitter to prevent further data sending
synchronized (MultiModalDialog.this) {
conversationEmitter = null;
}
// Force cancel: disposes upstream subscription + cancels WebSocket (non-blocking)
// Also sets isClosed=true which breaks reconnection loops in
// establishWebSocketClient/sendTextWithRetry/sendBinaryWithRetry
duplexApi.cancel();
// Release stopLatch to prevent stop() from blocking forever.
// This is critical: if close() is called from onError callback, any thread waiting
// in stop() will be unblocked immediately instead of deadlocking.
CountDownLatch latch = stopLatch.get();
if (latch != null) {
latch.countDown();
}
}

/**
* Gets current dialogue state.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.reactivex.Flowable;
import io.reactivex.FlowableEmitter;
import io.reactivex.Observable;
import io.reactivex.disposables.Disposable;
import io.reactivex.functions.Action;
import java.io.IOException;
import java.nio.ByteBuffer;
Expand All @@ -40,7 +41,7 @@ public class OkHttpWebSocketClient extends WebSocketListener
private WebSocket webSocketClient;
// indicate the websocket connection is established.
private AtomicBoolean isOpen = new AtomicBoolean(false);
private AtomicBoolean isClosed = new AtomicBoolean(false);
protected AtomicBoolean isClosed = new AtomicBoolean(false);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The isClosed flag is used to prevent reconnection loops and suppress error propagation when the client is cancelled or closed. However, in onClosed(WebSocket webSocket, int code, String reason), isClosed is reset to false via isClosed.set(false);.

If cancel() is called, it sets isClosed to true and cancels the websocket, which eventually triggers onClosed (either directly or via onFailure calling close()). When onClosed runs, it resets isClosed back to false.

If the reconnection loop in establishWebSocketClient is currently sleeping or retrying, it checks isClosed.get(). Since it was reset to false, the loop will continue to retry and reconnect, defeating the purpose of cancel().

Recommendation:
Modify onClosed so that it does not unconditionally reset isClosed to false if it was explicitly set to true by a user-initiated close or cancel.

// indicate the first response is received.
protected AtomicBoolean isFirstMessage = new AtomicBoolean(false);
// used for get request response
Expand All @@ -51,6 +52,9 @@ public class OkHttpWebSocketClient extends WebSocketListener

private AtomicBoolean passTaskStarted = new AtomicBoolean(false);

// Disposable for the streaming data subscription, used to cancel upstream when stopping
protected volatile Disposable streamingDataDisposable;

public OkHttpWebSocketClient(OkHttpClient client, boolean passTaskStarted) {
this.client = client;
this.passTaskStarted.set(passTaskStarted);
Expand Down Expand Up @@ -97,6 +101,13 @@ public boolean close(int code, String reason) {
}

public void cancel() {
// Set isClosed BEFORE cancel to suppress onFailure error propagation
isClosed.set(true);
// Dispose upstream subscription to stop sending data
Disposable d = streamingDataDisposable;
if (d != null && !d.isDisposed()) {
d.dispose();
}
if (webSocketClient != null) {
webSocketClient.cancel();
}
Expand All @@ -111,6 +122,11 @@ private void establishWebSocketClient(
int reconnectionTimes = 0;
String errorMessage = "";
while (reconnectionTimes < MAX_CONNECTION_TIMES) {
// Bail out immediately if cancel() has been called
if (isClosed.get()) {
log.debug("Connection cancelled, stop reconnecting.");
return;
}
try {
Flowable<DashScopeResult> flowable =
Flowable.<DashScopeResult>create(
Expand Down Expand Up @@ -144,9 +160,17 @@ private void establishWebSocketClient(
} else if (errorMessage.contains(Constants.NO_API_KEY_ERROR)) {
throw ex;
}
// Check again before sleeping
if (isClosed.get()) {
log.debug("Connection cancelled during retry, stop reconnecting.");
return;
}
try {
Thread.sleep(10000);
} catch (InterruptedException e) {;
} catch (InterruptedException e) {
// Respect interruption - exit the loop
Thread.currentThread().interrupt();
return;
}
}
}
Expand Down Expand Up @@ -379,10 +403,18 @@ protected void sendTextWithRetry(
String workspace,
Map<String, String> customHeaders,
String baseWebSocketUrl) {
// Guard: skip if already cancelled
if (isClosed.get()) {
log.debug("sendTextWithRetry skipped: connection already closed.");
return;
}
// simple retry with fixed delay, no strategy
if (!isOpen.get()) {
establishWebSocketClient(apiKey, isSecurityCheck, workspace, customHeaders, baseWebSocketUrl);
}
if (isClosed.get()) {
return;
}
int maxRetries = 3;
if (passTaskStarted.get()) {
// when pass througn task started, no need to retry.
Expand All @@ -395,6 +427,9 @@ protected void sendTextWithRetry(
}
int retryCount = 0;
while (retryCount < maxRetries) {
if (isClosed.get()) {
return;
}
log.debug("Sending message: " + message);
Comment on lines +430 to 433

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If establishWebSocketClient is interrupted during retry sleep, it catches InterruptedException, restores the interrupt status, and returns early. In this case, webSocketClient remains null.

If isClosed.get() is still false, the execution in sendTextWithRetry continues to the retry loop and attempts to call webSocketClient.send(message), which will throw a NullPointerException.

Recommendation:
Add a null check for webSocketClient before calling send to prevent potential NullPointerException when the connection establishment is interrupted or fails silently.

      if (isClosed.get()) {
        return;
      }
      if (webSocketClient == null) {
        log.warn("webSocketClient is null, cannot send message.");
        return;
      }
      log.debug("Sending message: " + message);

Boolean isOk = webSocketClient.send(message);
if (isOk) {
Expand All @@ -418,12 +453,22 @@ protected void sendBinaryWithRetry(
String workspace,
Map<String, String> customHeaders,
String baseWebSocketUrl) {
// Guard: skip if already cancelled
if (isClosed.get()) {
return;
}
if (!isOpen.get()) {
establishWebSocketClient(apiKey, isSecurityCheck, workspace, customHeaders, baseWebSocketUrl);
}
if (isClosed.get()) {
return;
}
int maxRetries = 3;
int retryCount = 0;
while (retryCount < maxRetries) {
if (isClosed.get()) {
return;
}
Boolean isOk = webSocketClient.send(message);
if (isOk) {
break;
Expand Down Expand Up @@ -584,67 +629,77 @@ protected CompletableFuture<Void> sendStreamRequest(FullDuplexRequest req) {
req.getBaseWebSocketUrl());

Flowable<Object> streamingData = req.getStreamingData();
streamingData.subscribe(
data -> {
try {
if (data instanceof String) {
JsonObject continueData = req.getContinueMessage((String) data, taskId);
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(continueData),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else if (data instanceof byte[]) {
sendBinaryWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
ByteString.of((byte[]) data),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else if (data instanceof ByteBuffer) {
sendBinaryWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
ByteString.of((ByteBuffer) data),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else {
JsonObject continueData = req.getContinueMessage(data, taskId);
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(continueData),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
}
} catch (Throwable ex) {
log.error(
StringUtils.format("sendStreamData exception: %s", ex.getMessage()));
responseEmitter.onError(ex);
}
},
err -> {
log.error(StringUtils.format("Get stream data error!"));
responseEmitter.onError(err);
},
new Action() {
@Override
public void run() throws Exception {
log.debug(StringUtils.format("Stream data send completed!"));
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(req.getFinishedTaskMessage(taskId)),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
}
});
Disposable d =
streamingData.subscribe(
data -> {
try {
if (data instanceof String) {
JsonObject continueData =
req.getContinueMessage((String) data, taskId);
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(continueData),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else if (data instanceof byte[]) {
sendBinaryWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
ByteString.of((byte[]) data),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else if (data instanceof ByteBuffer) {
sendBinaryWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
ByteString.of((ByteBuffer) data),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else {
JsonObject continueData = req.getContinueMessage(data, taskId);
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(continueData),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
}
} catch (Throwable ex) {
log.error(
StringUtils.format(
"sendStreamData exception: %s", ex.getMessage()));
responseEmitter.onError(ex);
}
},
err -> {
log.error(StringUtils.format("Get stream data error!"));
responseEmitter.onError(err);
},
new Action() {
@Override
public void run() throws Exception {
log.debug(StringUtils.format("Stream data send completed!"));
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(req.getFinishedTaskMessage(taskId)),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
}
});
// Publish the disposable, then check if cancel() raced ahead.
// If isClosed is already true, cancel() has already run and missed
// this disposable, so we must dispose it ourselves.
streamingDataDisposable = d;
if (isClosed.get()) {
d.dispose();
}
} catch (Throwable ex) {
log.error(StringUtils.format("sendStreamData exception: %s", ex.getMessage()));
responseEmitter.onError(ex);
Expand Down
Loading
Loading