diff --git a/internal/api/handlers/axon_rpc.go b/internal/api/handlers/axon_rpc.go index 6c75aca..7f9ccdd 100644 --- a/internal/api/handlers/axon_rpc.go +++ b/internal/api/handlers/axon_rpc.go @@ -7,6 +7,7 @@ package handlers import ( "context" + "database/sql" "encoding/json" "errors" "net/http" @@ -155,24 +156,33 @@ func (h *RecorderHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request remoteIP := extractIP(r.RemoteAddr) rc := h.hub.NewRecorderConn(conn, deviceID, remoteIP) - if !h.hub.Connect(deviceID, rc) { + staleConn, ok := h.hub.ConnectWithStaleThreshold(deviceID, rc, recorderStaleThreshold(h.cfg)) + if !ok { logger.Printf("[RECORDER] Device %s: connection rejected (already connected)", deviceID) if err := conn.Close(websocket.StatusPolicyViolation, "device already connected"); err != nil { - logger.Printf("[RECORDER] Device %s: WebSocket close error: %v", deviceID, err) + if !isExpectedWebSocketCloseError(err) { + logger.Printf("[RECORDER] Device %s: WebSocket close error: %v", deviceID, err) + } } return } + closeStaleRecorderConn(deviceID, staleConn) defer func() { if err := conn.Close(websocket.StatusNormalClosure, ""); err != nil { - logger.Printf("[RECORDER] Device %s: WebSocket close error: %v", deviceID, err) + if !isExpectedWebSocketCloseError(err) { + logger.Printf("[RECORDER] Device %s: WebSocket close error: %v", deviceID, err) + } + } + }() + defer func() { + if h.hub.Disconnect(deviceID, rc) { + revertRunnableTasksOnDeviceDisconnect(h.db, deviceID, nil, 0, false) } }() - defer h.hub.Disconnect(deviceID, rc) - defer revertRunnableTasksOnDeviceDisconnect(h.db, deviceID, nil, 0, false) ctx := r.Context() - go h.pingLoop(ctx, conn) + go h.pingLoop(ctx, rc) // #nosec G706 -- Set aside for now logger.Printf("[RECORDER] Recorder %s connected from %s", deviceID, remoteIP) @@ -180,7 +190,9 @@ func (h *RecorderHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request for { _, raw, err := conn.Read(ctx) if err != nil { - logger.Printf("[RECORDER] Recorder %s disconnected: %v", deviceID, err) + if !isExpectedWebSocketCloseError(err) { + logger.Printf("[RECORDER] Recorder %s disconnected: %v", deviceID, err) + } return } @@ -228,24 +240,7 @@ func (h *RecorderHandler) Config(c *gin.Context) { return } - // If RPC succeeded (HTTP 200), advance task status: pending -> ready. - // This is best-effort; failures should not change the RPC response. - if taskID != "" && h.db != nil { - now := time.Now().UTC() - _, err := h.db.Exec( - `UPDATE tasks - SET - status = 'ready', - ready_at = CASE WHEN ready_at IS NULL THEN ? ELSE ready_at END, - updated_at = ? - WHERE task_id = ? AND status = 'pending' AND deleted_at IS NULL`, - now, now, taskID, - ) - if err != nil { - logger.Printf("[RECORDER] Device %s: failed to advance task pending->ready after config: task=%s err=%v", c.Param("device_id"), taskID, err) - return - } - } + advanceTaskPendingToReady(h.db, c.Param("device_id"), taskID, "config") } // Begin sends begin recording RPC to the recorder. @@ -279,24 +274,17 @@ func (h *RecorderHandler) Begin(c *gin.Context) { return } - // If RPC succeeded (HTTP 200), advance task status: ready -> in_progress. + // If RPC succeeded (HTTP 200), advance task status: pending/ready -> in_progress. + // pending is allowed because recorder may preserve ready state across a transient + // WebSocket disconnect while Keystone has already rolled the task back. if taskID != "" && h.db != nil { - now := time.Now().UTC() - res, err := h.db.Exec( - `UPDATE tasks - SET - status = 'in_progress', - started_at = CASE WHEN started_at IS NULL THEN ? ELSE started_at END, - updated_at = ? - WHERE task_id = ? AND status = 'ready' AND deleted_at IS NULL`, - now, now, taskID, - ) + res, err := advanceTaskPendingOrReadyToInProgress(h.db, taskID) if err != nil { - logger.Printf("[RECORDER] Device %s: failed to advance task ready->in_progress after begin: task=%s err=%v", c.Param("device_id"), taskID, err) + logger.Printf("[RECORDER] Device %s: failed to advance task pending/ready->in_progress after begin: task=%s err=%v", c.Param("device_id"), taskID, err) return } if n, _ := res.RowsAffected(); n == 0 { - logger.Printf("[RECORDER] Device %s: task ready->in_progress skipped after begin (not found or not ready): task=%s", c.Param("device_id"), taskID) + h.logBeginTransitionNoop(c.Param("device_id"), taskID) } } } @@ -642,24 +630,7 @@ func (h *RecorderHandler) handleMessage(deviceID string, rc *services.RecorderCo taskID := stringValue(data, "task_id") // #nosec G706 -- Set aside for now logger.Printf("[RECORDER] Recorder %s config applied task=%s", deviceID, taskID) - // Advance task status: pending -> ready (best-effort, mirrors HTTP Config handler). - if taskID != "" && h.db != nil { - now := time.Now().UTC() - res, err := h.db.Exec( - `UPDATE tasks - SET - status = 'ready', - ready_at = CASE WHEN ready_at IS NULL THEN ? ELSE ready_at END, - updated_at = ? - WHERE task_id = ? AND status = 'pending' AND deleted_at IS NULL`, - now, now, taskID, - ) - if err != nil { - logger.Printf("[RECORDER] Recorder %s: failed to advance task pending->ready after config_applied: task=%s err=%v", deviceID, taskID, err) - } else if n, _ := res.RowsAffected(); n == 0 { - logger.Printf("[RECORDER] Recorder %s: task pending->ready skipped after config_applied (not found or not pending): task=%s", deviceID, taskID) - } - } + advanceTaskPendingToReady(h.db, deviceID, taskID, "config_applied") default: // #nosec G706 -- Set aside for now logger.Printf("[RECORDER] Recorder %s unknown message type %q", deviceID, msgType) @@ -687,25 +658,175 @@ func (h *RecorderHandler) handleStateUpdate(rc *services.RecorderConn, msg map[s Raw: data, } rc.UpdateState(state) + h.reconcileRecorderTaskState(rc.DeviceID, state) // #nosec G706 -- Set aside for now logger.Printf("[RECORDER] Recorder %s state=%s task=%s", rc.DeviceID, state.CurrentState, state.TaskID) } -func (h *RecorderHandler) pingLoop(ctx context.Context, conn *websocket.Conn) { - ticker := time.NewTicker(time.Duration(h.cfg.PingInterval) * time.Second) +func (h *RecorderHandler) reconcileRecorderTaskState(deviceID string, state services.RecorderState) { + if h.db == nil { + return + } + taskID := strings.TrimSpace(state.TaskID) + if taskID == "" { + return + } + + switch strings.TrimSpace(state.CurrentState) { + case "ready": + advanceTaskPendingToReady(h.db, deviceID, taskID, "state_update ready") + case "recording": + res, err := advanceTaskPendingOrReadyToInProgress(h.db, taskID) + if err != nil { + logger.Printf("[RECORDER] Recorder %s: failed to advance task pending/ready->in_progress after state_update recording: task=%s err=%v", deviceID, taskID, err) + return + } + if n, _ := res.RowsAffected(); n > 0 { + logger.Printf("[RECORDER] Device %s: task status reconciled: task=%s source=state_update_recording status=in_progress", deviceID, taskID) + } + } +} + +func advanceTaskPendingToReady(db *sqlx.DB, deviceID, taskID, source string) { + taskID = strings.TrimSpace(taskID) + if db == nil || taskID == "" { + return + } + now := time.Now().UTC() + res, err := db.Exec( + `UPDATE tasks + SET + status = 'ready', + ready_at = CASE WHEN ready_at IS NULL THEN ? ELSE ready_at END, + updated_at = ? + WHERE task_id = ? AND status = 'pending' AND deleted_at IS NULL`, + now, now, taskID, + ) + if err != nil { + logger.Printf("[RECORDER] Device %s: failed to advance task pending->ready after %s: task=%s err=%v", deviceID, source, taskID, err) + return + } + if n, _ := res.RowsAffected(); n > 0 { + logger.Printf("[RECORDER] Device %s: task status reconciled: task=%s source=%s status=ready", deviceID, taskID, source) + } +} + +func advanceTaskPendingOrReadyToInProgress(db *sqlx.DB, taskID string) (sql.Result, error) { + if db == nil { + return nil, nil + } + now := time.Now().UTC() + return db.Exec( + `UPDATE tasks + SET + status = 'in_progress', + started_at = CASE WHEN started_at IS NULL THEN ? ELSE started_at END, + updated_at = ? + WHERE task_id = ? AND status IN ('pending', 'ready') AND deleted_at IS NULL`, + now, now, strings.TrimSpace(taskID), + ) +} + +func (h *RecorderHandler) logBeginTransitionNoop(deviceID, taskID string) { + status, ok, err := currentTaskStatus(h.db, taskID) + if err != nil { + logger.Printf("[RECORDER] Device %s: task status lookup failed after begin: task=%s err=%v", deviceID, taskID, err) + return + } + if ok && (status == "in_progress" || status == "completed") { + return + } + if !ok { + logger.Printf("[RECORDER] Device %s: task pending/ready->in_progress skipped after begin (task not found): task=%s", deviceID, taskID) + return + } + logger.Printf("[RECORDER] Device %s: task pending/ready->in_progress skipped after begin (current_status=%s): task=%s", deviceID, status, taskID) +} + +func currentTaskStatus(db *sqlx.DB, taskID string) (string, bool, error) { + if db == nil { + return "", false, nil + } + var status string + err := db.Get(&status, `SELECT status FROM tasks WHERE task_id = ? AND deleted_at IS NULL`, strings.TrimSpace(taskID)) + if errors.Is(err, sql.ErrNoRows) { + return "", false, nil + } + if err != nil { + return "", false, err + } + return status, true, nil +} + +func (h *RecorderHandler) pingLoop(ctx context.Context, rc *services.RecorderConn) { + interval := recorderPingInterval(h.cfg) + if interval <= 0 || rc == nil || rc.Conn == nil { + return + } + timeout := recorderPingTimeout(h.cfg) + if timeout <= 0 { + timeout = interval + } + + ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: - if err := conn.Ping(ctx); err != nil { + pingCtx, cancel := context.WithTimeout(ctx, timeout) + err := rc.Conn.Ping(pingCtx) + cancel() + if err != nil { + if ctx.Err() == nil { + logger.Printf("[RECORDER] Recorder %s ping failed: %v", rc.DeviceID, err) + if closeErr := rc.Conn.CloseNow(); closeErr != nil { + if !isExpectedWebSocketCloseError(closeErr) { + logger.Printf("[RECORDER] Recorder %s close after ping failure: %v", rc.DeviceID, closeErr) + } + } + } return } + rc.LastSeenAt = time.Now() case <-ctx.Done(): return } } } +func recorderPingInterval(cfg *config.RecorderConfig) time.Duration { + if cfg == nil || cfg.PingInterval <= 0 { + return 0 + } + return time.Duration(cfg.PingInterval) * time.Second +} + +func recorderPingTimeout(cfg *config.RecorderConfig) time.Duration { + if cfg == nil || cfg.PingTimeout <= 0 { + return 0 + } + return time.Duration(cfg.PingTimeout) * time.Second +} + +func recorderStaleThreshold(cfg *config.RecorderConfig) time.Duration { + if cfg == nil || cfg.StaleThreshold <= 0 { + return 0 + } + return time.Duration(cfg.StaleThreshold) * time.Second +} + +func closeStaleRecorderConn(deviceID string, rc *services.RecorderConn) { + if rc == nil || rc.Conn == nil { + return + } + logger.Printf("[RECORDER] Device %s: closing stale WebSocket connection", deviceID) + if err := rc.Conn.CloseNow(); err != nil { + if !isExpectedWebSocketCloseError(err) { + logger.Printf("[RECORDER] Device %s: stale WebSocket close error: %v", deviceID, err) + } + } +} + func stringValue(m map[string]interface{}, key string) string { if m == nil { return "" diff --git a/internal/api/handlers/task.go b/internal/api/handlers/task.go index 8ca92db..9f6db6c 100644 --- a/internal/api/handlers/task.go +++ b/internal/api/handlers/task.go @@ -964,11 +964,22 @@ func (h *TaskHandler) OnRecordingStart(c *gin.Context) { return } + taskStatus := "unknown" + if h.db != nil { + res, err := advanceTaskPendingOrReadyToInProgress(h.db, callback.TaskID) + if err != nil { + logger.Printf("[RECORDER] Device %s: failed to advance task pending/ready->in_progress after start callback: task=%s err=%v", callback.DeviceID, callback.TaskID, err) + } else if n, _ := res.RowsAffected(); n > 0 { + taskStatus = "in_progress" + logger.Printf("[RECORDER] Device %s: task status reconciled: task=%s source=start_callback status=in_progress", callback.DeviceID, callback.TaskID) + } + } + now := time.Now() nowStr := now.Format(time.RFC3339) c.JSON(http.StatusOK, gin.H{ "status": "acknowledged", - "task_status": "unknown", + "task_status": taskStatus, "acknowledged_at": nowStr, }) } diff --git a/internal/api/handlers/task_state_recovery_test.go b/internal/api/handlers/task_state_recovery_test.go new file mode 100644 index 0000000..bb1f17e --- /dev/null +++ b/internal/api/handlers/task_state_recovery_test.go @@ -0,0 +1,152 @@ +// SPDX-FileCopyrightText: 2026 ArcheBase +// +// SPDX-License-Identifier: MulanPSL-2.0 + +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "archebase.com/keystone-edge/internal/config" + "archebase.com/keystone-edge/internal/services" + "github.com/gin-gonic/gin" + "github.com/jmoiron/sqlx" + _ "modernc.org/sqlite" +) + +func TestRecorderStateUpdateReadyRestoresPendingTask(t *testing.T) { + db := newTaskStateRecoveryDB(t) + defer db.Close() + seedTaskStateRecoveryTask(t, db, "task-ready", "pending") + + hub := services.NewRecorderHub() + handler := NewRecorderHandler(hub, &config.RecorderConfig{}, db) + rc := hub.NewRecorderConn(nil, "robot-001", "127.0.0.1") + + handler.handleStateUpdate(rc, map[string]interface{}{ + "data": map[string]interface{}{ + "current": "ready", + "task_id": "task-ready", + }, + }) + + assertTaskStateRecoveryStatus(t, db, "task-ready", "ready") + assertTaskStateRecoveryTimestampSet(t, db, "task-ready", "ready_at") +} + +func TestRecorderStateUpdateRecordingAdvancesPendingTask(t *testing.T) { + db := newTaskStateRecoveryDB(t) + defer db.Close() + seedTaskStateRecoveryTask(t, db, "task-recording", "pending") + + hub := services.NewRecorderHub() + handler := NewRecorderHandler(hub, &config.RecorderConfig{}, db) + rc := hub.NewRecorderConn(nil, "robot-001", "127.0.0.1") + + handler.handleStateUpdate(rc, map[string]interface{}{ + "data": map[string]interface{}{ + "current": "recording", + "task_id": "task-recording", + }, + }) + + assertTaskStateRecoveryStatus(t, db, "task-recording", "in_progress") + assertTaskStateRecoveryTimestampSet(t, db, "task-recording", "started_at") +} + +func TestRecordingStartCallbackAdvancesPendingTask(t *testing.T) { + db := newTaskStateRecoveryDB(t) + defer db.Close() + seedTaskStateRecoveryTask(t, db, "task-start", "pending") + + gin.SetMode(gin.TestMode) + router := gin.New() + NewTaskHandler(db, nil, nil, 0).RegisterCallbackRoutes(router.Group("/callbacks")) + + body, err := json.Marshal(RecordingStartCallback{ + TaskID: "task-start", + DeviceID: "robot-001", + Status: "recording", + StartedAt: time.Now().UTC().Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal callback: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/callbacks/start", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d want=%d body=%s", w.Code, http.StatusOK, w.Body.String()) + } + assertTaskStateRecoveryStatus(t, db, "task-start", "in_progress") + assertTaskStateRecoveryTimestampSet(t, db, "task-start", "started_at") +} + +func newTaskStateRecoveryDB(t *testing.T) *sqlx.DB { + t.Helper() + db, err := sqlx.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open sqlite db: %v", err) + } + if _, err := db.Exec(`CREATE TABLE tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id TEXT NOT NULL, + status TEXT NOT NULL, + ready_at TIMESTAMP NULL, + started_at TIMESTAMP NULL, + completed_at TIMESTAMP NULL, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL, + deleted_at TIMESTAMP NULL + )`); err != nil { + t.Fatalf("create tasks schema: %v", err) + } + return db +} + +func seedTaskStateRecoveryTask(t *testing.T, db *sqlx.DB, taskID string, status string) { + t.Helper() + now := time.Now().UTC() + if _, err := db.Exec( + `INSERT INTO tasks (task_id, status, created_at, updated_at) VALUES (?, ?, ?, ?)`, + taskID, + status, + now, + now, + ); err != nil { + t.Fatalf("seed task: %v", err) + } +} + +func assertTaskStateRecoveryStatus(t *testing.T, db *sqlx.DB, taskID string, want string) { + t.Helper() + var got string + if err := db.Get(&got, `SELECT status FROM tasks WHERE task_id = ?`, taskID); err != nil { + t.Fatalf("query task status: %v", err) + } + if got != want { + t.Fatalf("task status=%q want=%q", got, want) + } +} + +func assertTaskStateRecoveryTimestampSet(t *testing.T, db *sqlx.DB, taskID string, column string) { + t.Helper() + if column != "ready_at" && column != "started_at" { + t.Fatalf("unexpected timestamp column %q", column) + } + var got int + if err := db.Get(&got, `SELECT CASE WHEN `+column+` IS NULL THEN 0 ELSE 1 END FROM tasks WHERE task_id = ?`, taskID); err != nil { + t.Fatalf("query task timestamp %s: %v", column, err) + } + if got != 1 { + t.Fatalf("task %s was not set", column) + } +} diff --git a/internal/api/handlers/transfer.go b/internal/api/handlers/transfer.go index 4884f4b..b282db7 100644 --- a/internal/api/handlers/transfer.go +++ b/internal/api/handlers/transfer.go @@ -149,42 +149,36 @@ func (h *TransferHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request remoteIP := extractIP(r.RemoteAddr) dc := h.hub.NewTransferConn(conn, deviceID, remoteIP) - if !h.hub.Connect(deviceID, dc) { + staleConn, ok := h.hub.ConnectWithStaleThreshold(deviceID, dc, transferStaleThreshold(h.cfg)) + if !ok { logger.Printf("[TRANSFER] Device %s: connection rejected (already connected)", deviceID) if err := conn.Close(websocket.StatusPolicyViolation, "device already connected"); err != nil { - logger.Printf("[TRANSFER] WebSocket close error for device %s: %v", deviceID, err) + if !isExpectedWebSocketCloseError(err) { + logger.Printf("[TRANSFER] WebSocket close error for device %s: %v", deviceID, err) + } } return } + closeStaleTransferConn(deviceID, staleConn) defer func() { if err := conn.Close(websocket.StatusNormalClosure, ""); err != nil { - logger.Printf("[TRANSFER] WebSocket close error for device %s: %v", deviceID, err) + if !isExpectedWebSocketCloseError(err) { + logger.Printf("[TRANSFER] WebSocket close error for device %s: %v", deviceID, err) + } } }() defer h.clearUploadNotFoundAttemptsByDevice(deviceID) - defer h.hub.Disconnect(deviceID, dc) - defer revertRunnableTasksOnDeviceDisconnect(h.db, deviceID, h.recorderHub, h.recorderRPCTimeout, true) + defer func() { + if h.hub.Disconnect(deviceID, dc) { + revertRunnableTasksOnDeviceDisconnect(h.db, deviceID, h.recorderHub, h.recorderRPCTimeout, true) + } + }() // Create context for this connection ctx := r.Context() - // Start ping handler to automatically respond to client pings - // This prevents connection timeout due to idle connections - go func() { - ticker := time.NewTicker(25 * time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - if err := conn.Ping(ctx); err != nil { - return - } - case <-ctx.Done(): - return - } - } - }() + go h.pingLoop(ctx, dc) // #nosec G706 -- Set aside for now logger.Printf("[TRANSFER] Transfer %s connected from %s", deviceID, remoteIP) @@ -196,7 +190,9 @@ func (h *TransferHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request for { _, raw, err := conn.Read(ctx) if err != nil { - logger.Printf("[TRANSFER] Device %s disconnected: %v", deviceID, err) + if !isExpectedWebSocketCloseError(err) { + logger.Printf("[TRANSFER] Device %s disconnected: %v", deviceID, err) + } break } @@ -214,6 +210,75 @@ func (h *TransferHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request } } +func (h *TransferHandler) pingLoop(ctx context.Context, dc *services.TransferConn) { + interval := transferPingInterval(h.cfg) + if interval <= 0 || dc == nil || dc.Conn == nil { + return + } + timeout := transferPingTimeout(h.cfg) + if timeout <= 0 { + timeout = interval + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + pingCtx, cancel := context.WithTimeout(ctx, timeout) + err := dc.Conn.Ping(pingCtx) + cancel() + if err != nil { + if ctx.Err() == nil { + logger.Printf("[TRANSFER] Device %s ping failed: %v", dc.DeviceID, err) + if closeErr := dc.Conn.CloseNow(); closeErr != nil { + if !isExpectedWebSocketCloseError(closeErr) { + logger.Printf("[TRANSFER] Device %s close after ping failure: %v", dc.DeviceID, closeErr) + } + } + } + return + } + dc.LastSeenAt = time.Now() + case <-ctx.Done(): + return + } + } +} + +func transferPingInterval(cfg *config.TransferConfig) time.Duration { + if cfg == nil || cfg.PingInterval <= 0 { + return 0 + } + return time.Duration(cfg.PingInterval) * time.Second +} + +func transferPingTimeout(cfg *config.TransferConfig) time.Duration { + if cfg == nil || cfg.PingTimeout <= 0 { + return 0 + } + return time.Duration(cfg.PingTimeout) * time.Second +} + +func transferStaleThreshold(cfg *config.TransferConfig) time.Duration { + if cfg == nil || cfg.StaleThreshold <= 0 { + return 0 + } + return time.Duration(cfg.StaleThreshold) * time.Second +} + +func closeStaleTransferConn(deviceID string, dc *services.TransferConn) { + if dc == nil || dc.Conn == nil { + return + } + logger.Printf("[TRANSFER] Device %s: closing stale WebSocket connection", deviceID) + if err := dc.Conn.CloseNow(); err != nil { + if !isExpectedWebSocketCloseError(err) { + logger.Printf("[TRANSFER] Device %s: stale WebSocket close error: %v", deviceID, err) + } + } +} + // handleMessage dispatches an inbound WebSocket message to the appropriate handler func (h *TransferHandler) handleMessage(ctx context.Context, dc *services.TransferConn, msg map[string]interface{}) { msgType, _ := msg["type"].(string) @@ -626,8 +691,9 @@ func (h *TransferHandler) onUploadComplete(ctx context.Context, dc *services.Tra // #nosec G706 -- Set aside for now logger.Printf("[TRANSFER] Device %s: upload_ack sent for task=%s", dc.DeviceID, taskID) - // After upload_ack is sent, mark task as completed (ready or in_progress -> completed). - // ready is allowed when begin RPC timed out on edge but recording+upload succeeded (episode row committed above). + // After upload_ack is sent, mark task as completed (pending, ready, or in_progress -> completed). + // pending is allowed when Keystone rolled the task back during a transient recorder disconnect, + // while the device kept recording and successfully uploaded the episode. // Best-effort: do not affect the already-sent acknowledgement. now := time.Now().UTC() if _, err := h.db.ExecContext(ctx, ` @@ -636,10 +702,10 @@ func (h *TransferHandler) onUploadComplete(ctx context.Context, dc *services.Tra status = 'completed', completed_at = CASE WHEN completed_at IS NULL THEN ? ELSE completed_at END, updated_at = ? - WHERE id = ? AND status IN ('in_progress', 'ready') AND deleted_at IS NULL + WHERE id = ? AND status IN ('pending', 'in_progress', 'ready') AND deleted_at IS NULL `, now, now, taskPK); err != nil { // #nosec G706 -- Set aside for now - logger.Printf("[TRANSFER] Device %s: failed to mark task ready/in_progress->completed after upload_ack: task=%s err=%v", dc.DeviceID, taskID, err) + logger.Printf("[TRANSFER] Device %s: failed to mark task pending/ready/in_progress->completed after upload_ack: task=%s err=%v", dc.DeviceID, taskID, err) } else { if batchIDForAdvance > 0 { // Must run after the task row is terminal: tryAdvanceBatchStatus counts tasks in DB. diff --git a/internal/api/handlers/websocket_log.go b/internal/api/handlers/websocket_log.go new file mode 100644 index 0000000..1fa46ce --- /dev/null +++ b/internal/api/handlers/websocket_log.go @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2026 ArcheBase +// +// SPDX-License-Identifier: MulanPSL-2.0 + +package handlers + +import ( + "errors" + "net" + "strings" +) + +func isExpectedWebSocketCloseError(err error) bool { + if err == nil { + return true + } + if errors.Is(err, net.ErrClosed) { + return true + } + msg := err.Error() + return strings.Contains(msg, "use of closed network connection") || + strings.Contains(msg, "failed to close WebSocket: use of closed network connection") +} diff --git a/internal/config/config.go b/internal/config/config.go index 137c1c7..914a944 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -116,16 +116,21 @@ type ResourceLimitsConfig struct { // TransferConfig Transfer service configuration type TransferConfig struct { - WSPort int - MaxEvents int - ReadTimeout int // seconds - FactoryID string + WSPort int + MaxEvents int + ReadTimeout int // seconds + PingInterval int // seconds + PingTimeout int // seconds + StaleThreshold int // seconds + FactoryID string } // RecorderConfig Axon Recorder RPC gateway configuration type RecorderConfig struct { WSPort int PingInterval int // seconds + PingTimeout int // seconds + StaleThreshold int // seconds ResponseTimeout int // seconds } @@ -225,14 +230,19 @@ func Load() (*Config, error) { DiskWatermarkHigh: getEnvInt("KEYSTONE_DISK_WATERMARK_HIGH", 10), }, AxonTransfer: TransferConfig{ - WSPort: getEnvInt("KEYSTONE_AXON_TRANSFER_WS_PORT", 8090), - MaxEvents: getEnvInt("KEYSTONE_AXON_TRANSFER_MAX_EVENTS", 10000), - ReadTimeout: getEnvInt("KEYSTONE_AXON_TRANSFER_READ_TIMEOUT", 30), - FactoryID: getEnv("KEYSTONE_FACTORY_ID", "factory-default"), + WSPort: getEnvInt("KEYSTONE_AXON_TRANSFER_WS_PORT", 8090), + MaxEvents: getEnvInt("KEYSTONE_AXON_TRANSFER_MAX_EVENTS", 10000), + ReadTimeout: getEnvInt("KEYSTONE_AXON_TRANSFER_READ_TIMEOUT", 30), + PingInterval: getEnvInt("KEYSTONE_AXON_TRANSFER_PING_INTERVAL", 25), + PingTimeout: getEnvInt("KEYSTONE_AXON_TRANSFER_PING_TIMEOUT", 10), + StaleThreshold: getEnvInt("KEYSTONE_AXON_TRANSFER_STALE_THRESHOLD", 60), + FactoryID: getEnv("KEYSTONE_FACTORY_ID", "factory-default"), }, AxonRecorder: RecorderConfig{ WSPort: getEnvInt("KEYSTONE_AXON_RECORDER_WS_PORT", 8091), PingInterval: getEnvInt("KEYSTONE_AXON_RECORDER_PING_INTERVAL", 30), + PingTimeout: getEnvInt("KEYSTONE_AXON_RECORDER_PING_TIMEOUT", 10), + StaleThreshold: getEnvInt("KEYSTONE_AXON_RECORDER_STALE_THRESHOLD", 60), ResponseTimeout: getEnvInt("KEYSTONE_AXON_RECORDER_RESPONSE_TIMEOUT", 15), }, } diff --git a/internal/services/hub.go b/internal/services/hub.go index 040e9e9..ee7e30b 100644 --- a/internal/services/hub.go +++ b/internal/services/hub.go @@ -23,6 +23,8 @@ type Connection interface { GetWSConn() *websocket.Conn // GetConnectedAt returns the time the connection was established. GetConnectedAt() time.Time + // GetLastSeenAt returns when the connection last proved it was alive. + GetLastSeenAt() time.Time } // Hub is a generic, concurrency-safe registry of WebSocket connections keyed @@ -49,18 +51,44 @@ func newHub[T Connection](label string) *Hub[T] { // registered for the same device, the new connection is rejected (caller must // close it) and false is returned. Callers must pass a non-nil conn. func (h *Hub[T]) connect(deviceID string, conn T) bool { + _, ok := h.connectWithStaleThreshold(deviceID, conn, 0) + return ok +} + +// connectWithStaleThreshold registers conn under deviceID. If another +// connection exists and has not exceeded staleThreshold, the new connection is +// rejected. If the old connection is stale, it is replaced and returned so the +// caller can close it outside the hub lock. +func (h *Hub[T]) connectWithStaleThreshold(deviceID string, conn T, staleThreshold time.Duration) (T, bool) { + var zero T + h.mu.Lock() defer h.mu.Unlock() if old, exists := h.connections[deviceID]; exists { if old.GetWSConn() != nil && old.GetWSConn() != conn.GetWSConn() { + lastSeenAt := old.GetLastSeenAt() + isStale := staleThreshold > 0 && !lastSeenAt.IsZero() && time.Since(lastSeenAt) > staleThreshold + if isStale { + h.connections[deviceID] = conn + logger.Printf( + "[%s] Hub: replacing stale connection for device %s (last_seen_at=%s, stale_threshold=%s)", + h.label, + deviceID, + lastSeenAt.Format(time.RFC3339), + staleThreshold, + ) + logger.Printf("[%s] Hub: device %s registered, total connections=%d", h.label, deviceID, len(h.connections)) + return old, true + } + logger.Printf("[%s] Hub: rejecting new connection for device %s (already connected)", h.label, deviceID) - return false + return zero, false } } h.connections[deviceID] = conn logger.Printf("[%s] Hub: device %s registered, total connections=%d", h.label, deviceID, len(h.connections)) - return true + return zero, true } // disconnect removes the connection for deviceID only if it matches conn. diff --git a/internal/services/hub_test.go b/internal/services/hub_test.go new file mode 100644 index 0000000..4287f4e --- /dev/null +++ b/internal/services/hub_test.go @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: 2026 ArcheBase +// +// SPDX-License-Identifier: MulanPSL-2.0 + +package services + +import ( + "testing" + "time" + + "github.com/coder/websocket" +) + +func TestRecorderHubConnectWithStaleThresholdRejectsFreshConnection(t *testing.T) { + hub := NewRecorderHub() + deviceID := "robot-001" + + oldConn := hub.NewRecorderConn(&websocket.Conn{}, deviceID, "127.0.0.1") + oldConn.LastSeenAt = time.Now() + if !hub.Connect(deviceID, oldConn) { + t.Fatalf("initial connect failed") + } + + newConn := hub.NewRecorderConn(&websocket.Conn{}, deviceID, "127.0.0.1") + replaced, ok := hub.ConnectWithStaleThreshold(deviceID, newConn, time.Minute) + if ok { + t.Fatalf("fresh duplicate connection was accepted") + } + if replaced != nil { + t.Fatalf("fresh duplicate returned replaced connection") + } + if got := hub.Get(deviceID); got != oldConn { + t.Fatalf("hub connection changed on rejected duplicate") + } +} + +func TestRecorderHubConnectWithStaleThresholdReplacesStaleConnection(t *testing.T) { + hub := NewRecorderHub() + deviceID := "robot-001" + + oldConn := hub.NewRecorderConn(&websocket.Conn{}, deviceID, "127.0.0.1") + oldConn.LastSeenAt = time.Now().Add(-2 * time.Minute) + if !hub.Connect(deviceID, oldConn) { + t.Fatalf("initial connect failed") + } + + newConn := hub.NewRecorderConn(&websocket.Conn{}, deviceID, "127.0.0.1") + replaced, ok := hub.ConnectWithStaleThreshold(deviceID, newConn, time.Minute) + if !ok { + t.Fatalf("stale duplicate connection was rejected") + } + if replaced != oldConn { + t.Fatalf("replaced=%p want oldConn=%p", replaced, oldConn) + } + if got := hub.Get(deviceID); got != newConn { + t.Fatalf("hub connection=%p want newConn=%p", got, newConn) + } + + if hub.Disconnect(deviceID, oldConn) { + t.Fatalf("old stale connection disconnected the current hub entry") + } + if got := hub.Get(deviceID); got != newConn { + t.Fatalf("old disconnect changed current hub connection") + } + if !hub.Disconnect(deviceID, newConn) { + t.Fatalf("new connection did not disconnect") + } + if got := hub.Get(deviceID); got != nil { + t.Fatalf("hub connection=%p want nil", got) + } +} diff --git a/internal/services/recorder_hub.go b/internal/services/recorder_hub.go index 52553b7..902989e 100644 --- a/internal/services/recorder_hub.go +++ b/internal/services/recorder_hub.go @@ -89,6 +89,9 @@ func (r *RecorderConn) GetWSConn() *websocket.Conn { return r.Conn } // GetConnectedAt implements Connection. func (r *RecorderConn) GetConnectedAt() time.Time { return r.ConnectedAt } +// GetLastSeenAt implements Connection. +func (r *RecorderConn) GetLastSeenAt() time.Time { return r.LastSeenAt } + // GetState returns a copy of the recorder state. func (r *RecorderConn) GetState() RecorderState { r.StateMu.RLock() @@ -140,10 +143,16 @@ func (h *RecorderHub) Connect(deviceID string, rc *RecorderConn) bool { return h.connect(deviceID, rc) } +// ConnectWithStaleThreshold registers a recorder connection, replacing and +// returning the old connection when it has not been seen within staleThreshold. +func (h *RecorderHub) ConnectWithStaleThreshold(deviceID string, rc *RecorderConn, staleThreshold time.Duration) (*RecorderConn, bool) { + return h.connectWithStaleThreshold(deviceID, rc, staleThreshold) +} + // Disconnect removes a recorder connection and drains any pending RPC waiters. -func (h *RecorderHub) Disconnect(deviceID string, rc *RecorderConn) { +func (h *RecorderHub) Disconnect(deviceID string, rc *RecorderConn) bool { if !h.disconnect(deviceID, rc) { - return + return false } // Unblock any goroutines waiting for an RPC response from this device. @@ -161,6 +170,7 @@ func (h *RecorderHub) Disconnect(deviceID string, rc *RecorderConn) { } } rc.PendingMu.Unlock() + return true } // Get returns the recorder connection for a device, or nil if not connected. diff --git a/internal/services/transfer_hub.go b/internal/services/transfer_hub.go index fbe3a7d..0272fd0 100644 --- a/internal/services/transfer_hub.go +++ b/internal/services/transfer_hub.go @@ -116,6 +116,9 @@ func (d *TransferConn) GetWSConn() *websocket.Conn { return d.Conn } // GetConnectedAt implements Connection. func (d *TransferConn) GetConnectedAt() time.Time { return d.ConnectedAt } +// GetLastSeenAt implements Connection. +func (d *TransferConn) GetLastSeenAt() time.Time { return d.LastSeenAt } + // RecordEvent appends an event to the device's ring buffer func (d *TransferConn) RecordEvent(direction string, payload map[string]interface{}) { d.events.Push(DeviceEvent{ @@ -188,11 +191,15 @@ func (h *TransferHub) Connect(deviceID string, dc *TransferConn) bool { return h.connect(deviceID, dc) } +// ConnectWithStaleThreshold registers a transfer connection, replacing and +// returning the old connection when it has not been seen within staleThreshold. +func (h *TransferHub) ConnectWithStaleThreshold(deviceID string, dc *TransferConn, staleThreshold time.Duration) (*TransferConn, bool) { + return h.connectWithStaleThreshold(deviceID, dc, staleThreshold) +} + // Disconnect removes a device connection -func (h *TransferHub) Disconnect(deviceID string, dc *TransferConn) { - if !h.disconnect(deviceID, dc) { - return - } +func (h *TransferHub) Disconnect(deviceID string, dc *TransferConn) bool { + return h.disconnect(deviceID, dc) } // Get returns the TransferConn for a device, or nil if not connected