diff --git a/go/core/internal/database/fake/client.go b/go/core/internal/database/fake/client.go deleted file mode 100644 index 801514f3d..000000000 --- a/go/core/internal/database/fake/client.go +++ /dev/null @@ -1,1069 +0,0 @@ -package fake - -import ( - "context" - "encoding/json" - "fmt" - "math" - "slices" - "strings" - "sync" - "time" - - "github.com/jackc/pgx/v5" - "github.com/kagent-dev/kagent/go/api/database" - "github.com/kagent-dev/kagent/go/api/v1alpha2" - "github.com/pgvector/pgvector-go" - "trpc.group/trpc-go/trpc-a2a-go/protocol" -) - -// InMemoryFakeClient is a fake implementation of database.Client for testing -type InMemoryFakeClient struct { - mu sync.RWMutex - feedback map[string]*database.Feedback - tasks map[string]*database.Task // changed from runs, key: taskID - sessions map[string]*database.Session // key: sessionID_userID - agents map[string]*database.Agent // changed from teams - toolServers map[string]*database.ToolServer - tools map[string]*database.Tool - eventsBySession map[string][]*database.Event // key: sessionId - events map[string]*database.Event // key: eventID - pushNotifications map[string]*protocol.TaskPushNotificationConfig // key: taskID - checkpoints map[string]*database.LangGraphCheckpoint // key: user_id:thread_id:checkpoint_ns:checkpoint_id - checkpointWrites map[string][]*database.LangGraphCheckpointWrite // key: user_id:thread_id:checkpoint_ns:checkpoint_id - crewaiMemory map[string][]*database.CrewAIAgentMemory // key: user_id:thread_id:agent_id - crewaiFlowStates map[string]*database.CrewAIFlowState // key: user_id:thread_id - memories map[string]*database.Memory // key: user_id:thread_id:agent_id - nextFeedbackID int -} - -// NewClient creates a new fake database client -func NewClient() database.Client { - return &InMemoryFakeClient{ - feedback: make(map[string]*database.Feedback), - tasks: make(map[string]*database.Task), - sessions: make(map[string]*database.Session), - agents: make(map[string]*database.Agent), - toolServers: make(map[string]*database.ToolServer), - tools: make(map[string]*database.Tool), - eventsBySession: make(map[string][]*database.Event), - events: make(map[string]*database.Event), - pushNotifications: make(map[string]*protocol.TaskPushNotificationConfig), - checkpoints: make(map[string]*database.LangGraphCheckpoint), - checkpointWrites: make(map[string][]*database.LangGraphCheckpointWrite), - crewaiMemory: make(map[string][]*database.CrewAIAgentMemory), - crewaiFlowStates: make(map[string]*database.CrewAIFlowState), - memories: make(map[string]*database.Memory), - nextFeedbackID: 1, - } -} - -func (c *InMemoryFakeClient) sessionKey(sessionID, userID string) string { - return fmt.Sprintf("%s_%s", sessionID, userID) -} - -func (c *InMemoryFakeClient) DeletePushNotification(_ context.Context, taskID string) error { - c.mu.Lock() - defer c.mu.Unlock() - - delete(c.pushNotifications, taskID) - return nil -} - -func (c *InMemoryFakeClient) GetPushNotification(_ context.Context, taskID string, configID string) (*protocol.TaskPushNotificationConfig, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - return c.pushNotifications[taskID], nil -} - -func (c *InMemoryFakeClient) GetTask(_ context.Context, taskID string) (*protocol.Task, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - task, exists := c.tasks[taskID] - if !exists { - return nil, pgx.ErrNoRows - } - parsedTask := &protocol.Task{} - err := json.Unmarshal([]byte(task.Data), parsedTask) - if err != nil { - return nil, err - } - return parsedTask, nil -} - -func (c *InMemoryFakeClient) DeleteTask(_ context.Context, taskID string) error { - c.mu.Lock() - defer c.mu.Unlock() - - delete(c.tasks, taskID) - return nil -} - -// StoreFeedback creates a new feedback record -func (c *InMemoryFakeClient) StoreFeedback(_ context.Context, feedback *database.Feedback) error { - c.mu.Lock() - defer c.mu.Unlock() - - // Copy the feedback and assign an ID - newFeedback := *feedback - id := int64(c.nextFeedbackID) - newFeedback.MessageID = &id - c.nextFeedbackID++ - - key := fmt.Sprintf("%d", id) - c.feedback[key] = &newFeedback - return nil -} - -// StoreEvents creates a new event record -func (c *InMemoryFakeClient) StoreEvents(_ context.Context, events ...*database.Event) error { - c.mu.Lock() - defer c.mu.Unlock() - - for _, event := range events { - c.events[event.ID] = event - c.eventsBySession[event.SessionID] = append(c.eventsBySession[event.SessionID], event) - } - - return nil -} - -// StoreSession creates a new session record -func (c *InMemoryFakeClient) StoreSession(_ context.Context, session *database.Session) error { - c.mu.Lock() - defer c.mu.Unlock() - - key := c.sessionKey(session.ID, session.UserID) - c.sessions[key] = session - return nil -} - -// StoreAgent creates a new agent record -func (c *InMemoryFakeClient) StoreAgent(_ context.Context, agent *database.Agent) error { - c.mu.Lock() - defer c.mu.Unlock() - - c.agents[agent.ID] = agent - return nil -} - -// StoreTask creates a new task record -func (c *InMemoryFakeClient) StoreTask(_ context.Context, task *protocol.Task) error { - c.mu.Lock() - defer c.mu.Unlock() - - jsn, err := json.Marshal(task) - if err != nil { - return err - } - c.tasks[task.ID] = &database.Task{ - ID: task.ID, - Data: string(jsn), - } - return nil -} - -// StorePushNotification creates a new push notification record -func (c *InMemoryFakeClient) StorePushNotification(_ context.Context, config *protocol.TaskPushNotificationConfig) error { - c.mu.Lock() - defer c.mu.Unlock() - - c.pushNotifications[config.TaskID] = config - return nil -} - -// StoreToolServer creates a new tool server record -func (c *InMemoryFakeClient) StoreToolServer(_ context.Context, toolServer *database.ToolServer) (*database.ToolServer, error) { - c.mu.Lock() - defer c.mu.Unlock() - - c.toolServers[toolServer.Name] = toolServer - return toolServer, nil -} - -// CreateTool creates a new tool record -func (c *InMemoryFakeClient) CreateTool(_ context.Context, tool *database.Tool) error { - c.mu.Lock() - defer c.mu.Unlock() - - c.tools[tool.ID] = tool - return nil -} - -// DeleteSession deletes a session by ID and user ID -func (c *InMemoryFakeClient) DeleteSession(_ context.Context, sessionID string, userID string) error { - c.mu.Lock() - defer c.mu.Unlock() - - key := c.sessionKey(sessionID, userID) - delete(c.sessions, key) - return nil -} - -// DeleteAgent deletes an agent by name -func (c *InMemoryFakeClient) DeleteAgent(_ context.Context, agentName string) error { - c.mu.Lock() - defer c.mu.Unlock() - - _, exists := c.agents[agentName] - if !exists { - return pgx.ErrNoRows - } - - delete(c.agents, agentName) - - return nil -} - -// DeleteToolServer deletes a tool server by name -func (c *InMemoryFakeClient) DeleteToolServer(_ context.Context, serverName string, groupKind string) error { - c.mu.Lock() - defer c.mu.Unlock() - - delete(c.toolServers, serverName) - return nil -} - -// DeleteToolsForServer deletes tools for a tool server by name -func (c *InMemoryFakeClient) DeleteToolsForServer(_ context.Context, serverName string, groupKind string) error { - c.mu.Lock() - defer c.mu.Unlock() - - // Delete all tools that belong to the specified server - for toolID, tool := range c.tools { - if tool.ServerName == serverName { - delete(c.tools, toolID) - } - } - return nil -} - -// GetSession retrieves a session by ID and user ID -func (c *InMemoryFakeClient) GetSession(_ context.Context, sessionID string, userID string) (*database.Session, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - key := c.sessionKey(sessionID, userID) - session, exists := c.sessions[key] - if !exists { - return nil, pgx.ErrNoRows - } - return session, nil -} - -// GetAgent retrieves an agent by name -func (c *InMemoryFakeClient) GetAgent(_ context.Context, agentName string) (*database.Agent, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - agent, exists := c.agents[agentName] - if !exists { - return nil, pgx.ErrNoRows - } - return agent, nil -} - -// GetTool retrieves a tool by name -func (c *InMemoryFakeClient) GetTool(_ context.Context, toolName string) (*database.Tool, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - tool, exists := c.tools[toolName] - if !exists { - return nil, pgx.ErrNoRows - } - return tool, nil -} - -// GetToolServer retrieves a tool server by name -func (c *InMemoryFakeClient) GetToolServer(_ context.Context, serverName string) (*database.ToolServer, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - server, exists := c.toolServers[serverName] - if !exists { - return nil, pgx.ErrNoRows - } - return server, nil -} - -// ListFeedback lists all feedback for a user -func (c *InMemoryFakeClient) ListFeedback(_ context.Context, userID string) ([]database.Feedback, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []database.Feedback - for _, feedback := range c.feedback { - if feedback.UserID == userID { - result = append(result, *feedback) - } - } - return result, nil -} - -func (c *InMemoryFakeClient) ListTasksForSession(_ context.Context, sessionID string) ([]*protocol.Task, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []*protocol.Task - for _, task := range c.tasks { - if task.SessionID == sessionID { - parsed, err := task.Parse() - if err != nil { - return nil, err - } - result = append(result, &parsed) - } - } - return result, nil -} - -// ListSessions lists all sessions for a user -func (c *InMemoryFakeClient) ListSessions(_ context.Context, userID string) ([]database.Session, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []database.Session - for _, session := range c.sessions { - if session.UserID == userID { - result = append(result, *session) - } - } - slices.SortStableFunc(result, func(i, j database.Session) int { - return strings.Compare(i.ID, j.ID) - }) - return result, nil -} - -// ListSessionsForAgent lists all sessions for an agent, excluding agent-initiated sessions. -func (c *InMemoryFakeClient) ListSessionsForAgent(_ context.Context, agentID string, userID string) ([]database.Session, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []database.Session - for _, session := range c.sessions { - if session.AgentID != nil && *session.AgentID == agentID && session.UserID == userID { - // Exclude agent-initiated sessions from the listing - if session.Source != nil && *session.Source == database.SessionSourceAgent { - continue - } - result = append(result, *session) - } - } - slices.SortStableFunc(result, func(i, j database.Session) int { - return strings.Compare(i.ID, j.ID) - }) - return result, nil -} - -func (c *InMemoryFakeClient) ListSessionsForAgentAllUsers(_ context.Context, agentID string) ([]database.Session, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []database.Session - for _, session := range c.sessions { - if session.AgentID != nil && *session.AgentID == agentID { - if session.Source != nil && *session.Source == database.SessionSourceAgent { - continue - } - result = append(result, *session) - } - } - slices.SortStableFunc(result, func(i, j database.Session) int { - return strings.Compare(i.ID, j.ID) - }) - return result, nil -} - -// ListAgents lists all agents -func (c *InMemoryFakeClient) ListAgents(_ context.Context) ([]database.Agent, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []database.Agent - for _, agent := range c.agents { - result = append(result, *agent) - } - slices.SortStableFunc(result, func(i, j database.Agent) int { - return strings.Compare(i.ID, j.ID) - }) - return result, nil -} - -// ListToolServers lists all tool servers -func (c *InMemoryFakeClient) ListToolServers(_ context.Context) ([]database.ToolServer, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []database.ToolServer - for _, server := range c.toolServers { - result = append(result, *server) - } - slices.SortStableFunc(result, func(i, j database.ToolServer) int { - return strings.Compare(i.Name+i.GroupKind, j.Name+j.GroupKind) - }) - return result, nil -} - -// ListTools lists all tools for a user -func (c *InMemoryFakeClient) ListTools(_ context.Context) ([]database.Tool, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []database.Tool - for _, tool := range c.tools { - result = append(result, *tool) - } - slices.SortStableFunc(result, func(i, j database.Tool) int { - return strings.Compare(i.ServerName+i.ID, j.ServerName+j.ID) - }) - return result, nil -} - -// ListToolsForServer lists all tools for a specific server and toolserver type -func (c *InMemoryFakeClient) ListToolsForServer(_ context.Context, serverName string, groupKind string) ([]database.Tool, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []database.Tool - for _, tool := range c.tools { - // Search for tool server by name - toolServer, exists := c.toolServers[serverName] - if !exists { - continue - } - if tool.ServerName == toolServer.Name && tool.GroupKind == groupKind { - result = append(result, *tool) - } - } - - slices.SortStableFunc(result, func(i, j database.Tool) int { - return strings.Compare(i.ServerName+i.ID, j.ServerName+j.ID) - }) - return result, nil -} - -func (c *InMemoryFakeClient) ListPushNotifications(_ context.Context, taskID string) ([]*protocol.TaskPushNotificationConfig, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []*protocol.TaskPushNotificationConfig - config, exists := c.pushNotifications[taskID] - if exists { - result = append(result, config) - } - return result, nil -} - -// ListEventsForSession retrieves events for a specific session -func (c *InMemoryFakeClient) ListEventsForSession(_ context.Context, sessionID, userID string, options database.QueryOptions) ([]*database.Event, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - events, exists := c.eventsBySession[sessionID] - if !exists { - return nil, nil - } - - // Make a copy to avoid mutating the stored slice - result := make([]*database.Event, len(events)) - copy(result, events) - - if !options.OrderAsc { - // Default is DESC (newest first), reverse the insertion-order slice - for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 { - result[i], result[j] = result[j], result[i] - } - } - - return result, nil -} - -// RefreshToolsForServer refreshes a tool server -func (c *InMemoryFakeClient) RefreshToolsForServer(_ context.Context, serverName string, groupKind string, tools ...*v1alpha2.MCPTool) error { - c.mu.Lock() - defer c.mu.Unlock() - - // Simple implementation: remove all existing tools for this server+groupKind and add new ones - for toolID, tool := range c.tools { - if tool.ServerName == serverName && tool.GroupKind == groupKind { - delete(c.tools, toolID) - } - } - - // Add new tools - for _, tool := range tools { - c.tools[tool.Name] = &database.Tool{ - ID: tool.Name, - ServerName: serverName, - GroupKind: groupKind, - Description: tool.Description, - } - } - - return nil -} - -// UpdateSession updates a session -func (c *InMemoryFakeClient) UpdateSession(_ context.Context, session *database.Session) error { - c.mu.Lock() - defer c.mu.Unlock() - - key := c.sessionKey(session.ID, session.UserID) - c.sessions[key] = session - return nil -} - -// UpdateToolServer updates a tool server -func (c *InMemoryFakeClient) UpdateToolServer(_ context.Context, server *database.ToolServer) error { - c.mu.Lock() - defer c.mu.Unlock() - - c.toolServers[server.Name] = server - return nil -} - -// UpdateAgent updates an agent record -func (c *InMemoryFakeClient) UpdateAgent(_ context.Context, agent *database.Agent) error { - c.mu.Lock() - defer c.mu.Unlock() - - c.agents[agent.ID] = agent - return nil -} - -// UpdateTask updates a task record -func (c *InMemoryFakeClient) UpdateTask(_ context.Context, task *database.Task) error { - c.mu.Lock() - defer c.mu.Unlock() - - c.tasks[task.ID] = task - return nil -} - -// AddTool adds a tool for testing purposes -func (c *InMemoryFakeClient) AddTool(tool *database.Tool) { - c.mu.Lock() - defer c.mu.Unlock() - - c.tools[tool.ID] = tool -} - -// AddTask adds a task for testing purposes -func (c *InMemoryFakeClient) AddTask(task *database.Task) { - c.mu.Lock() - defer c.mu.Unlock() - - c.tasks[task.ID] = task -} - -// Clear clears all data for testing purposes -func (c *InMemoryFakeClient) Clear() { - c.mu.Lock() - defer c.mu.Unlock() - - c.feedback = make(map[string]*database.Feedback) - c.tasks = make(map[string]*database.Task) - c.sessions = make(map[string]*database.Session) - c.agents = make(map[string]*database.Agent) - c.toolServers = make(map[string]*database.ToolServer) - c.tools = make(map[string]*database.Tool) - c.eventsBySession = make(map[string][]*database.Event) - c.events = make(map[string]*database.Event) - c.pushNotifications = make(map[string]*protocol.TaskPushNotificationConfig) - c.checkpoints = make(map[string]*database.LangGraphCheckpoint) - c.checkpointWrites = make(map[string][]*database.LangGraphCheckpointWrite) - c.memories = make(map[string]*database.Memory) - c.nextFeedbackID = 1 -} - -// UpsertAgent upserts an agent record -func (c *InMemoryFakeClient) UpsertAgent(_ context.Context, agent *database.Agent) error { - c.mu.Lock() - defer c.mu.Unlock() - - c.agents[agent.ID] = agent - return nil -} - -// checkpointKey creates a key for checkpoint storage -func (c *InMemoryFakeClient) checkpointKey(userID, threadID, checkpointNS, checkpointID string) string { - return fmt.Sprintf("%s:%s:%s:%s", userID, threadID, checkpointNS, checkpointID) -} - -// StoreCheckpoint stores a LangGraph checkpoint -func (c *InMemoryFakeClient) StoreCheckpoint(_ context.Context, checkpoint *database.LangGraphCheckpoint) error { - c.mu.Lock() - defer c.mu.Unlock() - - key := c.checkpointKey(checkpoint.UserID, checkpoint.ThreadID, checkpoint.CheckpointNS, checkpoint.CheckpointID) - - // Check for idempotent retry - if existing, exists := c.checkpoints[key]; exists { - if existing.Metadata == checkpoint.Metadata && existing.Checkpoint == checkpoint.Checkpoint { - return nil // Idempotent success - } - return fmt.Errorf("checkpoint already exists with different data") - } - - // Store checkpoint - c.checkpoints[key] = checkpoint - - return nil -} - -// StoreCheckpointWrites stores checkpoint writes -func (c *InMemoryFakeClient) StoreCheckpointWrites(_ context.Context, writes []*database.LangGraphCheckpointWrite) error { - c.mu.Lock() - defer c.mu.Unlock() - - // Group writes by checkpoint key - writesByKey := make(map[string][]*database.LangGraphCheckpointWrite) - for _, write := range writes { - key := c.checkpointKey(write.UserID, write.ThreadID, write.CheckpointNS, write.CheckpointID) - writesByKey[key] = append(writesByKey[key], write) - } - - // Store writes for each checkpoint - for key, keyWrites := range writesByKey { - c.checkpointWrites[key] = append(c.checkpointWrites[key], keyWrites...) - } - - return nil -} - -// GetLatestCheckpoint retrieves the most recent checkpoint for a thread -func (c *InMemoryFakeClient) GetLatestCheckpoint(_ context.Context, userID, threadID, checkpointNS string) (*database.LangGraphCheckpoint, []*database.LangGraphCheckpointWrite, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var latest *database.LangGraphCheckpoint - var latestKey string - - // Find the latest checkpoint by creation time - for key, checkpoint := range c.checkpoints { - if checkpoint.UserID == userID && checkpoint.ThreadID == threadID && checkpoint.CheckpointNS == checkpointNS { - if latest == nil || checkpoint.CreatedAt.After(latest.CreatedAt) { - latest = checkpoint - latestKey = key - } - } - } - - if latest == nil { - return nil, nil, nil - } - - // Get writes for this checkpoint - writes := c.checkpointWrites[latestKey] - - return latest, writes, nil -} - -// GetCheckpoint retrieves a specific checkpoint by ID -func (c *InMemoryFakeClient) GetCheckpoint(_ context.Context, userID, threadID, checkpointNS, checkpointID string) (*database.LangGraphCheckpoint, []*database.LangGraphCheckpointWrite, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - key := c.checkpointKey(userID, threadID, checkpointNS, checkpointID) - checkpoint, exists := c.checkpoints[key] - if !exists { - return nil, nil, nil - } - - // Get writes for this checkpoint - writes := c.checkpointWrites[key] - - return checkpoint, writes, nil -} - -// ListCheckpoints lists checkpoints for a thread, optionally filtered by checkpointID -func (c *InMemoryFakeClient) ListCheckpoints(_ context.Context, userID, threadID, checkpointNS string, checkpointID *string, limit int) ([]*database.LangGraphCheckpointTuple, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []*database.LangGraphCheckpointTuple - - // Find matching checkpoints - for key, checkpoint := range c.checkpoints { - if checkpoint.UserID == userID && checkpoint.ThreadID == threadID && checkpoint.CheckpointNS == checkpointNS { - // If a specific checkpoint ID is requested, only return that one - if checkpointID != nil && checkpoint.CheckpointID != *checkpointID { - continue - } - - // Get writes for this checkpoint - writes := c.checkpointWrites[key] - if writes == nil { - writes = []*database.LangGraphCheckpointWrite{} - } - - result = append(result, &database.LangGraphCheckpointTuple{ - Checkpoint: checkpoint, - Writes: writes, - }) - } - } - - // Sort by creation time (newest first) - for i := 0; i < len(result)-1; i++ { - for j := i + 1; j < len(result); j++ { - if result[i].Checkpoint.CreatedAt.Before(result[j].Checkpoint.CreatedAt) { - result[i], result[j] = result[j], result[i] - } - } - } - - // Apply limit - if limit > 0 && len(result) > limit { - result = result[:limit] - } - - return result, nil -} - -// DeleteCheckpoint deletes a checkpoint and its writes atomically -func (c *InMemoryFakeClient) DeleteCheckpoint(_ context.Context, userID, threadID string) error { - c.mu.Lock() - defer c.mu.Unlock() - - // Find and delete all checkpoints for the thread - keysToDelete := make([]string, 0) - for key, checkpoint := range c.checkpoints { - if checkpoint.UserID == userID && checkpoint.ThreadID == threadID { - keysToDelete = append(keysToDelete, key) - } - } - - // Delete checkpoints and their writes - for _, key := range keysToDelete { - delete(c.checkpoints, key) - delete(c.checkpointWrites, key) - } - - return nil -} - -// ListWrites retrieves writes for a specific checkpoint -func (c *InMemoryFakeClient) ListWrites(_ context.Context, userID, threadID, checkpointNS, checkpointID string, offset, limit int) ([]*database.LangGraphCheckpointWrite, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - key := c.checkpointKey(userID, threadID, checkpointNS, checkpointID) - writes := c.checkpointWrites[key] - - if writes == nil { - return []*database.LangGraphCheckpointWrite{}, nil - } - - // Apply pagination - start := offset - if start >= len(writes) { - return []*database.LangGraphCheckpointWrite{}, nil - } - - end := len(writes) - if limit > 0 && start+limit < end { - end = start + limit - } - - return writes[start:end], nil -} - -// CrewAI methods - -// StoreCrewAIMemory stores CrewAI agent memory -func (c *InMemoryFakeClient) StoreCrewAIMemory(_ context.Context, memory *database.CrewAIAgentMemory) error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.crewaiMemory == nil { - c.crewaiMemory = make(map[string][]*database.CrewAIAgentMemory) - } - - key := fmt.Sprintf("%s:%s", memory.UserID, memory.ThreadID) - c.crewaiMemory[key] = append(c.crewaiMemory[key], memory) - - return nil -} - -// SearchCrewAIMemoryByTask searches CrewAI agent memory by task description across all agents for a session -func (c *InMemoryFakeClient) SearchCrewAIMemoryByTask(_ context.Context, userID, threadID, taskDescription string, limit int) ([]*database.CrewAIAgentMemory, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - if c.crewaiMemory == nil { - return []*database.CrewAIAgentMemory{}, nil - } - - var allMemories []*database.CrewAIAgentMemory - - // Search across all agents for this user/thread - for key, memories := range c.crewaiMemory { - // Key format is "user_id:thread_id" - if strings.HasPrefix(key, userID+":"+threadID) { - for _, memory := range memories { - // Parse the JSON memory data and search for task_description - var memoryData map[string]any - if err := json.Unmarshal([]byte(memory.MemoryData), &memoryData); err == nil { - if taskDesc, ok := memoryData["task_description"].(string); ok { - if strings.Contains(strings.ToLower(taskDesc), strings.ToLower(taskDescription)) { - allMemories = append(allMemories, memory) - } - } - } - // Fallback to simple string search if JSON parsing fails - if len(allMemories) == 0 && strings.Contains(strings.ToLower(memory.MemoryData), strings.ToLower(taskDescription)) { - allMemories = append(allMemories, memory) - } - } - } - } - - // Sort by created_at DESC, then by score ASC (if score exists in JSON) - slices.SortStableFunc(allMemories, func(i, j *database.CrewAIAgentMemory) int { - // First sort by created_at DESC (most recent first) - if !i.CreatedAt.Equal(j.CreatedAt) { - if i.CreatedAt.After(j.CreatedAt) { - return -1 - } else { - return 1 - } - } - - // If created_at is equal, sort by score ASC - var scoreI, scoreJ float64 - var memoryDataI, memoryDataJ map[string]any - - if err := json.Unmarshal([]byte(i.MemoryData), &memoryDataI); err == nil { - if score, ok := memoryDataI["score"].(float64); ok { - scoreI = score - } - } - - if err := json.Unmarshal([]byte(j.MemoryData), &memoryDataJ); err == nil { - if score, ok := memoryDataJ["score"].(float64); ok { - scoreJ = score - } - } - - if scoreI < scoreJ { - return -1 - } else if scoreI > scoreJ { - return 1 - } else { - return 0 - } - }) - - // Apply limit - if limit > 0 && len(allMemories) > limit { - allMemories = allMemories[:limit] - } - - return allMemories, nil -} - -// ResetCrewAIMemory deletes all CrewAI agent memory for a session -func (c *InMemoryFakeClient) ResetCrewAIMemory(_ context.Context, userID, threadID string) error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.crewaiMemory == nil { - return nil - } - - // Find and delete all memory entries for this user/thread combination - keysToDelete := make([]string, 0) - for key := range c.crewaiMemory { - // Key format is "user_id:thread_id" - if strings.HasPrefix(key, userID+":"+threadID) { - keysToDelete = append(keysToDelete, key) - } - } - - // Delete the entries - for _, key := range keysToDelete { - delete(c.crewaiMemory, key) - } - - return nil -} - -// StoreCrewAIFlowState stores CrewAI flow state -func (c *InMemoryFakeClient) StoreCrewAIFlowState(_ context.Context, state *database.CrewAIFlowState) error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.crewaiFlowStates == nil { - c.crewaiFlowStates = make(map[string]*database.CrewAIFlowState) - } - - key := fmt.Sprintf("%s:%s", state.UserID, state.ThreadID) - c.crewaiFlowStates[key] = state - - return nil -} - -// GetCrewAIFlowState retrieves CrewAI flow state -func (c *InMemoryFakeClient) GetCrewAIFlowState(_ context.Context, userID, threadID string) (*database.CrewAIFlowState, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - if c.crewaiFlowStates == nil { - return nil, nil - } - - key := fmt.Sprintf("%s:%s", userID, threadID) - state := c.crewaiFlowStates[key] - - return state, nil -} - -// memoryKey creates a unique key for a memory record -func (c *InMemoryFakeClient) memoryKey(agentName, userID, id string) string { - return fmt.Sprintf("%s:%s:%s", agentName, userID, id) -} - -// StoreAgentMemory stores agent memory -func (c *InMemoryFakeClient) StoreAgentMemory(_ context.Context, memory *database.Memory) error { - c.mu.Lock() - defer c.mu.Unlock() - - if memory.ID == "" { - memory.ID = fmt.Sprintf("%d", len(c.memories)+1) - } - key := c.memoryKey(memory.AgentName, memory.UserID, memory.ID) - c.memories[key] = memory - return nil -} - -// StoreAgentMemories stores multiple agent memories -func (c *InMemoryFakeClient) StoreAgentMemories(_ context.Context, memories []*database.Memory) error { - c.mu.Lock() - defer c.mu.Unlock() - - for _, memory := range memories { - if memory.ID == "" { - memory.ID = fmt.Sprintf("%d", len(c.memories)+1) - } - key := c.memoryKey(memory.AgentName, memory.UserID, memory.ID) - c.memories[key] = memory - } - return nil -} - -// cosineSimilarity computes the cosine similarity between two float32 slices. -// Returns 0 if either vector has zero magnitude. -func cosineSimilarity(a, b []float32) float64 { - if len(a) != len(b) { - return 0 - } - var dot, normA, normB float64 - for i := range a { - ai := float64(a[i]) - bi := float64(b[i]) - dot += ai * bi - normA += ai * ai - normB += bi * bi - } - if normA == 0 || normB == 0 { - return 0 - } - return dot / (math.Sqrt(normA) * math.Sqrt(normB)) -} - -// SearchAgentMemory searches agent memory by vector similarity -func (c *InMemoryFakeClient) SearchAgentMemory(_ context.Context, agentName, userID string, embedding pgvector.Vector, limit int) ([]database.AgentMemorySearchResult, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - queryVec := embedding.Slice() - now := time.Now() - - var results []database.AgentMemorySearchResult - for _, memory := range c.memories { - if memory.AgentName != agentName || memory.UserID != userID { - continue - } - // Skip expired memories - if memory.ExpiresAt != nil && memory.ExpiresAt.Before(now) { - continue - } - score := cosineSimilarity(queryVec, memory.Embedding.Slice()) - results = append(results, database.AgentMemorySearchResult{ - Memory: *memory, - Score: score, - }) - } - - // Sort by score descending - slices.SortStableFunc(results, func(i, j database.AgentMemorySearchResult) int { - if i.Score > j.Score { - return -1 - } else if i.Score < j.Score { - return 1 - } - return 0 - }) - - if limit > 0 && len(results) > limit { - results = results[:limit] - } - - return results, nil -} - -// ListAgentMemories lists agent memories ordered by access count descending -func (c *InMemoryFakeClient) ListAgentMemories(_ context.Context, agentName, userID string) ([]database.Memory, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - var result []database.Memory - for _, memory := range c.memories { - if memory.AgentName == agentName && memory.UserID == userID { - result = append(result, *memory) - } - } - - // Sort by access_count DESC - slices.SortStableFunc(result, func(i, j database.Memory) int { - if i.AccessCount > j.AccessCount { - return -1 - } else if i.AccessCount < j.AccessCount { - return 1 - } - return 0 - }) - - return result, nil -} - -// DeleteAgentMemory deletes all agent memory for a given agent and user -func (c *InMemoryFakeClient) DeleteAgentMemory(_ context.Context, agentName, userID string) error { - c.mu.Lock() - defer c.mu.Unlock() - - for key, memory := range c.memories { - if memory.AgentName == agentName && memory.UserID == userID { - delete(c.memories, key) - } - } - return nil -} - -// PruneExpiredMemories removes all memories whose ExpiresAt is in the past -func (c *InMemoryFakeClient) PruneExpiredMemories(_ context.Context) error { - c.mu.Lock() - defer c.mu.Unlock() - - now := time.Now() - for key, memory := range c.memories { - if memory.ExpiresAt != nil && memory.ExpiresAt.Before(now) { - delete(c.memories, key) - } - } - return nil -} diff --git a/go/core/internal/httpserver/handlers/agents_test.go b/go/core/internal/httpserver/handlers/agents_test.go index a6f27d3a0..b6dfa1a66 100644 --- a/go/core/internal/httpserver/handlers/agents_test.go +++ b/go/core/internal/httpserver/handlers/agents_test.go @@ -19,7 +19,6 @@ import ( "github.com/kagent-dev/kagent/go/api/database" api "github.com/kagent-dev/kagent/go/api/httpapi" "github.com/kagent-dev/kagent/go/api/v1alpha2" - database_fake "github.com/kagent-dev/kagent/go/core/internal/database/fake" "github.com/kagent-dev/kagent/go/core/internal/httpserver/auth" "github.com/kagent-dev/kagent/go/core/internal/httpserver/handlers" common "github.com/kagent-dev/kagent/go/core/internal/utils" @@ -80,14 +79,14 @@ func createTestSandboxAgentCRD(name string, modelConfig *v1alpha2.ModelConfig, c } } -func setupTestHandler(objects ...client.Object) (*handlers.AgentsHandler, string) { +func setupTestHandler(t *testing.T, objects ...client.Object) (*handlers.AgentsHandler, string) { kubeClient := fake.NewClientBuilder(). WithScheme(setupScheme()). WithObjects(objects...). Build() userID := "test-user" - dbClient := database_fake.NewClient() + dbClient := setupTestDBClient(t) base := &handlers.Base{ KubeClient: kubeClient, @@ -116,7 +115,7 @@ func TestHandleGetAgent(t *testing.T) { modelConfig := createTestModelConfig() team := createTestAgent("test-team", modelConfig) - handler, _ := setupTestHandler(team, modelConfig) + handler, _ := setupTestHandler(t, team, modelConfig) createAgent(handler.DatabaseService, team) req := httptest.NewRequest("GET", "/api/agents/default/test-team", nil) @@ -154,7 +153,7 @@ func TestHandleGetAgent(t *testing.T) { } agent := createTestAgentWithStatus("test-agent-ready", modelConfig, conditions) - handler, _ := setupTestHandler(agent, modelConfig) + handler, _ := setupTestHandler(t, agent, modelConfig) createAgent(handler.DatabaseService, agent) req := httptest.NewRequest("GET", "/api/agents/default/test-agent-ready", nil) @@ -184,7 +183,7 @@ func TestHandleGetAgent(t *testing.T) { } agent := createTestAgentWithStatus("test-agent-not-ready", modelConfig, conditions) - handler, _ := setupTestHandler(agent, modelConfig) + handler, _ := setupTestHandler(t, agent, modelConfig) createAgent(handler.DatabaseService, agent) req := httptest.NewRequest("GET", "/api/agents/default/test-agent-not-ready", nil) @@ -213,7 +212,7 @@ func TestHandleGetAgent(t *testing.T) { } agent := createTestAgentWithStatus("test-agent-different-reason", modelConfig, conditions) - handler, _ := setupTestHandler(agent, modelConfig) + handler, _ := setupTestHandler(t, agent, modelConfig) createAgent(handler.DatabaseService, agent) req := httptest.NewRequest("GET", "/api/agents/default/test-agent-different-reason", nil) @@ -247,7 +246,7 @@ func TestHandleGetAgent(t *testing.T) { } sa := createTestSandboxAgentCRD("sandbox-accepted", modelConfig, conditions) - handler, _ := setupTestHandler(sa, modelConfig) + handler, _ := setupTestHandler(t, sa, modelConfig) req := httptest.NewRequest("GET", "/api/agents/default/sandbox-accepted", nil) req = mux.SetURLVars(req, map[string]string{"namespace": "default", "name": "sandbox-accepted"}) @@ -260,7 +259,7 @@ func TestHandleGetAgent(t *testing.T) { }) t.Run("returns 404 for missing agent", func(t *testing.T) { - handler, _ := setupTestHandler() + handler, _ := setupTestHandler(t) req := httptest.NewRequest("GET", "/api/agents/default/test-team", nil) req = mux.SetURLVars(req, map[string]string{"namespace": "default", "name": "test-team"}) @@ -282,7 +281,7 @@ func TestHandleGetSandboxAgent(t *testing.T) { } sa := createTestSandboxAgentCRD("sandbox-accepted", modelConfig, conditions) - handler, _ := setupTestHandler(sa, modelConfig) + handler, _ := setupTestHandler(t, sa, modelConfig) req := httptest.NewRequest("GET", "/api/sandboxagents/default/sandbox-accepted", nil) req = mux.SetURLVars(req, map[string]string{"namespace": "default", "name": "sandbox-accepted"}) @@ -305,7 +304,7 @@ func TestHandleGetSandboxAgent(t *testing.T) { modelConfig := createTestModelConfig() agent := createTestAgent("shared-name", modelConfig) sa := createTestSandboxAgentCRD("shared-name", modelConfig, nil) - handler, _ := setupTestHandler(agent, sa, modelConfig) + handler, _ := setupTestHandler(t, agent, sa, modelConfig) req := httptest.NewRequest("GET", "/api/sandboxagents/default/shared-name", nil) req = mux.SetURLVars(req, map[string]string{"namespace": "default", "name": "shared-name"}) @@ -344,7 +343,7 @@ func TestHandleListAgents(t *testing.T) { // Agent with DeploymentReady=false notReadyAgent := createTestAgent("not-ready-agent", modelConfig) - handler, _ := setupTestHandler(readyAgent, notReadyAgent, modelConfig) + handler, _ := setupTestHandler(t, readyAgent, notReadyAgent, modelConfig) createAgent(handler.DatabaseService, readyAgent) createAgent(handler.DatabaseService, notReadyAgent) @@ -404,7 +403,7 @@ func TestHandleListAgents(t *testing.T) { readyAgent := createTestAgentWithStatus("ready-agent", modelConfig, readyConditions) invalidAgent := createTestAgentWithStatus("invalid-agent", modelConfig, invalidConditions) - handler, _ := setupTestHandler(readyAgent, invalidAgent, modelConfig) + handler, _ := setupTestHandler(t, readyAgent, invalidAgent, modelConfig) createAgent(handler.DatabaseService, readyAgent) createAgent(handler.DatabaseService, invalidAgent) @@ -437,7 +436,7 @@ func TestHandleListAgents(t *testing.T) { {Type: "Ready", Status: "True", Reason: "WorkloadReady"}, } sa := createTestSandboxAgentCRD("mysandbox", modelConfig, conditions) - handler, _ := setupTestHandler(sa, modelConfig) + handler, _ := setupTestHandler(t, sa, modelConfig) req := httptest.NewRequest("GET", "/api/agents", nil) req = setUser(req, "test-user") @@ -463,7 +462,7 @@ func TestHandleListSandboxAgents(t *testing.T) { } sa := createTestSandboxAgentCRD("mysandbox", modelConfig, conditions) agent := createTestAgent("myagent", modelConfig) - handler, _ := setupTestHandler(sa, agent, modelConfig) + handler, _ := setupTestHandler(t, sa, agent, modelConfig) req := httptest.NewRequest("GET", "/api/sandboxagents", nil) req = setUser(req, "test-user") @@ -487,7 +486,7 @@ func TestHandleListSandboxAgents(t *testing.T) { modelConfig := createTestModelConfig() agent := createTestAgent("shared-name", modelConfig) sa := createTestSandboxAgentCRD("shared-name", modelConfig, nil) - handler, _ := setupTestHandler(agent, sa, modelConfig) + handler, _ := setupTestHandler(t, agent, sa, modelConfig) agentReq := httptest.NewRequest("GET", "/api/agents", nil) agentReq = setUser(agentReq, "test-user") @@ -540,7 +539,7 @@ func TestHandleUpdateAgent(t *testing.T) { }, } - handler, _ := setupTestHandler(existingAgent, oldModelConfig, newModelConfig) + handler, _ := setupTestHandler(t, existingAgent, oldModelConfig, newModelConfig) updatedAgent := &v1alpha2.Agent{ ObjectMeta: metav1.ObjectMeta{Name: "test-team", Namespace: "default"}, @@ -589,7 +588,7 @@ func TestHandleUpdateAgent(t *testing.T) { }, } - handler, _ := setupTestHandler(existingAgent, modelConfig) + handler, _ := setupTestHandler(t, existingAgent, modelConfig) updatedAgent := &v1alpha2.Agent{ ObjectMeta: metav1.ObjectMeta{Name: "test-team", Namespace: "default"}, @@ -615,7 +614,7 @@ func TestHandleUpdateAgent(t *testing.T) { }) t.Run("returns 404 for non-existent team", func(t *testing.T) { - handler, _ := setupTestHandler() + handler, _ := setupTestHandler(t) agent := &v1alpha2.Agent{ ObjectMeta: metav1.ObjectMeta{Name: "non-existent", Namespace: "default"}, @@ -645,7 +644,7 @@ func TestHandleCreateAgent(t *testing.T) { }, } - handler, _ := setupTestHandler(modelConfig) + handler, _ := setupTestHandler(t, modelConfig) agent := &v1alpha2.Agent{ ObjectMeta: metav1.ObjectMeta{Name: "test-team", Namespace: "default"}, @@ -685,7 +684,7 @@ func TestHandleDeleteTeam(t *testing.T) { ObjectMeta: metav1.ObjectMeta{Name: "test-team", Namespace: "default"}, } - handler, _ := setupTestHandler(team) + handler, _ := setupTestHandler(t, team) createAgent(handler.DatabaseService, team) req := httptest.NewRequest("DELETE", "/api/agents/default/test-team", nil) @@ -699,7 +698,7 @@ func TestHandleDeleteTeam(t *testing.T) { }) t.Run("returns 404 for non-existent team", func(t *testing.T) { - handler, _ := setupTestHandler() + handler, _ := setupTestHandler(t) req := httptest.NewRequest("DELETE", "/api/teams/default/non-existent", nil) req = mux.SetURLVars(req, map[string]string{ @@ -718,7 +717,7 @@ func TestHandleDeleteTeam(t *testing.T) { modelConfig := createTestModelConfig() agent := createTestAgent("shared-name", modelConfig) sa := createTestSandboxAgentCRD("shared-name", modelConfig, nil) - handler, _ := setupTestHandler(agent, sa, modelConfig) + handler, _ := setupTestHandler(t, agent, sa, modelConfig) req := httptest.NewRequest("DELETE", "/api/agents/default/shared-name", nil) req = mux.SetURLVars(req, map[string]string{"namespace": "default", "name": "shared-name"}) @@ -739,7 +738,7 @@ func TestHandleDeleteSandboxAgent(t *testing.T) { t.Run("deletes sandbox agent successfully", func(t *testing.T) { modelConfig := createTestModelConfig() sa := createTestSandboxAgentCRD("test-sandbox", modelConfig, nil) - handler, _ := setupTestHandler(sa, modelConfig) + handler, _ := setupTestHandler(t, sa, modelConfig) req := httptest.NewRequest("DELETE", "/api/sandboxagents/default/test-sandbox", nil) req = mux.SetURLVars(req, map[string]string{"namespace": "default", "name": "test-sandbox"}) diff --git a/go/core/internal/httpserver/handlers/database_test.go b/go/core/internal/httpserver/handlers/database_test.go new file mode 100644 index 000000000..43ead845c --- /dev/null +++ b/go/core/internal/httpserver/handlers/database_test.go @@ -0,0 +1,122 @@ +package handlers_test + +import ( + "context" + "flag" + "fmt" + "os" + "slices" + "strings" + "sync" + "testing" + + "github.com/jackc/pgx/v5/pgxpool" + apidatabase "github.com/kagent-dev/kagent/go/api/database" + coredatabase "github.com/kagent-dev/kagent/go/core/internal/database" + "github.com/kagent-dev/kagent/go/core/internal/dbtest" + "github.com/stretchr/testify/require" +) + +var ( + sharedDB *pgxpool.Pool + sharedDBCleanup func() + sharedDBInitErr error + sharedDBInit sync.Once +) + +func TestMain(m *testing.M) { + flag.Parse() + code := m.Run() + if sharedDB != nil { + sharedDB.Close() + } + if sharedDBCleanup != nil { + sharedDBCleanup() + } + os.Exit(code) +} + +func setupTestDBClient(t *testing.T) apidatabase.Client { + t.Helper() + if testing.Short() { + t.Skip("skipping database-backed handler test in short mode") + } + + initSharedDB(t) + + tableNames, err := truncatableTables(context.Background()) + require.NoError(t, err, "failed to list tables for truncation") + + _, err = sharedDB.Exec(context.Background(), fmt.Sprintf( + "TRUNCATE TABLE %s RESTART IDENTITY CASCADE", + strings.Join(tableNames, ", "), + )) + require.NoError(t, err, "failed to truncate test tables") + + return coredatabase.NewClient(sharedDB) +} + +func initSharedDB(t *testing.T) { + t.Helper() + + sharedDBInit.Do(func() { + connStr, cleanup, err := dbtest.Start(context.Background()) + if err != nil { + sharedDBInitErr = fmt.Errorf("start postgres container: %w", err) + return + } + + if err := dbtest.Migrate(connStr, true); err != nil { + cleanup() + sharedDBInitErr = fmt.Errorf("migrate test database: %w", err) + return + } + + db, err := coredatabase.Connect(context.Background(), &coredatabase.PostgresConfig{ + URL: connStr, + VectorEnabled: true, + }) + if err != nil { + cleanup() + sharedDBInitErr = fmt.Errorf("connect to test database: %w", err) + return + } + + sharedDB = db + sharedDBCleanup = cleanup + }) + + require.NoError(t, sharedDBInitErr, "failed to initialize shared test database") +} + +func truncatableTables(ctx context.Context) ([]string, error) { + rows, err := sharedDB.Query(ctx, ` + SELECT tablename + FROM pg_tables + WHERE schemaname = current_schema() + AND tablename NOT IN ('schema_migrations', 'vector_schema_migrations') + `) + if err != nil { + return nil, err + } + defer rows.Close() + + var tableNames []string + for rows.Next() { + var tableName string + if err := rows.Scan(&tableName); err != nil { + return nil, err + } + tableNames = append(tableNames, quoteIdentifier(tableName)) + } + if err := rows.Err(); err != nil { + return nil, err + } + + slices.Sort(tableNames) + return tableNames, nil +} + +func quoteIdentifier(identifier string) string { + return `"` + strings.ReplaceAll(identifier, `"`, `""`) + `"` +} diff --git a/go/core/internal/httpserver/handlers/memory_test.go b/go/core/internal/httpserver/handlers/memory_test.go index 7b4755640..34bea2709 100644 --- a/go/core/internal/httpserver/handlers/memory_test.go +++ b/go/core/internal/httpserver/handlers/memory_test.go @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/types" - database_fake "github.com/kagent-dev/kagent/go/core/internal/database/fake" "github.com/kagent-dev/kagent/go/core/internal/httpserver/auth" "github.com/kagent-dev/kagent/go/core/internal/httpserver/handlers" ) @@ -27,10 +26,10 @@ func makeVector(n int, val float32) []float32 { } func TestMemoryHandler(t *testing.T) { - setupHandler := func() (*handlers.MemoryHandler, *mockErrorResponseWriter) { + setupHandler := func(t *testing.T) (*handlers.MemoryHandler, *mockErrorResponseWriter) { base := &handlers.Base{ DefaultModelConfig: types.NamespacedName{Namespace: "default", Name: "default"}, - DatabaseService: database_fake.NewClient(), + DatabaseService: setupTestDBClient(t), Authorizer: &auth.NoopAuthorizer{}, } handler := handlers.NewMemoryHandler(base) @@ -40,7 +39,7 @@ func TestMemoryHandler(t *testing.T) { t.Run("AddSession", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) reqBody := handlers.AddSessionMemoryRequest{ AgentName: "test-agent", @@ -64,7 +63,7 @@ func TestMemoryHandler(t *testing.T) { }) t.Run("MissingFields", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) reqBody := handlers.AddSessionMemoryRequest{UserID: "user123", Vector: makeVector(768, 0.1)} jsonBody, _ := json.Marshal(reqBody) @@ -77,7 +76,7 @@ func TestMemoryHandler(t *testing.T) { }) t.Run("WrongVectorDimension", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) reqBody := handlers.AddSessionMemoryRequest{ AgentName: "test-agent", @@ -96,7 +95,7 @@ func TestMemoryHandler(t *testing.T) { t.Run("AddSessionBatch", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) reqBody := handlers.AddSessionMemoryBatchRequest{ Items: []handlers.AddSessionMemoryRequest{ @@ -118,7 +117,7 @@ func TestMemoryHandler(t *testing.T) { }) t.Run("EmptyBatch", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) reqBody := handlers.AddSessionMemoryBatchRequest{Items: []handlers.AddSessionMemoryRequest{}} jsonBody, _ := json.Marshal(reqBody) @@ -131,7 +130,7 @@ func TestMemoryHandler(t *testing.T) { }) t.Run("BatchTooLarge", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) items := make([]handlers.AddSessionMemoryRequest, 51) for i := range items { @@ -149,7 +148,7 @@ func TestMemoryHandler(t *testing.T) { t.Run("Search", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) reqBody := handlers.SearchSessionMemoryRequest{ AgentName: "test-agent", @@ -170,7 +169,7 @@ func TestMemoryHandler(t *testing.T) { }) t.Run("MissingFields", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) reqBody := handlers.SearchSessionMemoryRequest{AgentName: "test-agent", Vector: makeVector(768, 0.1)} jsonBody, _ := json.Marshal(reqBody) @@ -185,7 +184,7 @@ func TestMemoryHandler(t *testing.T) { t.Run("List", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) req := httptest.NewRequest("GET", "/api/memories?agent_name=test-agent&user_id=user123", nil) req = setUser(req, "test-user") @@ -198,7 +197,7 @@ func TestMemoryHandler(t *testing.T) { }) t.Run("MissingFields", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) req := httptest.NewRequest("GET", "/api/memories?agent_name=test-agent", nil) req = setUser(req, "test-user") @@ -211,7 +210,7 @@ func TestMemoryHandler(t *testing.T) { t.Run("Delete", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) req := httptest.NewRequest("DELETE", "/api/memories?agent_name=test-agent&user_id=user123", nil) req = setUser(req, "test-user") @@ -225,7 +224,7 @@ func TestMemoryHandler(t *testing.T) { }) t.Run("MissingFields", func(t *testing.T) { - handler, responseRecorder := setupHandler() + handler, responseRecorder := setupHandler(t) req := httptest.NewRequest("DELETE", "/api/memories?agent_name=test-agent", nil) req = setUser(req, "test-user") diff --git a/go/core/internal/httpserver/handlers/sessions_test.go b/go/core/internal/httpserver/handlers/sessions_test.go index cb265f10a..517ea3682 100644 --- a/go/core/internal/httpserver/handlers/sessions_test.go +++ b/go/core/internal/httpserver/handlers/sessions_test.go @@ -19,12 +19,12 @@ import ( "github.com/kagent-dev/kagent/go/api/database" api "github.com/kagent-dev/kagent/go/api/httpapi" "github.com/kagent-dev/kagent/go/api/v1alpha2" - database_fake "github.com/kagent-dev/kagent/go/core/internal/database/fake" authimpl "github.com/kagent-dev/kagent/go/core/internal/httpserver/auth" "github.com/kagent-dev/kagent/go/core/internal/httpserver/handlers" "github.com/kagent-dev/kagent/go/core/internal/utils" "github.com/kagent-dev/kagent/go/core/pkg/auth" "github.com/kagent-dev/kmcp/api/v1alpha1" + "trpc.group/trpc-go/trpc-a2a-go/protocol" ) func setUser(req *http.Request, userID string) *http.Request { @@ -43,9 +43,9 @@ func TestSessionsHandler(t *testing.T) { err := v1alpha1.AddToScheme(scheme) require.NoError(t, err) - setupHandler := func() (*handlers.SessionsHandler, *database_fake.InMemoryFakeClient, *mockErrorResponseWriter) { + setupHandler := func(t *testing.T) (*handlers.SessionsHandler, database.Client, *mockErrorResponseWriter) { kubeClient := fake.NewClientBuilder().WithScheme(scheme).Build() - dbClient := database_fake.NewClient() + dbClient := setupTestDBClient(t) base := &handlers.Base{ KubeClient: kubeClient, @@ -54,40 +54,40 @@ func TestSessionsHandler(t *testing.T) { } handler := handlers.NewSessionsHandler(base) responseRecorder := newMockErrorResponseWriter() - return handler, dbClient.(*database_fake.InMemoryFakeClient), responseRecorder + return handler, dbClient, responseRecorder } - createTestAgent := func(dbClient database.Client, agentRef string) *database.Agent { + createTestAgent := func(t *testing.T, dbClient database.Client, agentRef string) *database.Agent { + t.Helper() agent := &database.Agent{ ID: agentRef, WorkloadType: v1alpha2.WorkloadModeDeployment, } - dbClient.StoreAgent(context.Background(), agent) //nolint:errcheck - // The fake client should assign an ID, but we'll use a default for testing - agent.ID = "1" // Simulate the ID that would be assigned by GORM + require.NoError(t, dbClient.StoreAgent(context.Background(), agent)) return agent } - createTestSession := func(dbClient database.Client, sessionID, userID string, agentID string) *database.Session { + createTestSession := func(t *testing.T, dbClient database.Client, sessionID, userID string, agentID string) *database.Session { + t.Helper() session := &database.Session{ ID: sessionID, Name: new(sessionID), UserID: userID, AgentID: &agentID, } - dbClient.StoreSession(context.Background(), session) //nolint:errcheck + require.NoError(t, dbClient.StoreSession(context.Background(), session)) return session } t.Run("HandleListSessions", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, dbClient, responseRecorder := setupHandler() + handler, dbClient, responseRecorder := setupHandler(t) userID := "test-user" // Create test sessions agentID := "1" - session1 := createTestSession(dbClient, "session-1", userID, agentID) - session2 := createTestSession(dbClient, "session-2", userID, agentID) + session1 := createTestSession(t, dbClient, "session-1", userID, agentID) + session2 := createTestSession(t, dbClient, "session-2", userID, agentID) req := httptest.NewRequest("GET", "/api/sessions?user_id="+userID, nil) req = setUser(req, userID) @@ -104,7 +104,7 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("MissingUserID", func(t *testing.T) { - handler, _, responseRecorder := setupHandler() + handler, _, responseRecorder := setupHandler(t) req := httptest.NewRequest("GET", "/api/sessions", nil) handler.HandleListSessions(responseRecorder, req) @@ -116,12 +116,12 @@ func TestSessionsHandler(t *testing.T) { t.Run("HandleCreateSession", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, dbClient, responseRecorder := setupHandler() + handler, dbClient, responseRecorder := setupHandler(t) userID := "test-user" agentRef := utils.ConvertToPythonIdentifier("default/test-agent") // Create test agent - createTestAgent(dbClient, agentRef) + createTestAgent(t, dbClient, agentRef) sessionReq := api.SessionRequest{ AgentRef: &agentRef, @@ -146,7 +146,7 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("MissingUserID", func(t *testing.T) { - handler, _, responseRecorder := setupHandler() + handler, _, responseRecorder := setupHandler(t) agentRef := utils.ConvertToPythonIdentifier("default/test-agent") sessionReq := api.SessionRequest{ @@ -164,7 +164,7 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("MissingAgentRef", func(t *testing.T) { - handler, _, responseRecorder := setupHandler() + handler, _, responseRecorder := setupHandler(t) userID := "test-user" sessionReq := api.SessionRequest{} @@ -182,7 +182,7 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("AgentNotFound", func(t *testing.T) { - handler, _, responseRecorder := setupHandler() + handler, _, responseRecorder := setupHandler(t) agentRef := utils.ConvertToPythonIdentifier("default/non-existent-agent") sessionReq := api.SessionRequest{ @@ -201,7 +201,7 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("InvalidJSON", func(t *testing.T) { - handler, _, responseRecorder := setupHandler() + handler, _, responseRecorder := setupHandler(t) req := httptest.NewRequest("POST", "/api/sessions", bytes.NewBufferString("invalid json")) req.Header.Set("Content-Type", "application/json") @@ -213,7 +213,7 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("SandboxAgentAllowsOnlyOneSessionGlobally", func(t *testing.T) { - handler, dbClient, responseRecorder := setupHandler() + handler, dbClient, responseRecorder := setupHandler(t) userID := "test-user" agentRef := utils.ConvertToPythonIdentifier("default/test-sandbox-agent") @@ -223,7 +223,7 @@ func TestSessionsHandler(t *testing.T) { })) existingAgentID := agentRef - createTestSession(dbClient, "existing-session", "other-user", existingAgentID) + createTestSession(t, dbClient, "existing-session", "other-user", existingAgentID) sessionReq := api.SessionRequest{ AgentRef: &agentRef, @@ -244,13 +244,13 @@ func TestSessionsHandler(t *testing.T) { t.Run("HandleGetSession", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, dbClient, responseRecorder := setupHandler() + handler, dbClient, responseRecorder := setupHandler(t) userID := "test-user" sessionID := "test-session" // Create test session agentID := "1" - session := createTestSession(dbClient, sessionID, userID, agentID) + session := createTestSession(t, dbClient, sessionID, userID, agentID) req := httptest.NewRequest("GET", "/api/sessions/"+sessionID, nil) req = mux.SetURLVars(req, map[string]string{"session_id": sessionID}) @@ -268,7 +268,7 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("SessionNotFound", func(t *testing.T) { - handler, _, responseRecorder := setupHandler() + handler, _, responseRecorder := setupHandler(t) userID := "test-user" sessionID := "non-existent-session" @@ -283,7 +283,7 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("MissingUserID", func(t *testing.T) { - handler, _, responseRecorder := setupHandler() + handler, _, responseRecorder := setupHandler(t) sessionID := "test-session" req := httptest.NewRequest("GET", "/api/sessions/"+sessionID, nil) @@ -296,13 +296,13 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("OrderAsc", func(t *testing.T) { - handler, dbClient, responseRecorder := setupHandler() + handler, dbClient, responseRecorder := setupHandler(t) userID := "test-user" sessionID := "test-session" // Create test session agentID := "1" - createTestSession(dbClient, sessionID, userID, agentID) + createTestSession(t, dbClient, sessionID, userID, agentID) // Create events with different timestamps event1 := &database.Event{ @@ -340,17 +340,17 @@ func TestSessionsHandler(t *testing.T) { t.Run("HandleUpdateSession", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, dbClient, responseRecorder := setupHandler() + handler, dbClient, responseRecorder := setupHandler(t) userID := "test-user" sessionName := "test-session" // Create test agent and session agentRef := utils.ConvertToPythonIdentifier("default/test-agent") - agent := createTestAgent(dbClient, agentRef) - session := createTestSession(dbClient, sessionName, userID, agent.ID) + agent := createTestAgent(t, dbClient, agentRef) + session := createTestSession(t, dbClient, sessionName, userID, agent.ID) newAgentRef := utils.ConvertToPythonIdentifier("default/new-agent") - newAgent := createTestAgent(dbClient, newAgentRef) + newAgent := createTestAgent(t, dbClient, newAgentRef) sessionReq := api.SessionRequest{ Name: &sessionName, @@ -374,7 +374,7 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("MissingSessionName", func(t *testing.T) { - handler, _, responseRecorder := setupHandler() + handler, _, responseRecorder := setupHandler(t) userID := "test-user" agentRef := "default/test-agent" @@ -394,12 +394,12 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("SessionNotFound", func(t *testing.T) { - handler, dbClient, responseRecorder := setupHandler() + handler, dbClient, responseRecorder := setupHandler(t) userID := "test-user" sessionName := "non-existent-session" agentRef := "default/test-agent" - createTestAgent(dbClient, agentRef) + createTestAgent(t, dbClient, agentRef) sessionReq := api.SessionRequest{ Name: &sessionName, @@ -420,7 +420,7 @@ func TestSessionsHandler(t *testing.T) { t.Run("HandleDeleteSession", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, dbClient, responseRecorder := setupHandler() + handler, dbClient, responseRecorder := setupHandler(t) userID := "test-user" sessionID := "test-session" @@ -430,7 +430,7 @@ func TestSessionsHandler(t *testing.T) { Type: "Declarative", })) agentID := "1" - createTestSession(dbClient, sessionID, userID, agentID) + createTestSession(t, dbClient, sessionID, userID, agentID) req := httptest.NewRequest("DELETE", "/api/sessions/"+sessionID, nil) req = mux.SetURLVars(req, map[string]string{"session_id": sessionID}) @@ -447,7 +447,7 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("MissingUserID", func(t *testing.T) { - handler, _, responseRecorder := setupHandler() + handler, _, responseRecorder := setupHandler(t) sessionID := "test-session" req := httptest.NewRequest("DELETE", "/api/sessions/"+sessionID, nil) @@ -462,16 +462,16 @@ func TestSessionsHandler(t *testing.T) { t.Run("HandleGetSessionsForAgent", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, dbClient, responseRecorder := setupHandler() + handler, dbClient, responseRecorder := setupHandler(t) userID := "test-user" namespace := "default" agentName := "test-agent" agentRef := utils.ConvertToPythonIdentifier(namespace + "/" + agentName) // Create test agent and sessions - agent := createTestAgent(dbClient, agentRef) - session1 := createTestSession(dbClient, "session-1", userID, agent.ID) - session2 := createTestSession(dbClient, "session-2", userID, agent.ID) + agent := createTestAgent(t, dbClient, agentRef) + session1 := createTestSession(t, dbClient, "session-1", userID, agent.ID) + session2 := createTestSession(t, dbClient, "session-2", userID, agent.ID) req := httptest.NewRequest("GET", "/api/agents/"+namespace+"/"+agentName+"/sessions", nil) req = mux.SetURLVars(req, map[string]string{"namespace": namespace, "name": agentName}) @@ -490,7 +490,7 @@ func TestSessionsHandler(t *testing.T) { }) t.Run("AgentNotFound", func(t *testing.T) { - handler, _, responseRecorder := setupHandler() + handler, _, responseRecorder := setupHandler(t) userID := "test-user" namespace := "default" agentName := "non-existent-agent" @@ -508,27 +508,22 @@ func TestSessionsHandler(t *testing.T) { t.Run("HandleListTasksForSession", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, dbClient, responseRecorder := setupHandler() + handler, dbClient, responseRecorder := setupHandler(t) userID := "test-user" sessionID := "test-session" // Create test session and tasks agentID := "1" - createTestSession(dbClient, sessionID, userID, agentID) + createTestSession(t, dbClient, sessionID, userID, agentID) - task1 := &database.Task{ + require.NoError(t, dbClient.StoreTask(context.Background(), &protocol.Task{ ID: "task-1", - SessionID: sessionID, - Data: "{}", - } - task2 := &database.Task{ + ContextID: sessionID, + })) + require.NoError(t, dbClient.StoreTask(context.Background(), &protocol.Task{ ID: "task-2", - SessionID: sessionID, - Data: "{}", - } - // Use the fake client's AddTask method for testing - dbClient.AddTask(task1) - dbClient.AddTask(task2) + ContextID: sessionID, + })) req := httptest.NewRequest("GET", "/api/sessions/"+sessionID+"/tasks", nil) req = mux.SetURLVars(req, map[string]string{"session_id": sessionID}) @@ -538,14 +533,14 @@ func TestSessionsHandler(t *testing.T) { assert.Equal(t, http.StatusOK, responseRecorder.Code) - var response api.StandardResponse[[]*database.Task] + var response api.StandardResponse[[]*protocol.Task] err := json.Unmarshal(responseRecorder.Body.Bytes(), &response) require.NoError(t, err) assert.Len(t, response.Data, 2) }) t.Run("MissingUserID", func(t *testing.T) { - handler, _, responseRecorder := setupHandler() + handler, _, responseRecorder := setupHandler(t) sessionID := "test-session" req := httptest.NewRequest("GET", "/api/sessions/"+sessionID+"/tasks", nil) diff --git a/go/core/internal/httpserver/handlers/toolservers_test.go b/go/core/internal/httpserver/handlers/toolservers_test.go index 7efd6735f..79d0caadc 100644 --- a/go/core/internal/httpserver/handlers/toolservers_test.go +++ b/go/core/internal/httpserver/handlers/toolservers_test.go @@ -24,7 +24,6 @@ import ( "github.com/kagent-dev/kagent/go/api/database" api "github.com/kagent-dev/kagent/go/api/httpapi" "github.com/kagent-dev/kagent/go/api/v1alpha2" - database_fake "github.com/kagent-dev/kagent/go/core/internal/database/fake" "github.com/kagent-dev/kagent/go/core/internal/httpserver/auth" "github.com/kagent-dev/kagent/go/core/internal/httpserver/handlers" common "github.com/kagent-dev/kagent/go/core/internal/utils" @@ -41,7 +40,7 @@ func TestToolServersHandler(t *testing.T) { err = corev1.AddToScheme(scheme) require.NoError(t, err) - setupHandler := func() (*handlers.ToolServersHandler, ctrl_client.Client, *database_fake.InMemoryFakeClient, *mockErrorResponseWriter) { + setupHandler := func(t *testing.T) (*handlers.ToolServersHandler, ctrl_client.Client, database.Client, *mockErrorResponseWriter) { // Create a RESTMapper that knows about the MCPServer type restMapper := meta.NewDefaultRESTMapper([]schema.GroupVersion{v1alpha1.GroupVersion}) restMapper.Add(schema.GroupVersionKind{ @@ -54,7 +53,7 @@ func TestToolServersHandler(t *testing.T) { WithScheme(scheme). WithRESTMapper(restMapper). Build() - dbClient := database_fake.NewClient() + dbClient := setupTestDBClient(t) base := &handlers.Base{ KubeClient: kubeClient, DefaultModelConfig: types.NamespacedName{Namespace: "default", Name: "default"}, @@ -65,12 +64,12 @@ func TestToolServersHandler(t *testing.T) { _ = handlers.NewToolServerTypesHandler(base) handler := handlers.NewToolServersHandler(base) responseRecorder := newMockErrorResponseWriter() - return handler, kubeClient, dbClient.(*database_fake.InMemoryFakeClient), responseRecorder + return handler, kubeClient, dbClient, responseRecorder } t.Run("HandleListToolServers", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, _, dbClient, responseRecorder := setupHandler() + handler, _, dbClient, responseRecorder := setupHandler(t) // Create test tool servers in database toolServer1 := &database.ToolServer{ @@ -90,14 +89,12 @@ func TestToolServersHandler(t *testing.T) { _, err = dbClient.StoreToolServer(context.Background(), toolServer2) require.NoError(t, err) - // Create test tools in database - tool1 := &database.Tool{ - ID: "test-tool", - ServerName: "default/test-toolserver-1", - GroupKind: "kagent.dev/RemoteMCPServer", - Description: "Test tool", - } - err = dbClient.CreateTool(context.Background(), tool1) + err = dbClient.RefreshToolsForServer(context.Background(), "default/test-toolserver-1", "kagent.dev/RemoteMCPServer", + &v1alpha2.MCPTool{ + Name: "test-tool", + Description: "Test tool", + }, + ) require.NoError(t, err) req := httptest.NewRequest("GET", "/api/toolservers/", nil) @@ -123,7 +120,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("EmptyList", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) req := httptest.NewRequest("GET", "/api/toolservers/", nil) req = setUser(req, "test-user") @@ -140,7 +137,7 @@ func TestToolServersHandler(t *testing.T) { t.Run("HandleCreateToolServer", func(t *testing.T) { t.Run("Success_RemoteMCPServer_StreamableHttp", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) reqBody := &handlers.ToolServerCreateRequest{ Type: "RemoteMCPServer", @@ -186,7 +183,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("Success_RemoteMCPServer_Sse", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) reqBody := &handlers.ToolServerCreateRequest{ Type: "RemoteMCPServer", @@ -234,7 +231,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("Success_MCPServer_Stdio", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) reqBody := &handlers.ToolServerCreateRequest{ Type: "MCPServer", @@ -279,7 +276,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("Success_DefaultNamespace", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) reqBody := &handlers.ToolServerCreateRequest{ Type: "RemoteMCPServer", @@ -312,7 +309,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("InvalidType", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) reqBody := &handlers.ToolServerCreateRequest{ Type: "InvalidType", @@ -330,7 +327,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("MissingRemoteMCPServerData", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) reqBody := &handlers.ToolServerCreateRequest{ Type: "RemoteMCPServer", @@ -349,7 +346,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("MissingMCPServerData", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) reqBody := &handlers.ToolServerCreateRequest{ Type: "MCPServer", @@ -368,7 +365,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("InvalidJSON", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) req := httptest.NewRequest("POST", "/api/toolservers/", bytes.NewBufferString("invalid json")) req.Header.Set("Content-Type", "application/json") @@ -381,7 +378,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("ToolServerAlreadyExists", func(t *testing.T) { - handler, kubeClient, _, responseRecorder := setupHandler() + handler, kubeClient, _, responseRecorder := setupHandler(t) // Create existing tool server existingToolServer := &v1alpha2.RemoteMCPServer{ @@ -425,7 +422,7 @@ func TestToolServersHandler(t *testing.T) { t.Run("HandleDeleteToolServer", func(t *testing.T) { t.Run("Success", func(t *testing.T) { - handler, kubeClient, dbClient, responseRecorder := setupHandler() + handler, kubeClient, dbClient, responseRecorder := setupHandler(t) // Create tool server to delete toolServer := &v1alpha2.RemoteMCPServer{ @@ -462,7 +459,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("NotFound", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) req := httptest.NewRequest("DELETE", "/api/toolservers/default/nonexistent", nil) req = setUser(req, "test-user") @@ -479,7 +476,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("MissingNamespaceParam", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) // Request without namespace param should fail req := httptest.NewRequest("DELETE", "/api/toolservers/", nil) @@ -491,7 +488,7 @@ func TestToolServersHandler(t *testing.T) { }) t.Run("MissingToolServerNameParam", func(t *testing.T) { - handler, _, _, responseRecorder := setupHandler() + handler, _, _, responseRecorder := setupHandler(t) req := httptest.NewRequest("DELETE", "/api/toolservers/default/", nil) req = mux.SetURLVars(req, map[string]string{