Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 180 additions & 59 deletions internal/api/handlers/axon_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package handlers

import (
"context"
"database/sql"
"encoding/json"
"errors"
"net/http"
Expand Down Expand Up @@ -155,32 +156,43 @@ func (h *RecorderHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request

remoteIP := extractIP(r.RemoteAddr)
rc := h.hub.NewRecorderConn(conn, deviceID, remoteIP)
if !h.hub.Connect(deviceID, rc) {
staleConn, ok := h.hub.ConnectWithStaleThreshold(deviceID, rc, recorderStaleThreshold(h.cfg))
if !ok {
logger.Printf("[RECORDER] Device %s: connection rejected (already connected)", deviceID)
if err := conn.Close(websocket.StatusPolicyViolation, "device already connected"); err != nil {
logger.Printf("[RECORDER] Device %s: WebSocket close error: %v", deviceID, err)
if !isExpectedWebSocketCloseError(err) {
logger.Printf("[RECORDER] Device %s: WebSocket close error: %v", deviceID, err)
}
}
return
}
closeStaleRecorderConn(deviceID, staleConn)

defer func() {
if err := conn.Close(websocket.StatusNormalClosure, ""); err != nil {
logger.Printf("[RECORDER] Device %s: WebSocket close error: %v", deviceID, err)
if !isExpectedWebSocketCloseError(err) {
logger.Printf("[RECORDER] Device %s: WebSocket close error: %v", deviceID, err)
}
}
}()
defer func() {
if h.hub.Disconnect(deviceID, rc) {
revertRunnableTasksOnDeviceDisconnect(h.db, deviceID, nil, 0, false)
}
}()
defer h.hub.Disconnect(deviceID, rc)
defer revertRunnableTasksOnDeviceDisconnect(h.db, deviceID, nil, 0, false)

ctx := r.Context()
go h.pingLoop(ctx, conn)
go h.pingLoop(ctx, rc)

// #nosec G706 -- Set aside for now
logger.Printf("[RECORDER] Recorder %s connected from %s", deviceID, remoteIP)

for {
_, raw, err := conn.Read(ctx)
if err != nil {
logger.Printf("[RECORDER] Recorder %s disconnected: %v", deviceID, err)
if !isExpectedWebSocketCloseError(err) {
logger.Printf("[RECORDER] Recorder %s disconnected: %v", deviceID, err)
}
return
}

Expand Down Expand Up @@ -228,24 +240,7 @@ func (h *RecorderHandler) Config(c *gin.Context) {
return
}

// If RPC succeeded (HTTP 200), advance task status: pending -> ready.
// This is best-effort; failures should not change the RPC response.
if taskID != "" && h.db != nil {
now := time.Now().UTC()
_, err := h.db.Exec(
`UPDATE tasks
SET
status = 'ready',
ready_at = CASE WHEN ready_at IS NULL THEN ? ELSE ready_at END,
updated_at = ?
WHERE task_id = ? AND status = 'pending' AND deleted_at IS NULL`,
now, now, taskID,
)
if err != nil {
logger.Printf("[RECORDER] Device %s: failed to advance task pending->ready after config: task=%s err=%v", c.Param("device_id"), taskID, err)
return
}
}
advanceTaskPendingToReady(h.db, c.Param("device_id"), taskID, "config")
}

// Begin sends begin recording RPC to the recorder.
Expand Down Expand Up @@ -279,24 +274,17 @@ func (h *RecorderHandler) Begin(c *gin.Context) {
return
}

// If RPC succeeded (HTTP 200), advance task status: ready -> in_progress.
// If RPC succeeded (HTTP 200), advance task status: pending/ready -> in_progress.
// pending is allowed because recorder may preserve ready state across a transient
// WebSocket disconnect while Keystone has already rolled the task back.
if taskID != "" && h.db != nil {
now := time.Now().UTC()
res, err := h.db.Exec(
`UPDATE tasks
SET
status = 'in_progress',
started_at = CASE WHEN started_at IS NULL THEN ? ELSE started_at END,
updated_at = ?
WHERE task_id = ? AND status = 'ready' AND deleted_at IS NULL`,
now, now, taskID,
)
res, err := advanceTaskPendingOrReadyToInProgress(h.db, taskID)
if err != nil {
logger.Printf("[RECORDER] Device %s: failed to advance task ready->in_progress after begin: task=%s err=%v", c.Param("device_id"), taskID, err)
logger.Printf("[RECORDER] Device %s: failed to advance task pending/ready->in_progress after begin: task=%s err=%v", c.Param("device_id"), taskID, err)
return
}
if n, _ := res.RowsAffected(); n == 0 {
logger.Printf("[RECORDER] Device %s: task ready->in_progress skipped after begin (not found or not ready): task=%s", c.Param("device_id"), taskID)
h.logBeginTransitionNoop(c.Param("device_id"), taskID)
}
}
}
Expand Down Expand Up @@ -642,24 +630,7 @@ func (h *RecorderHandler) handleMessage(deviceID string, rc *services.RecorderCo
taskID := stringValue(data, "task_id")
// #nosec G706 -- Set aside for now
logger.Printf("[RECORDER] Recorder %s config applied task=%s", deviceID, taskID)
// Advance task status: pending -> ready (best-effort, mirrors HTTP Config handler).
if taskID != "" && h.db != nil {
now := time.Now().UTC()
res, err := h.db.Exec(
`UPDATE tasks
SET
status = 'ready',
ready_at = CASE WHEN ready_at IS NULL THEN ? ELSE ready_at END,
updated_at = ?
WHERE task_id = ? AND status = 'pending' AND deleted_at IS NULL`,
now, now, taskID,
)
if err != nil {
logger.Printf("[RECORDER] Recorder %s: failed to advance task pending->ready after config_applied: task=%s err=%v", deviceID, taskID, err)
} else if n, _ := res.RowsAffected(); n == 0 {
logger.Printf("[RECORDER] Recorder %s: task pending->ready skipped after config_applied (not found or not pending): task=%s", deviceID, taskID)
}
}
advanceTaskPendingToReady(h.db, deviceID, taskID, "config_applied")
default:
// #nosec G706 -- Set aside for now
logger.Printf("[RECORDER] Recorder %s unknown message type %q", deviceID, msgType)
Expand Down Expand Up @@ -687,25 +658,175 @@ func (h *RecorderHandler) handleStateUpdate(rc *services.RecorderConn, msg map[s
Raw: data,
}
rc.UpdateState(state)
h.reconcileRecorderTaskState(rc.DeviceID, state)
// #nosec G706 -- Set aside for now
logger.Printf("[RECORDER] Recorder %s state=%s task=%s", rc.DeviceID, state.CurrentState, state.TaskID)
}

func (h *RecorderHandler) pingLoop(ctx context.Context, conn *websocket.Conn) {
ticker := time.NewTicker(time.Duration(h.cfg.PingInterval) * time.Second)
func (h *RecorderHandler) reconcileRecorderTaskState(deviceID string, state services.RecorderState) {
if h.db == nil {
return
}
taskID := strings.TrimSpace(state.TaskID)
if taskID == "" {
return
}

switch strings.TrimSpace(state.CurrentState) {
case "ready":
advanceTaskPendingToReady(h.db, deviceID, taskID, "state_update ready")
case "recording":
res, err := advanceTaskPendingOrReadyToInProgress(h.db, taskID)
if err != nil {
logger.Printf("[RECORDER] Recorder %s: failed to advance task pending/ready->in_progress after state_update recording: task=%s err=%v", deviceID, taskID, err)
return
}
if n, _ := res.RowsAffected(); n > 0 {
logger.Printf("[RECORDER] Device %s: task status reconciled: task=%s source=state_update_recording status=in_progress", deviceID, taskID)
}
}
}

func advanceTaskPendingToReady(db *sqlx.DB, deviceID, taskID, source string) {
taskID = strings.TrimSpace(taskID)
if db == nil || taskID == "" {
return
}
now := time.Now().UTC()
res, err := db.Exec(
`UPDATE tasks
SET
status = 'ready',
ready_at = CASE WHEN ready_at IS NULL THEN ? ELSE ready_at END,
updated_at = ?
WHERE task_id = ? AND status = 'pending' AND deleted_at IS NULL`,
now, now, taskID,
)
if err != nil {
logger.Printf("[RECORDER] Device %s: failed to advance task pending->ready after %s: task=%s err=%v", deviceID, source, taskID, err)
return
}
if n, _ := res.RowsAffected(); n > 0 {
logger.Printf("[RECORDER] Device %s: task status reconciled: task=%s source=%s status=ready", deviceID, taskID, source)
}
}

func advanceTaskPendingOrReadyToInProgress(db *sqlx.DB, taskID string) (sql.Result, error) {
if db == nil {
return nil, nil
}
now := time.Now().UTC()
return db.Exec(
`UPDATE tasks
SET
status = 'in_progress',
started_at = CASE WHEN started_at IS NULL THEN ? ELSE started_at END,
updated_at = ?
WHERE task_id = ? AND status IN ('pending', 'ready') AND deleted_at IS NULL`,
now, now, strings.TrimSpace(taskID),
)
}

func (h *RecorderHandler) logBeginTransitionNoop(deviceID, taskID string) {
status, ok, err := currentTaskStatus(h.db, taskID)
if err != nil {
logger.Printf("[RECORDER] Device %s: task status lookup failed after begin: task=%s err=%v", deviceID, taskID, err)
return
}
if ok && (status == "in_progress" || status == "completed") {
return
}
if !ok {
logger.Printf("[RECORDER] Device %s: task pending/ready->in_progress skipped after begin (task not found): task=%s", deviceID, taskID)
return
}
logger.Printf("[RECORDER] Device %s: task pending/ready->in_progress skipped after begin (current_status=%s): task=%s", deviceID, status, taskID)
}

func currentTaskStatus(db *sqlx.DB, taskID string) (string, bool, error) {
if db == nil {
return "", false, nil
}
var status string
err := db.Get(&status, `SELECT status FROM tasks WHERE task_id = ? AND deleted_at IS NULL`, strings.TrimSpace(taskID))
if errors.Is(err, sql.ErrNoRows) {
return "", false, nil
}
if err != nil {
return "", false, err
}
return status, true, nil
}

func (h *RecorderHandler) pingLoop(ctx context.Context, rc *services.RecorderConn) {
interval := recorderPingInterval(h.cfg)
if interval <= 0 || rc == nil || rc.Conn == nil {
return
}
timeout := recorderPingTimeout(h.cfg)
if timeout <= 0 {
timeout = interval
}

ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := conn.Ping(ctx); err != nil {
pingCtx, cancel := context.WithTimeout(ctx, timeout)
err := rc.Conn.Ping(pingCtx)
cancel()
if err != nil {
if ctx.Err() == nil {
logger.Printf("[RECORDER] Recorder %s ping failed: %v", rc.DeviceID, err)
if closeErr := rc.Conn.CloseNow(); closeErr != nil {
if !isExpectedWebSocketCloseError(closeErr) {
logger.Printf("[RECORDER] Recorder %s close after ping failure: %v", rc.DeviceID, closeErr)
}
}
}
return
}
rc.LastSeenAt = time.Now()
case <-ctx.Done():
return
}
}
}

func recorderPingInterval(cfg *config.RecorderConfig) time.Duration {
if cfg == nil || cfg.PingInterval <= 0 {
return 0
}
return time.Duration(cfg.PingInterval) * time.Second
}

func recorderPingTimeout(cfg *config.RecorderConfig) time.Duration {
if cfg == nil || cfg.PingTimeout <= 0 {
return 0
}
return time.Duration(cfg.PingTimeout) * time.Second
}

func recorderStaleThreshold(cfg *config.RecorderConfig) time.Duration {
if cfg == nil || cfg.StaleThreshold <= 0 {
return 0
}
return time.Duration(cfg.StaleThreshold) * time.Second
}

func closeStaleRecorderConn(deviceID string, rc *services.RecorderConn) {
if rc == nil || rc.Conn == nil {
return
}
logger.Printf("[RECORDER] Device %s: closing stale WebSocket connection", deviceID)
if err := rc.Conn.CloseNow(); err != nil {
if !isExpectedWebSocketCloseError(err) {
logger.Printf("[RECORDER] Device %s: stale WebSocket close error: %v", deviceID, err)
}
}
}

func stringValue(m map[string]interface{}, key string) string {
if m == nil {
return ""
Expand Down
13 changes: 12 additions & 1 deletion internal/api/handlers/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -964,11 +964,22 @@ func (h *TaskHandler) OnRecordingStart(c *gin.Context) {
return
}

taskStatus := "unknown"
if h.db != nil {
res, err := advanceTaskPendingOrReadyToInProgress(h.db, callback.TaskID)
if err != nil {
logger.Printf("[RECORDER] Device %s: failed to advance task pending/ready->in_progress after start callback: task=%s err=%v", callback.DeviceID, callback.TaskID, err)
} else if n, _ := res.RowsAffected(); n > 0 {
taskStatus = "in_progress"
logger.Printf("[RECORDER] Device %s: task status reconciled: task=%s source=start_callback status=in_progress", callback.DeviceID, callback.TaskID)
}
}

now := time.Now()
nowStr := now.Format(time.RFC3339)
c.JSON(http.StatusOK, gin.H{
"status": "acknowledged",
"task_status": "unknown",
"task_status": taskStatus,
"acknowledged_at": nowStr,
})
}
Expand Down
Loading
Loading