Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 67 additions & 4 deletions server/coding-cli/codex-app-server/remote-proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { allocateLocalhostPort, type LoopbackServerEndpoint } from '../../local-
import {
CodexFsChangedNotificationSchema,
CodexThreadLifecycleNotificationSchema,
CodexTurnInterruptParamsSchema,
CodexTurnCompletedNotificationSchema,
CodexTurnStartedNotificationSchema,
type CodexThreadHandle,
Expand Down Expand Up @@ -47,6 +48,7 @@ type CodexRemoteProxyOptions = {

const DEFAULT_REQUEST_HOLD_TIMEOUT_MS = 5_000
const DEFAULT_CANDIDATE_CAPTURE_TIMEOUT_MS = 45_000
const MAX_COMPLETED_TURN_KEYS = 256

export class CodexRemoteProxy {
private readonly upstreamWsUrl: string
Expand All @@ -67,6 +69,8 @@ export class CodexRemoteProxy {
private readonly repairTriggerHandlers = new Set<(event: CodexRemoteProxyRepairTrigger) => void>()
private readonly lifecycleHandlers = new Set<(event: CodexThreadLifecycleEvent) => void>()
private readonly lifecycleLossHandlers = new Set<(event: CodexThreadLifecycleLossEvent) => void>()
private readonly activeTurnKeys = new Set<string>()
private readonly completedTurnKeys = new Set<string>()

constructor(options: CodexRemoteProxyOptions) {
this.upstreamWsUrl = options.upstreamWsUrl
Expand Down Expand Up @@ -260,16 +264,31 @@ export class CodexRemoteProxy {
const parsed = parseJson(raw)
const method = parsed && typeof parsed === 'object' ? (parsed as Record<string, unknown>).method : undefined
const id = jsonRpcId(parsed)
if (id !== undefined && typeof method === 'string') {
connection.pendingMethods.set(id, method)
}
if (typeof method === 'string') {
log.debug({
proxyWsUrl: this.endpoint ? this.wsUrl : undefined,
upstreamWsUrl: this.upstreamWsUrl,
method,
id,
}, 'Codex remote proxy forwarding client request')
}, 'Codex remote proxy received client request')
}

const completedTurnInterrupt = this.completedTurnInterrupt(parsed)
if (id !== undefined && completedTurnInterrupt) {
log.info({
proxyWsUrl: this.endpoint ? this.wsUrl : undefined,
upstreamWsUrl: this.upstreamWsUrl,
method,
id,
threadId: completedTurnInterrupt.threadId,
turnId: completedTurnInterrupt.turnId,
}, 'Codex remote proxy acknowledged interrupt for completed turn')
this.sendJsonRpcSuccess(connection.client, id, {})
return
}

if (id !== undefined && typeof method === 'string') {
connection.pendingMethods.set(id, method)
}

if (this.requireCandidatePersistence && method === 'turn/start' && !this.candidatePersisted) {
Expand Down Expand Up @@ -348,12 +367,14 @@ export class CodexRemoteProxy {

const turnStarted = CodexTurnStartedNotificationSchema.safeParse(parsed)
if (turnStarted.success) {
this.recordTurnStarted(turnStarted.data.params)
this.emitTurnEvent(this.turnStartedHandlers, turnStarted.data.params)
return
}

const turnCompleted = CodexTurnCompletedNotificationSchema.safeParse(parsed)
if (turnCompleted.success) {
this.recordTurnCompleted(turnCompleted.data.params)
this.emitTurnEvent(this.turnCompletedHandlers, turnCompleted.data.params)
return
}
Expand Down Expand Up @@ -430,6 +451,10 @@ export class CodexRemoteProxy {
}))
}

private sendJsonRpcSuccess(client: WebSocket, id: JsonRpcId, result: Record<string, never>): void {
sendIfOpen(client, JSON.stringify({ id, result }))
}

private ensureCandidateCaptureTimer(): void {
if (!this.requireCandidatePersistence) return
if (this.candidatePersisted || this.candidateCaptureTimer) return
Expand Down Expand Up @@ -467,6 +492,40 @@ export class CodexRemoteProxy {
}
}

private recordTurnStarted(params: { threadId: string; turnId?: string }): void {
if (typeof params.turnId !== 'string') return
const key = turnKey(params.threadId, params.turnId)
this.activeTurnKeys.add(key)
this.completedTurnKeys.delete(key)
}

private recordTurnCompleted(params: { threadId: string; turnId?: string }): void {
if (typeof params.turnId !== 'string') return
const key = turnKey(params.threadId, params.turnId)
this.activeTurnKeys.delete(key)
this.rememberCompletedTurnKey(key)
}

private rememberCompletedTurnKey(key: string): void {
this.completedTurnKeys.delete(key)
this.completedTurnKeys.add(key)
while (this.completedTurnKeys.size > MAX_COMPLETED_TURN_KEYS) {
const oldest = this.completedTurnKeys.values().next().value
if (typeof oldest !== 'string') return
this.completedTurnKeys.delete(oldest)
}
}

private completedTurnInterrupt(parsed: unknown): { threadId: string; turnId: string } | undefined {
if (!parsed || typeof parsed !== 'object') return undefined
const message = parsed as Record<string, unknown>
if (message.method !== 'turn/interrupt') return undefined
const params = CodexTurnInterruptParamsSchema.safeParse(message.params)
if (!params.success) return undefined
const key = turnKey(params.data.threadId, params.data.turnId)
return this.completedTurnKeys.has(key) && !this.activeTurnKeys.has(key) ? params.data : undefined
}

private emitRepairTrigger(event: CodexRemoteProxyRepairTrigger): void {
for (const handler of this.repairTriggerHandlers) {
handler(event)
Expand Down Expand Up @@ -514,6 +573,10 @@ function sendIfOpen(socket: WebSocket, data: WebSocket.RawData | string): void {
}
}

function turnKey(threadId: string, turnId: string): string {
return `${threadId}\u0000${turnId}`
}

function normalizeCandidateThread(thread: unknown): CodexThreadHandle | undefined {
if (!thread || typeof thread !== 'object') return undefined
const candidate = thread as Record<string, unknown>
Expand Down
53 changes: 53 additions & 0 deletions test/unit/server/coding-cli/codex-app-server/remote-proxy.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ function nextMessage(socket: WebSocket): Promise<any> {
})
}

function nextMessageWithin(socket: WebSocket, ms: number): Promise<any> {
return Promise.race([
nextMessage(socket),
delay(ms).then(() => {
throw new Error(`Timed out waiting ${ms}ms for websocket message.`)
}),
])
}

function nextMessageFrame(socket: WebSocket): Promise<{ message: any; isBinary: boolean }> {
return new Promise((resolve) => {
socket.once('message', (raw, isBinary) => resolve({
Expand Down Expand Up @@ -336,4 +345,48 @@ describe('CodexRemoteProxy', () => {
params: { threadId: 'thread-1', turnId: 'turn-1', status: 'completed' },
})
})

it('acks duplicate turn/interrupt after the turn already completed', async () => {
const interruptRequests: unknown[] = []
const upstream = await startUpstream((socket, message) => {
if (message.method !== 'turn/interrupt') return
interruptRequests.push(message)
if (interruptRequests.length !== 1) return

socket.send(JSON.stringify({ id: message.id, result: {} }))
socket.send(JSON.stringify({
method: 'thread/status/changed',
params: { threadId: 'thread-1', status: { type: 'idle' } },
}))
socket.send(JSON.stringify({
method: 'turn/completed',
params: { threadId: 'thread-1', turnId: 'turn-1' },
}))
})
const proxy = await startProxy(upstream.wsUrl, {
requireCandidatePersistence: false,
})
const completed = new Promise((resolve) => {
proxy.onTurnCompleted((event) => resolve(event))
})
const tui = await connect(proxy.wsUrl)

tui.send(JSON.stringify({
id: 1,
method: 'turn/interrupt',
params: { threadId: 'thread-1', turnId: 'turn-1' },
}))
await expect(nextMessageWithin(tui, 100)).resolves.toEqual({ id: 1, result: {} })
await expect(completed).resolves.toMatchObject({ threadId: 'thread-1', turnId: 'turn-1' })

tui.send(JSON.stringify({
id: 2,
method: 'turn/interrupt',
params: { threadId: 'thread-1', turnId: 'turn-1' },
}))

await expect(nextMessageWithin(tui, 50)).resolves.toEqual({ id: 2, result: {} })
await delay(25)
expect(interruptRequests).toHaveLength(1)
})
})
Loading