Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ public class Sentence {
@SerializedName("sentence_end")
boolean sentenceEnd;

@SerializedName("speaker_id")
String speakerId;

public static Sentence from(String message) {
return JsonUtils.fromJson(message, Sentence.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Queue;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Builder;
import lombok.Getter;
Expand Down Expand Up @@ -440,17 +441,67 @@ public void updateInfo(MultiModalRequestParam.UpdateParams updateParams) {
sendTextFrame("UpdateInfo");
}

/** Stops the MultiModalDialog. */
/**
* Stops the MultiModalDialog gracefully. Sends a finish message to the server and waits for the
* server to acknowledge. If the server does not respond within the timeout (30s), falls back to
* force close.
*
* <p>This method is safe to call after onError — it will not deadlock.
*/
public void stop() {
sendFinishTaskMessage();
if (stopLatch.get() != null) {
// If emitter is already null (e.g., close() was called from onError), skip sending and just
// ensure cleanup
boolean sent;
synchronized (MultiModalDialog.this) {
sent = (conversationEmitter != null);
}
if (sent) {
sendFinishTaskMessage();
}
Comment on lines +454 to +460

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The synchronization and null-check on conversationEmitter here is redundant. The sendFinishTaskMessage() method already synchronizes on MultiModalDialog.this and safely checks if conversationEmitter is non-null before calling onComplete(). We can simplify this block to a single call to sendFinishTaskMessage().

Suggested change
boolean sent;
synchronized (MultiModalDialog.this) {
sent = (conversationEmitter != null);
}
if (sent) {
sendFinishTaskMessage();
}
sendFinishTaskMessage();


CountDownLatch latch = stopLatch.get();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a concurrency issue with how stopLatch is managed. stopLatch is declared as a non-volatile, non-final AtomicReference<CountDownLatch>:

private AtomicReference<CountDownLatch> stopLatch = new AtomicReference<>(null);

However, in the start() method (line 213), it is reassigned to a completely new AtomicReference instance:

stopLatch = new AtomicReference<>(new CountDownLatch(1));

Because stopLatch is neither final nor volatile, this reassignment is not thread-safe and can lead to other threads (calling stop() or close()) reading a stale or partially published reference.

Recommendation

To fix this, make stopLatch final and use stopLatch.set(new CountDownLatch(1)) in the start() method instead of reassigning the reference.

if (latch != null) {
try {
stopLatch.get().await();
} catch (InterruptedException ignored) {
// Use timeout to prevent deadlock: if server doesn't respond in 30s, force close
boolean completed = latch.await(30, TimeUnit.SECONDS);
if (!completed) {
log.warn("stop() timed out waiting for server acknowledgement, forcing close");
close();
}
} catch (InterruptedException e) {
close();
Thread.currentThread().interrupt();
}
}
}

/**
* Force closes the dialog immediately without sending any message to the server. Safe to call
* from any callback thread (including onError/onStopped). Does not block.
*
* <p>This method: - Nullifies emitter to prevent further data sending - Disposes upstream
* subscription to stop sendBinaryWithRetry - Forces WebSocket cancel (non-blocking) - Sets
* isClosed to short-circuit any reconnection loops in progress - Releases stopLatch to unblock
* any thread waiting in stop()
*/
public void close() {
// Nullify emitter to prevent further data sending
synchronized (MultiModalDialog.this) {
conversationEmitter = null;
}
// Force cancel: disposes upstream subscription + cancels WebSocket (non-blocking)
// Also sets isClosed=true which breaks reconnection loops in
// establishWebSocketClient/sendTextWithRetry/sendBinaryWithRetry
duplexApi.cancel();
// Release stopLatch to prevent stop() from blocking forever.
// This is critical: if close() is called from onError callback, any thread waiting
// in stop() will be unblocked immediately instead of deadlocking.
CountDownLatch latch = stopLatch.get();
if (latch != null) {
latch.countDown();
}
}

/**
* Gets current dialogue state.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.reactivex.Flowable;
import io.reactivex.FlowableEmitter;
import io.reactivex.Observable;
import io.reactivex.disposables.Disposable;
import io.reactivex.functions.Action;
import java.io.IOException;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -51,6 +52,9 @@ public class OkHttpWebSocketClient extends WebSocketListener

private AtomicBoolean passTaskStarted = new AtomicBoolean(false);

// Disposable for the streaming data subscription, used to cancel upstream when stopping
protected volatile Disposable streamingDataDisposable;

public OkHttpWebSocketClient(OkHttpClient client, boolean passTaskStarted) {
this.client = client;
this.passTaskStarted.set(passTaskStarted);
Expand Down Expand Up @@ -97,6 +101,13 @@ public boolean close(int code, String reason) {
}

public void cancel() {
// Set isClosed BEFORE cancel to suppress onFailure error propagation
isClosed.set(true);
// Dispose upstream subscription to stop sending data
Disposable d = streamingDataDisposable;
if (d != null && !d.isDisposed()) {
d.dispose();
}
if (webSocketClient != null) {
webSocketClient.cancel();
}
Expand All @@ -111,6 +122,11 @@ private void establishWebSocketClient(
int reconnectionTimes = 0;
String errorMessage = "";
while (reconnectionTimes < MAX_CONNECTION_TIMES) {
// Bail out immediately if cancel() has been called
if (isClosed.get()) {
log.debug("Connection cancelled, stop reconnecting.");
return;
}
try {
Flowable<DashScopeResult> flowable =
Flowable.<DashScopeResult>create(
Expand Down Expand Up @@ -144,9 +160,17 @@ private void establishWebSocketClient(
} else if (errorMessage.contains(Constants.NO_API_KEY_ERROR)) {
throw ex;
}
// Check again before sleeping
if (isClosed.get()) {
log.debug("Connection cancelled during retry, stop reconnecting.");
return;
}
try {
Thread.sleep(10000);
} catch (InterruptedException e) {;
} catch (InterruptedException e) {
// Respect interruption - exit the loop
Thread.currentThread().interrupt();
return;
}
}
}
Expand Down Expand Up @@ -379,10 +403,18 @@ protected void sendTextWithRetry(
String workspace,
Map<String, String> customHeaders,
String baseWebSocketUrl) {
// Guard: skip if already cancelled
if (isClosed.get()) {
log.debug("sendTextWithRetry skipped: connection already closed.");
return;
}
// simple retry with fixed delay, no strategy
if (!isOpen.get()) {
establishWebSocketClient(apiKey, isSecurityCheck, workspace, customHeaders, baseWebSocketUrl);
}
if (isClosed.get()) {
return;
}
int maxRetries = 3;
if (passTaskStarted.get()) {
// when pass througn task started, no need to retry.
Expand All @@ -395,6 +427,9 @@ protected void sendTextWithRetry(
}
int retryCount = 0;
while (retryCount < maxRetries) {
if (isClosed.get()) {
return;
}
log.debug("Sending message: " + message);
Boolean isOk = webSocketClient.send(message);
if (isOk) {
Expand All @@ -418,12 +453,22 @@ protected void sendBinaryWithRetry(
String workspace,
Map<String, String> customHeaders,
String baseWebSocketUrl) {
// Guard: skip if already cancelled
if (isClosed.get()) {
return;
}
if (!isOpen.get()) {
establishWebSocketClient(apiKey, isSecurityCheck, workspace, customHeaders, baseWebSocketUrl);
}
if (isClosed.get()) {
return;
}
int maxRetries = 3;
int retryCount = 0;
while (retryCount < maxRetries) {
if (isClosed.get()) {
return;
}
Boolean isOk = webSocketClient.send(message);
if (isOk) {
break;
Expand Down Expand Up @@ -584,67 +629,70 @@ protected CompletableFuture<Void> sendStreamRequest(FullDuplexRequest req) {
req.getBaseWebSocketUrl());

Flowable<Object> streamingData = req.getStreamingData();
streamingData.subscribe(
data -> {
try {
if (data instanceof String) {
JsonObject continueData = req.getContinueMessage((String) data, taskId);
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(continueData),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else if (data instanceof byte[]) {
sendBinaryWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
ByteString.of((byte[]) data),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else if (data instanceof ByteBuffer) {
sendBinaryWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
ByteString.of((ByteBuffer) data),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else {
JsonObject continueData = req.getContinueMessage(data, taskId);
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(continueData),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
}
} catch (Throwable ex) {
log.error(
StringUtils.format("sendStreamData exception: %s", ex.getMessage()));
responseEmitter.onError(ex);
}
},
err -> {
log.error(StringUtils.format("Get stream data error!"));
responseEmitter.onError(err);
},
new Action() {
@Override
public void run() throws Exception {
log.debug(StringUtils.format("Stream data send completed!"));
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(req.getFinishedTaskMessage(taskId)),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
}
});
streamingDataDisposable =
streamingData.subscribe(
data -> {
try {
if (data instanceof String) {
JsonObject continueData =
req.getContinueMessage((String) data, taskId);
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(continueData),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else if (data instanceof byte[]) {
sendBinaryWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
ByteString.of((byte[]) data),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else if (data instanceof ByteBuffer) {
sendBinaryWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
ByteString.of((ByteBuffer) data),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
} else {
JsonObject continueData = req.getContinueMessage(data, taskId);
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(continueData),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
}
} catch (Throwable ex) {
log.error(
StringUtils.format(
"sendStreamData exception: %s", ex.getMessage()));
responseEmitter.onError(ex);
}
},
err -> {
log.error(StringUtils.format("Get stream data error!"));
responseEmitter.onError(err);
},
new Action() {
@Override
public void run() throws Exception {
log.debug(StringUtils.format("Stream data send completed!"));
sendTextWithRetry(
req.getApiKey(),
req.isSecurityCheck(),
JsonUtils.toJson(req.getFinishedTaskMessage(taskId)),
req.getWorkspace(),
req.getHeaders(),
req.getBaseWebSocketUrl());
}
});
Comment on lines +632 to +695

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a race condition between cancel() and the asynchronous subscription in sendStreamRequest(). Because sendStreamRequest() runs asynchronously via CompletableFuture.runAsync(), cancel() might be called before the subscription is established and assigned to streamingDataDisposable. If that happens, cancel() sees streamingDataDisposable as null and does not dispose it, while the background thread subsequently subscribes and leaks the active subscription.

Recommendation

To prevent this leak, assign the subscription to a local variable first, then assign it to streamingDataDisposable, and finally check isClosed.get(). If isClosed.get() is true, dispose the subscription immediately.

                Disposable d =
                    streamingData.subscribe(
                        data -> {
                          try {
                            if (data instanceof String) {
                              JsonObject continueData =
                                  req.getContinueMessage((String) data, taskId);
                              sendTextWithRetry(
                                  req.getApiKey(),
                                  req.isSecurityCheck(),
                                  JsonUtils.toJson(continueData),
                                  req.getWorkspace(),
                                  req.getHeaders(),
                                  req.getBaseWebSocketUrl());
                            } else if (data instanceof byte[]) {
                              sendBinaryWithRetry(
                                  req.getApiKey(),
                                  req.isSecurityCheck(),
                                  ByteString.of((byte[]) data),
                                  req.getWorkspace(),
                                  req.getHeaders(),
                                  req.getBaseWebSocketUrl());
                            } else if (data instanceof ByteBuffer) {
                              sendBinaryWithRetry(
                                  req.getApiKey(),
                                  req.isSecurityCheck(),
                                  ByteString.of((ByteBuffer) data),
                                  req.getWorkspace(),
                                  req.getHeaders(),
                                  req.getBaseWebSocketUrl());
                            } else {
                              JsonObject continueData = req.getContinueMessage(data, taskId);
                              sendTextWithRetry(
                                  req.getApiKey(),
                                  req.isSecurityCheck(),
                                  JsonUtils.toJson(continueData),
                                  req.getWorkspace(),
                                  req.getHeaders(),
                                  req.getBaseWebSocketUrl());
                            }
                          } catch (Throwable ex) {
                            log.error(
                                StringUtils.format(
                                    "sendStreamData exception: %s", ex.getMessage()));
                            responseEmitter.onError(ex);
                          }
                        },
                        err -> {
                          log.error(StringUtils.format("Get stream data error!"));
                          responseEmitter.onError(err);
                        },
                        new Action() {
                          @Override
                          public void run() throws Exception {
                            log.debug(StringUtils.format("Stream data send completed!"));
                            sendTextWithRetry(
                                req.getApiKey(),
                                req.isSecurityCheck(),
                                JsonUtils.toJson(req.getFinishedTaskMessage(taskId)),
                                req.getWorkspace(),
                                req.getHeaders(),
                                req.getBaseWebSocketUrl());
                          }
                        });
                streamingDataDisposable = d;
                if (isClosed.get()) {
                  d.dispose();
                }

} catch (Throwable ex) {
log.error(StringUtils.format("sendStreamData exception: %s", ex.getMessage()));
responseEmitter.onError(ex);
Expand Down
Loading
Loading