From 41f63caa08c03cf506fa69f0cdaa8b1a57de95de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 10:05:47 +0700 Subject: [PATCH 01/16] refactor(ai-chat): replace animated focus border with native macOS focus stroke --- CHANGELOG.md | 6 ++ TablePro/Views/AIChat/ChatComposerView.swift | 70 ++------------------ 2 files changed, 11 insertions(+), 65 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c4ec1dc9..0fabde150 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- AI Chat: composer focus ring uses the standard macOS accent stroke instead of a colored gradient +- AI inline suggestions: debounce now uses structured Swift concurrency, and the delay is configurable via the `inlineSuggestionDebounceMs` setting (default 500ms) +- Copilot LSP shutdown caps at 10 seconds, closes pipes explicitly, and strips the quarantine attribute from the downloaded binary + ## [0.39.1] - 2026-05-08 ### Added diff --git a/TablePro/Views/AIChat/ChatComposerView.swift b/TablePro/Views/AIChat/ChatComposerView.swift index f06568f57..a191065a4 100644 --- a/TablePro/Views/AIChat/ChatComposerView.swift +++ b/TablePro/Views/AIChat/ChatComposerView.swift @@ -55,15 +55,12 @@ struct ChatComposerView: View { return shape .fill(Color(nsColor: .textBackgroundColor)) .overlay { - if isFocused { - IntelligenceFocusBorder(shape: shape) - .transition(.opacity) - } else { - shape.stroke(Color(nsColor: .separatorColor), lineWidth: 0.5) - .transition(.opacity) - } + shape.stroke( + isFocused ? Color.accentColor : Color(nsColor: .separatorColor), + lineWidth: isFocused ? 1 : 0.5 + ) } - .animation(.easeOut(duration: 0.25), value: isFocused) + .animation(.default, value: isFocused) } private var popoverBinding: Binding { @@ -111,60 +108,3 @@ struct ChatComposerView: View { mentionState.reset() } } - -private enum IntelligenceShimmer { - static let palette: [Color] = [ - Color(red: 1.0, green: 0.404, blue: 0.471), - Color(red: 1.0, green: 0.553, blue: 0.443), - Color(red: 1.0, green: 0.729, blue: 0.443), - Color(red: 0.961, green: 0.725, blue: 0.918), - Color(red: 0.776, green: 0.525, blue: 1.0), - Color(red: 0.737, green: 0.510, blue: 0.953), - Color(red: 0.553, green: 0.624, blue: 1.0) - ] - - struct Layer: Identifiable { - let id: Int - let lineWidth: CGFloat - let blur: CGFloat - let opacity: Double - } - - static let layers: [Layer] = [ - Layer(id: 0, lineWidth: 1.5, blur: 2, opacity: 1.0), - Layer(id: 1, lineWidth: 5, blur: 4, opacity: 0.75), - Layer(id: 2, lineWidth: 9, blur: 10, opacity: 0.5), - Layer(id: 3, lineWidth: 14, blur: 16, opacity: 0.35) - ] - - static func generateStops() -> [Gradient.Stop] { - let count = palette.count - var stops = palette.enumerated().map { index, color in - Gradient.Stop(color: color, location: Double(index) / Double(count)) - } - if let first = palette.first { - stops.append(Gradient.Stop(color: first, location: 1.0)) - } - return stops - } -} - -private struct IntelligenceFocusBorder: View { - let shape: S - - @State private var stops: [Gradient.Stop] = IntelligenceShimmer.generateStops() - - var body: some View { - ZStack { - ForEach(IntelligenceShimmer.layers) { layer in - shape - .stroke( - AngularGradient(gradient: Gradient(stops: stops), center: .center), - lineWidth: layer.lineWidth - ) - .blur(radius: layer.blur) - .opacity(layer.opacity) - } - } - } -} From 6730b6c4b99c40f1d94e51d8c08757f46604a728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 10:19:08 +0700 Subject: [PATCH 02/16] refactor(ai-chat): decompose AIChatViewModel into responsibility extensions --- CHANGELOG.md | 1 + TablePro/Core/AI/AIProvider.swift | 12 +- TablePro/Core/AI/AnthropicProvider.swift | 22 +- TablePro/Core/AI/JSONValue+Encoding.swift | 22 + .../Core/AI/OpenAICompatibleProvider.swift | 44 +- TablePro/Core/AI/String+AIEndpoint.swift | 12 + .../AIChatViewModel+MessageEditing.swift | 106 ++ .../AIChatViewModel+Persistence.swift | 71 + .../AIChatViewModel+SchemaContext.swift | 178 +++ .../AIChatViewModel+SlashCommands.swift | 99 ++ .../AIChatViewModel+Streaming.swift | 467 +++++++ .../AIChatViewModel+ToolApproval.swift | 17 + TablePro/ViewModels/AIChatViewModel.swift | 1160 ++--------------- TablePro/Views/AIChat/AIChatPanelView.swift | 2 +- 14 files changed, 1124 insertions(+), 1089 deletions(-) create mode 100644 TablePro/Core/AI/JSONValue+Encoding.swift create mode 100644 TablePro/Core/AI/String+AIEndpoint.swift create mode 100644 TablePro/ViewModels/AIChatViewModel+MessageEditing.swift create mode 100644 TablePro/ViewModels/AIChatViewModel+Persistence.swift create mode 100644 TablePro/ViewModels/AIChatViewModel+SchemaContext.swift create mode 100644 TablePro/ViewModels/AIChatViewModel+SlashCommands.swift create mode 100644 TablePro/ViewModels/AIChatViewModel+Streaming.swift diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fabde150..c886bbe55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - AI Chat: composer focus ring uses the standard macOS accent stroke instead of a colored gradient - AI inline suggestions: debounce now uses structured Swift concurrency, and the delay is configurable via the `inlineSuggestionDebounceMs` setting (default 500ms) - Copilot LSP shutdown caps at 10 seconds, closes pipes explicitly, and strips the quarantine attribute from the downloaded binary +- AI Chat: streaming view model split into focused extensions backed by a single `streamingState` enum ## [0.39.1] - 2026-05-08 diff --git a/TablePro/Core/AI/AIProvider.swift b/TablePro/Core/AI/AIProvider.swift index da14fca62..baeb49991 100644 --- a/TablePro/Core/AI/AIProvider.swift +++ b/TablePro/Core/AI/AIProvider.swift @@ -5,6 +5,10 @@ import Foundation +enum AIProvider { + static let modelListTimeout: TimeInterval = 5.0 +} + enum AIProviderError: Error, LocalizedError { case invalidEndpoint(String) case authenticationFailed(String) @@ -36,11 +40,17 @@ enum AIProviderError: Error, LocalizedError { } } - static func mapHTTPError(statusCode: Int, body: String) -> AIProviderError { + static func mapHTTPError( + statusCode: Int, + body: String, + treatForbiddenAsAuthFailure: Bool = false + ) -> AIProviderError { let message = parseErrorMessage(from: body) ?? body switch statusCode { case 401: return .authenticationFailed(message) + case 403 where treatForbiddenAsAuthFailure: + return .authenticationFailed(message) case 429: return .rateLimited case 404: diff --git a/TablePro/Core/AI/AnthropicProvider.swift b/TablePro/Core/AI/AnthropicProvider.swift index 571f19510..8a9be6f9c 100644 --- a/TablePro/Core/AI/AnthropicProvider.swift +++ b/TablePro/Core/AI/AnthropicProvider.swift @@ -15,7 +15,7 @@ final class AnthropicProvider: ChatTransport { private let session: URLSession init(endpoint: String, apiKey: String, maxOutputTokens: Int = 4_096) { - self.endpoint = endpoint.hasSuffix("/") ? String(endpoint.dropLast()) : endpoint + self.endpoint = endpoint.normalizedEndpoint() self.apiKey = apiKey.trimmingCharacters(in: .whitespacesAndNewlines) self.maxOutputTokens = maxOutputTokens self.session = URLSession(configuration: .ephemeral) @@ -73,16 +73,25 @@ final class AnthropicProvider: ChatTransport { var request = URLRequest(url: url) request.httpMethod = "GET" + request.timeoutInterval = AIProvider.modelListTimeout request.setValue(apiKey, forHTTPHeaderField: "x-api-key") request.setValue("2023-06-01", forHTTPHeaderField: "anthropic-version") - let (data, response) = try await session.data(for: request) + let data: Data + let response: URLResponse + do { + (data, response) = try await session.data(for: request) + } catch { + Self.logger.warning("Anthropic model fetch failed; using known models: \(error.localizedDescription, privacy: .public)") + return Self.knownModels + } guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200, let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], let models = json["data"] as? [[String: Any]] else { + Self.logger.warning("Anthropic model fetch returned unexpected response; using known models") return Self.knownModels } @@ -240,7 +249,7 @@ final class AnthropicProvider: ChatTransport { [ "name": spec.name, "description": spec.description, - "input_schema": try jsonObject(from: spec.inputSchema) + "input_schema": try spec.inputSchema.asJSONObject() ] } @@ -276,7 +285,7 @@ final class AnthropicProvider: ChatTransport { "type": "tool_use", "id": toolUse.id, "name": toolUse.name, - "input": try jsonObject(from: toolUse.input) + "input": try toolUse.input.asJSONObject() ] case .toolResult(let result): var encoded: [String: Any] = [ @@ -292,11 +301,6 @@ final class AnthropicProvider: ChatTransport { return nil } } - - static func jsonObject(from value: JSONValue) throws -> Any { - let data = try JSONEncoder().encode(value) - return try JSONSerialization.jsonObject(with: data, options: [.fragmentsAllowed]) - } } /// Mutable state carried across `AnthropicProvider.parseChunk` calls. diff --git a/TablePro/Core/AI/JSONValue+Encoding.swift b/TablePro/Core/AI/JSONValue+Encoding.swift new file mode 100644 index 000000000..6e3f4570b --- /dev/null +++ b/TablePro/Core/AI/JSONValue+Encoding.swift @@ -0,0 +1,22 @@ +// +// JSONValue+Encoding.swift +// TablePro +// + +import Foundation + +extension JSONValue { + func asJSONObject() throws -> Any { + let data = try JSONEncoder().encode(self) + return try JSONSerialization.jsonObject(with: data, options: [.fragmentsAllowed]) + } + + func asJSONString() -> String { + guard let data = try? JSONEncoder().encode(self), + let string = String(data: data, encoding: .utf8) + else { + return "{}" + } + return string + } +} diff --git a/TablePro/Core/AI/OpenAICompatibleProvider.swift b/TablePro/Core/AI/OpenAICompatibleProvider.swift index e2389d173..672422ef3 100644 --- a/TablePro/Core/AI/OpenAICompatibleProvider.swift +++ b/TablePro/Core/AI/OpenAICompatibleProvider.swift @@ -30,7 +30,7 @@ final class OpenAICompatibleProvider: ChatTransport { maxOutputTokens: Int? = nil, session: URLSession = URLSession(configuration: .ephemeral) ) { - self.endpoint = endpoint.hasSuffix("/") ? String(endpoint.dropLast()) : endpoint + self.endpoint = endpoint.normalizedEndpoint() self.apiKey = apiKey?.trimmingCharacters(in: .whitespacesAndNewlines) self.providerType = providerType self.model = model.trimmingCharacters(in: .whitespacesAndNewlines) @@ -375,7 +375,7 @@ final class OpenAICompatibleProvider: ChatTransport { "type": "function", "function": [ "name": block.name, - "arguments": jsonString(from: block.input) + "arguments": block.input.asJSONString() ] ] } @@ -407,7 +407,7 @@ final class OpenAICompatibleProvider: ChatTransport { } func encodeTool(_ tool: ChatToolSpec) throws -> [String: Any] { - let parameters = try jsonObject(from: tool.inputSchema) + let parameters = try tool.inputSchema.asJSONObject() return [ "type": "function", "function": [ @@ -418,26 +418,13 @@ final class OpenAICompatibleProvider: ChatTransport { ] } - func jsonString(from value: JSONValue) -> String { - guard let data = try? JSONEncoder().encode(value), - let string = String(data: data, encoding: .utf8) - else { - return "{}" - } - return string - } - - func jsonObject(from value: JSONValue) throws -> Any { - let data = try JSONEncoder().encode(value) - return try JSONSerialization.jsonObject(with: data, options: [.fragmentsAllowed]) - } - private func fetchOpenAIModels() async throws -> [String] { guard let url = URL(string: "\(endpoint)/v1/models") else { throw AIProviderError.invalidEndpoint(endpoint) } var request = URLRequest(url: url) + request.timeoutInterval = AIProvider.modelListTimeout if let apiKey, !apiKey.isEmpty { request.setValue( "Bearer \(apiKey)", @@ -445,7 +432,14 @@ final class OpenAICompatibleProvider: ChatTransport { ) } - let (data, response) = try await session.data(for: request) + let data: Data + let response: URLResponse + do { + (data, response) = try await session.data(for: request) + } catch { + Self.logger.warning("OpenAI-compatible model fetch failed: \(error.localizedDescription, privacy: .public)") + throw AIProviderError.networkError("Failed to fetch models") + } guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 @@ -468,8 +462,18 @@ final class OpenAICompatibleProvider: ChatTransport { throw AIProviderError.invalidEndpoint(endpoint) } - let request = URLRequest(url: url) - let (data, response) = try await session.data(for: request) + var request = URLRequest(url: url) + request.timeoutInterval = AIProvider.modelListTimeout + let data: Data + let response: URLResponse + do { + (data, response) = try await session.data(for: request) + } catch { + Self.logger.warning("Ollama model fetch failed: \(error.localizedDescription, privacy: .public)") + throw AIProviderError.networkError( + String(format: String(localized: "Failed to fetch models from %@"), endpoint) + ) + } guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 diff --git a/TablePro/Core/AI/String+AIEndpoint.swift b/TablePro/Core/AI/String+AIEndpoint.swift new file mode 100644 index 000000000..962edaec4 --- /dev/null +++ b/TablePro/Core/AI/String+AIEndpoint.swift @@ -0,0 +1,12 @@ +// +// String+AIEndpoint.swift +// TablePro +// + +import Foundation + +extension String { + func normalizedEndpoint() -> String { + hasSuffix("/") ? String(dropLast()) : self + } +} diff --git a/TablePro/ViewModels/AIChatViewModel+MessageEditing.swift b/TablePro/ViewModels/AIChatViewModel+MessageEditing.swift new file mode 100644 index 000000000..88b635d65 --- /dev/null +++ b/TablePro/ViewModels/AIChatViewModel+MessageEditing.swift @@ -0,0 +1,106 @@ +// +// AIChatViewModel+MessageEditing.swift +// TablePro +// + +import Foundation + +extension AIChatViewModel { + func editMessage(_ message: ChatTurn) { + guard message.role == .user, !isStreaming else { return } + guard let idx = messages.firstIndex(where: { $0.id == message.id }) else { return } + + inputText = message.plainText + attachedContext = message.blocks.compactMap { block in + if case .attachment(let item) = block { return item } + return nil + } + messages.removeSubrange(idx...) + persistCurrentConversation() + } + + func resolveTurnForWire(_ turn: ChatTurn) async -> ChatTurn { + let attachments = turn.blocks.compactMap { block -> ContextItem? in + if case .attachment(let item) = block { return item } + return nil + } + guard !attachments.isEmpty else { return turn } + + for item in attachments { + await primeAttachmentData(for: item) + } + + let typed = turn.blocks.compactMap { block -> String? in + if case .text(let value) = block { return value } + return nil + }.joined() + + let resolved = attachments + .compactMap { resolveAttachment($0) } + .joined(separator: "\n\n") + if resolved.isEmpty { return turn } + + let combined = typed.isEmpty ? resolved : typed + "\n\n---\n\n" + resolved + return ChatTurn( + id: turn.id, + role: turn.role, + blocks: [.text(combined)], + timestamp: turn.timestamp, + usage: turn.usage, + modelId: turn.modelId, + providerId: turn.providerId + ) + } + + func resolveAttachment(_ item: ContextItem) -> String? { + switch item { + case .schema: + return resolveSchemaAttachment() + case .table(_, let name): + return resolveTableAttachment(name: name) + case .currentQuery(let text): + let snapshot = text.isEmpty ? (currentQuery ?? "") : text + guard !snapshot.isEmpty else { return nil } + return "## Current Query\n```\n\(snapshot)\n```" + case .queryResult(let summary): + let snapshot = summary.isEmpty ? (queryResults ?? "") : summary + guard !snapshot.isEmpty else { return nil } + return "## Query Results\n\(snapshot)" + case .savedQuery(let id, let name): + return resolveSavedQueryAttachment(id: id, fallbackName: name) + case .file: + return nil + } + } + + private func resolveSavedQueryAttachment(id: UUID, fallbackName: String) -> String? { + guard let favorite = cachedSavedQueries[id] else { return nil } + let displayName = favorite.name.isEmpty ? fallbackName : favorite.name + let header = displayName.isEmpty + ? String(localized: "Saved Query") + : "\(String(localized: "Saved Query")): \(displayName)" + return "## \(header)\n```sql\n\(favorite.query)\n```" + } + + private func resolveSchemaAttachment() -> String? { + guard let section = renderedSchemaSection() else { return nil } + return "## Schema\n\(section)" + } + + private func resolveTableAttachment(name: String) -> String? { + let columns = columnsByTable[name] ?? [] + guard !columns.isEmpty else { return nil } + let foreignKeys = foreignKeysByTable[name] ?? [] + var lines: [String] = ["## Table \(name)"] + for column in columns { + lines.append("- \(column.name): \(column.dataType)") + } + if !foreignKeys.isEmpty { + lines.append("Foreign keys:") + for foreign in foreignKeys { + lines.append("- \(foreign.column) -> \(foreign.referencedTable).\(foreign.referencedColumn)") + } + } + return lines.joined(separator: "\n") + } +} diff --git a/TablePro/ViewModels/AIChatViewModel+Persistence.swift b/TablePro/ViewModels/AIChatViewModel+Persistence.swift new file mode 100644 index 000000000..01eeb7b82 --- /dev/null +++ b/TablePro/ViewModels/AIChatViewModel+Persistence.swift @@ -0,0 +1,71 @@ +// +// AIChatViewModel+Persistence.swift +// TablePro +// + +import Foundation + +extension AIChatViewModel { + func loadConversations() { + let storage = chatStorage + Task.detached(priority: .utility) { [weak self] in + let loaded = await storage.loadAll() + await MainActor.run { + guard let self else { return } + self.conversations = loaded + if let mostRecent = loaded.first { + self.activeConversationID = mostRecent.id + self.messages = mostRecent.messages + } + } + } + } + + func clearConversation() { + cancelStream() + AIProviderFactory.resetCopilotConversation() + Task { await chatStorage.deleteAll() } + conversations.removeAll() + messages.removeAll() + activeConversationID = nil + clearError() + } + + func deleteConversation(_ id: UUID) { + if activeConversationID == id { + AIProviderFactory.resetCopilotConversation() + } + Task { await chatStorage.delete(id) } + conversations.removeAll { $0.id == id } + if activeConversationID == id { + activeConversationID = nil + messages.removeAll() + } + } + + func persistCurrentConversation() { + guard !messages.isEmpty else { return } + + if let existingID = activeConversationID, + var conversation = conversations.first(where: { $0.id == existingID }) { + conversation.messages = messages + conversation.updatedAt = Date() + conversation.updateTitle() + conversation.connectionName = connection?.name + Task { await chatStorage.save(conversation) } + + if let index = conversations.firstIndex(where: { $0.id == existingID }) { + conversations[index] = conversation + } + } else { + var conversation = AIConversation( + messages: messages, + connectionName: connection?.name + ) + conversation.updateTitle() + Task { await chatStorage.save(conversation) } + activeConversationID = conversation.id + conversations.insert(conversation, at: 0) + } + } +} diff --git a/TablePro/ViewModels/AIChatViewModel+SchemaContext.swift b/TablePro/ViewModels/AIChatViewModel+SchemaContext.swift new file mode 100644 index 000000000..f76b1982c --- /dev/null +++ b/TablePro/ViewModels/AIChatViewModel+SchemaContext.swift @@ -0,0 +1,178 @@ +// +// AIChatViewModel+SchemaContext.swift +// TablePro +// + +import Foundation +import TableProPluginKit + +extension AIChatViewModel { + struct PromptContext: Sendable { + let databaseType: DatabaseType + let databaseName: String + let tables: [TableInfo] + let columnsByTable: [String: [ColumnInfo]] + let foreignKeys: [String: [ForeignKeyInfo]] + let currentQuery: String? + let queryResults: String? + let settings: AISettings + let identifierQuote: String + let editorLanguage: EditorLanguage + let queryLanguageName: String + let connectionRules: String? + } + + func ensureColumnsLoaded(forTable tableName: String) async { + if let existing = columnsByTable[tableName], !existing.isEmpty { return } + if let inFlight = inFlightColumnFetches[tableName] { + await inFlight.value + return + } + guard let connection, + let driver = DatabaseManager.shared.driver(for: connection.id) else { return } + let task: Task = Task { [weak self] in + let columns: [ColumnInfo] + do { + columns = try await driver.fetchColumns(table: tableName) + } catch { + Self.logger.warning("Column fetch failed for \(tableName, privacy: .public): \(error.localizedDescription, privacy: .public)") + columns = [] + } + let fkMap: [String: [ForeignKeyInfo]] + do { + fkMap = try await driver.fetchForeignKeys(forTables: [tableName]) + } catch { + Self.logger.warning("Foreign key fetch failed for \(tableName, privacy: .public): \(error.localizedDescription, privacy: .public)") + fkMap = [:] + } + guard !Task.isCancelled, let self else { return } + self.columnsByTable[tableName] = columns + if let fks = fkMap[tableName] { + self.foreignKeysByTable[tableName] = fks + } + self.inFlightColumnFetches[tableName] = nil + } + inFlightColumnFetches[tableName] = task + await task.value + } + + func ensureSchemaLoaded() async { + if let inFlight = inFlightSchemaLoad { + await inFlight.value + return + } + let task: Task = Task { [weak self] in + guard let self else { return } + await self.runSchemaLoad() + } + inFlightSchemaLoad = task + await task.value + inFlightSchemaLoad = nil + } + + func ensureSavedQueryLoaded(id: UUID) async { + if cachedSavedQueries[id] != nil { return } + if let favorite = await SQLFavoriteManager.shared.fetchFavorite(id: id) { + cachedSavedQueries[id] = favorite + } + } + + func primeAttachmentData(for item: ContextItem) async { + switch item { + case .schema: + await ensureSchemaLoaded() + case .table(_, let name): + await ensureColumnsLoaded(forTable: name) + case .savedQuery(let id, _): + await ensureSavedQueryLoaded(id: id) + case .currentQuery, .queryResult, .file: + break + } + } + + private func runSchemaLoad() async { + guard let connection, + let driver = DatabaseManager.shared.driver(for: connection.id) else { return } + let settings = AppSettingsManager.shared.ai + let tablesToFetch = Array(tables.prefix(settings.maxSchemaTables)) + guard !tablesToFetch.isEmpty else { return } + + await withTaskGroup(of: (String, [ColumnInfo]).self) { group in + for table in tablesToFetch where (columnsByTable[table.name] ?? []).isEmpty { + let name = table.name + group.addTask { + do { + let cols = try await driver.fetchColumns(table: name) + return (name, cols) + } catch { + Self.logger.warning("Schema column fetch failed for \(name, privacy: .public): \(error.localizedDescription, privacy: .public)") + return (name, []) + } + } + } + for await (name, cols) in group { + columnsByTable[name] = cols + } + } + + guard !Task.isCancelled else { return } + + let needsFKFetch = tablesToFetch.contains { foreignKeysByTable[$0.name] == nil } + guard needsFKFetch else { return } + do { + let fkMap = try await driver.fetchForeignKeys(forTables: tablesToFetch.map(\.name)) + for (name, fks) in fkMap { + foreignKeysByTable[name] = fks + } + } catch { + Self.logger.warning("Foreign key bulk fetch failed: \(error.localizedDescription, privacy: .public)") + } + } + + func capturePromptContext(settings: AISettings) -> PromptContext? { + guard let connection else { return nil } + return PromptContext( + databaseType: connection.type, + databaseName: DatabaseManager.shared.activeDatabaseName(for: connection), + tables: tables, + columnsByTable: columnsByTable, + foreignKeys: foreignKeysByTable, + currentQuery: settings.includeCurrentQuery ? currentQuery : nil, + queryResults: settings.includeQueryResults ? queryResults : nil, + settings: settings, + identifierQuote: PluginManager.shared.sqlDialect(for: connection.type)?.identifierQuote ?? "\"", + editorLanguage: PluginManager.shared.editorLanguage(for: connection.type), + queryLanguageName: PluginManager.shared.queryLanguageName(for: connection.type), + connectionRules: connection.aiRules + ) + } + + func resolveConnectionPolicy(settings: AISettings) -> AIConnectionPolicy? { + let policy = connection?.aiPolicy ?? settings.defaultConnectionPolicy + + if policy == .askEachTime { + if let connectionID = connection?.id, sessionApprovedConnections.contains(connectionID) { + return .alwaysAllow + } + return .askEachTime + } + + return policy + } + + func renderedSchemaSection() -> String? { + guard !tables.isEmpty else { return nil } + let settings = AppSettingsManager.shared.ai + let identifierQuote = connection.flatMap { + PluginManager.shared.sqlDialect(for: $0.type)?.identifierQuote + } ?? "\"" + let section = AISchemaContext.buildSchemaSection( + tables: tables, + columnsByTable: columnsByTable, + foreignKeys: foreignKeysByTable, + maxTables: settings.maxSchemaTables, + identifierQuote: identifierQuote + ) + return section.isEmpty ? nil : section + } +} diff --git a/TablePro/ViewModels/AIChatViewModel+SlashCommands.swift b/TablePro/ViewModels/AIChatViewModel+SlashCommands.swift new file mode 100644 index 000000000..bb887b156 --- /dev/null +++ b/TablePro/ViewModels/AIChatViewModel+SlashCommands.swift @@ -0,0 +1,99 @@ +// +// AIChatViewModel+SlashCommands.swift +// TablePro +// + +import Foundation + +extension AIChatViewModel { + static let helpMarkdown: String = { + let lines = SlashCommand.allCommands + .map { "- `/\($0.name)` · \($0.description)" } + .joined(separator: "\n") + return String(localized: "**Available commands:**") + "\n\n" + lines + }() + + func runSlashCommand(_ command: SlashCommand, body: String = "") { + inputText = "" + clearError() + + let invocationText = body.isEmpty ? "/\(command.name)" : "/\(command.name) \(body)" + let databaseType = connection?.type ?? .mysql + + switch command { + case .help: + let helpMarkdown = Self.helpMarkdown + if let last = messages.last, last.role == .assistant, last.plainText == helpMarkdown { + return + } + messages.append(ChatTurn(role: .user, blocks: [.text(invocationText)])) + messages.append(ChatTurn(role: .assistant, blocks: [.text(helpMarkdown)])) + case .explain: + guard let query = resolveQuery(body: body, command: command) else { return } + messages.append(ChatTurn(role: .user, blocks: [.text(invocationText)])) + sendWithContext(prompt: AIPromptTemplates.explainQuery(query, databaseType: databaseType)) + case .optimize: + guard let query = resolveQuery(body: body, command: command) else { return } + messages.append(ChatTurn(role: .user, blocks: [.text(invocationText)])) + sendWithContext(prompt: AIPromptTemplates.optimizeQuery(query, databaseType: databaseType)) + case .fix: + guard let query = resolveQuery(body: body, command: command) else { return } + messages.append(ChatTurn(role: .user, blocks: [.text(invocationText)])) + let lastError = queryResults ?? "" + sendWithContext(prompt: AIPromptTemplates.fixError(query: query, error: lastError, databaseType: databaseType)) + } + } + + func runCustomSlashCommand(_ command: CustomSlashCommand, body: String = "") async { + guard command.isValid else { + Self.logger.warning("runCustomSlashCommand called with invalid command: name=\(command.name, privacy: .public)") + return + } + inputText = "" + clearError() + let invocationText = body.isEmpty ? "/\(command.name)" : "/\(command.name) \(body)" + let needsSchema = command.promptTemplate.contains(CustomSlashCommandVariable.schema.placeholder) + if needsSchema { + await ensureSchemaLoaded() + } + let renderingContext = CustomSlashCommandRenderer.Context( + query: currentQuery, + schema: needsSchema ? renderedSchemaSection() : nil, + database: connection.flatMap { DatabaseManager.shared.activeDatabaseName(for: $0) }, + body: body + ) + let prompt = CustomSlashCommandRenderer.render(command, context: renderingContext) + messages.append(ChatTurn(role: .user, blocks: [.text(invocationText)])) + sendWithContext(prompt: prompt) + } + + func handleExplainSelection(_ selectedText: String) { + guard !selectedText.isEmpty else { return } + startNewConversation() + let databaseType = connection?.type ?? .mysql + let prompt = AIPromptTemplates.explainQuery(selectedText, databaseType: databaseType) + sendWithContext(prompt: prompt) + } + + func handleOptimizeSelection(_ selectedText: String) { + guard !selectedText.isEmpty else { return } + startNewConversation() + let databaseType = connection?.type ?? .mysql + let prompt = AIPromptTemplates.optimizeQuery(selectedText, databaseType: databaseType) + sendWithContext(prompt: prompt) + } + + private func resolveQuery(body: String, command: SlashCommand) -> String? { + if !body.isEmpty { + return body + } + if let editorQuery = currentQuery, !editorQuery.isEmpty { + return editorQuery + } + errorMessage = String( + format: String(localized: "/%@ needs a query: type one in the editor or after the command."), + command.name + ) + return nil + } +} diff --git a/TablePro/ViewModels/AIChatViewModel+Streaming.swift b/TablePro/ViewModels/AIChatViewModel+Streaming.swift new file mode 100644 index 000000000..e39fc9a06 --- /dev/null +++ b/TablePro/ViewModels/AIChatViewModel+Streaming.swift @@ -0,0 +1,467 @@ +// +// AIChatViewModel+Streaming.swift +// TablePro +// + +import Foundation +import TableProPluginKit + +extension AIChatViewModel { + static let maxToolRoundtrips = 10 + + struct ToolRoundtripContinuation { + let nextAssistantID: UUID + let assistantTurn: ChatTurn + let userTurn: ChatTurn + } + + private struct StreamRoundResult { + let toolUseOrder: [String] + let toolUseNames: [String: String] + let toolUseInputs: [String: String] + let cancelled: Bool + } + + func startStreaming() { + guard case .idle = streamingState else { return } + + let settings = AppSettingsManager.shared.ai + + let resolved = AIProviderFactory.resolve( + settings: settings, + overrideProviderId: selectedProviderId, + overrideModel: selectedModel + ) + guard let resolved else { + errorMessage = String(localized: "No AI provider configured. Go to Settings > AI to add one.") + return + } + + if connection != nil, let policy = resolveConnectionPolicy(settings: settings) { + if policy == .never { + errorMessage = String(localized: "AI is disabled for this connection.") + if let last = messages.last, last.role == .user { + messages.removeLast() + } + return + } + if policy == .askEachTime { + streamingState = .awaitingApproval + showAIAccessConfirmation = true + return + } + } + + let assistantMessage = ChatTurn( + role: .assistant, + blocks: [], + modelId: resolved.model, + providerId: resolved.config.id.uuidString + ) + messages.append(assistantMessage) + trimMessagesIfNeeded() + let assistantID = assistantMessage.id + streamingState = .streaming(assistantID: assistantID) + + prepTask = Task { [weak self] in + guard let self else { return } + if settings.includeSchema { + await self.ensureSchemaLoaded() + } + guard !Task.isCancelled else { return } + let promptContext = self.capturePromptContext(settings: settings) + var chatMessages: [ChatTurn] = [] + for turn in self.messages.dropLast() { + chatMessages.append(await self.resolveTurnForWire(turn)) + } + guard !Task.isCancelled else { return } + self.runStream( + chatMessages: chatMessages, + promptContext: promptContext, + resolved: resolved, + assistantID: assistantID, + settings: settings + ) + self.prepTask = nil + } + } + + func runStream( + chatMessages: [ChatTurn], + promptContext: PromptContext?, + resolved: AIProviderFactory.ResolvedProvider, + assistantID: UUID, + settings: AISettings + ) { + let chatMode = settings.chatMode + streamingTask = Task.detached(priority: .userInitiated) { [weak self] in + do { + let systemPrompt = Self.buildSystemPrompt(promptContext, mode: chatMode) + guard let self else { return } + let preflightOK = await self.preflightCheck( + systemPrompt: systemPrompt, + turns: chatMessages, + assistantID: assistantID + ) + guard preflightOK else { return } + + let toolSpecs = await MainActor.run { ChatToolRegistry.shared.allSpecs(for: chatMode) } + var workingTurns = chatMessages + var currentAssistantID = assistantID + + for roundtrip in 0.. StreamRoundResult { + let stream = resolved.provider.streamChat( + turns: workingTurns, + options: ChatTransportOptions( + model: resolved.model, + systemPrompt: systemPrompt, + tools: toolSpecs + ) + ) + + var pendingContent = "" + var pendingUsage: AITokenUsage? + var toolUseOrder: [String] = [] + var toolUseNames: [String: String] = [:] + var toolUseInputs: [String: String] = [:] + let flushInterval: ContinuousClock.Duration = .milliseconds(150) + var lastFlushTime: ContinuousClock.Instant = .now + + for try await event in stream { + guard !Task.isCancelled else { break } + switch event { + case .textDelta(let token): + pendingContent += token + case .usage(let usage): + pendingUsage = usage + case .toolUseStart(let id, let name): + if toolUseInputs[id] == nil { + toolUseOrder.append(id) + toolUseInputs[id] = "" + } + toolUseNames[id] = name + case .toolUseDelta(let id, let inputJSONDelta): + toolUseInputs[id, default: ""] += inputJSONDelta + case .toolUseEnd: + break + case .toolInvocationRequest(let block, let replyToken): + await self.dispatchCopilotInvocation( + block: block, replyToken: replyToken, + assistantID: assistantID, mode: chatMode + ) + } + + if ContinuousClock.now - lastFlushTime >= flushInterval { + await self.flushPending(content: pendingContent, usage: pendingUsage, into: assistantID) + pendingContent = "" + pendingUsage = nil + lastFlushTime = .now + } + } + + if !Task.isCancelled, !pendingContent.isEmpty || pendingUsage != nil { + await self.flushPending(content: pendingContent, usage: pendingUsage, into: assistantID) + } + + return StreamRoundResult( + toolUseOrder: toolUseOrder, + toolUseNames: toolUseNames, + toolUseInputs: toolUseInputs, + cancelled: Task.isCancelled + ) + } + + nonisolated static func buildSystemPrompt(_ promptContext: PromptContext?, mode: AIChatMode) -> String? { + let schemaPrompt = promptContext.map { + AISchemaContext.buildSystemPrompt( + databaseType: $0.databaseType, + databaseName: $0.databaseName, + tables: $0.tables, + columnsByTable: $0.columnsByTable, + foreignKeys: $0.foreignKeys, + currentQuery: $0.currentQuery, + queryResults: $0.queryResults, + settings: $0.settings, + identifierQuote: $0.identifierQuote, + editorLanguage: $0.editorLanguage, + queryLanguageName: $0.queryLanguageName, + connectionRules: $0.connectionRules + ) + } + let modeNote = mode.systemPromptNote + guard let schemaPrompt, !schemaPrompt.isEmpty else { return modeNote } + return "\(schemaPrompt)\n\n\(modeNote)" + } + + private func failTooManyRoundtrips(assistantID: UUID) async { + await MainActor.run { [weak self] in + guard let self else { return } + self.errorMessage = String( + localized: "AI made too many tool calls in one response. Try simplifying the request." + ) + if let idx = self.messages.firstIndex(where: { $0.id == assistantID }), + self.messages[idx].plainText.isEmpty { + self.messages.remove(at: idx) + } + self.streamingState = .failed(nil) + } + } + + func completeToolRoundtrip( + assistantIDForRound: UUID, + toolUseBlocks: [ToolUseBlock], + toolResultBlocks: [ToolResultBlock], + resolved: AIProviderFactory.ResolvedProvider + ) async -> ToolRoundtripContinuation { + await MainActor.run { [weak self] () -> ToolRoundtripContinuation in + let assistantText: String = { + guard let self, + let idx = self.messages.firstIndex(where: { $0.id == assistantIDForRound }) + else { return "" } + return self.messages[idx].plainText + }() + var assistantBlocks: [ChatContentBlock] = [] + if !assistantText.isEmpty { assistantBlocks.append(.text(assistantText)) } + assistantBlocks.append(contentsOf: toolUseBlocks.map { .toolUse($0) }) + let assistantTurn = ChatTurn( + id: assistantIDForRound, + role: .assistant, + blocks: assistantBlocks, + modelId: resolved.model, + providerId: resolved.config.id.uuidString + ) + let userTurn = ChatTurn( + role: .user, + blocks: toolResultBlocks.map { .toolResult($0) } + ) + let nextAssistant = ChatTurn( + role: .assistant, + blocks: [], + modelId: resolved.model, + providerId: resolved.config.id.uuidString + ) + self?.messages.append(userTurn) + self?.messages.append(nextAssistant) + self?.streamingState = .streaming(assistantID: nextAssistant.id) + return ToolRoundtripContinuation( + nextAssistantID: nextAssistant.id, + assistantTurn: assistantTurn, + userTurn: userTurn + ) + } + } + + func flushPending(content: String, usage: AITokenUsage?, into assistantID: UUID) async { + guard !content.isEmpty || usage != nil else { return } + await MainActor.run { [weak self] in + guard let self, + let idx = self.messages.firstIndex(where: { $0.id == assistantID }) + else { return } + if !content.isEmpty { + self.messages[idx].appendText(content) + } + if let usage { + self.messages[idx].usage = usage + } + } + } + + func preflightCheck(systemPrompt: String?, turns: [ChatTurn], assistantID: UUID) async -> Bool { + let totalSize = ((systemPrompt ?? "") as NSString).length + + turns.reduce(0) { $0 + ($1.plainText as NSString).length } + guard totalSize > 100_000 else { return true } + await MainActor.run { [weak self] in + guard let self else { return } + self.errorMessage = String( + localized: "Message too large. Try disabling 'Include schema' or 'Include query results' in AI settings." + ) + if let idx = self.messages.firstIndex(where: { $0.id == assistantID }) { + self.messages.remove(at: idx) + } + self.streamingState = .idle + } + return false + } + + nonisolated static func assembleToolUseBlocks( + order: [String], + names: [String: String], + inputs: [String: String] + ) -> [ToolUseBlock] { + order.compactMap { id -> ToolUseBlock? in + guard let name = names[id] else { return nil } + let inputString = inputs[id] ?? "{}" + let inputValue: JSONValue + if inputString.isEmpty { + inputValue = .object([:]) + } else if let data = inputString.data(using: .utf8), + let decoded = try? JSONDecoder().decode(JSONValue.self, from: data) { + inputValue = decoded + } else { + inputValue = .object([:]) + } + return ToolUseBlock(id: id, name: name, input: inputValue) + } + } + + nonisolated static func executeToolUses( + _ blocks: [ToolUseBlock], + mode: AIChatMode, + context: ChatToolContext, + registry: ChatToolRegistry? = nil + ) async -> [ToolResultBlock] { + await withTaskGroup(of: (Int, ToolResultBlock).self) { group in + for (index, block) in blocks.enumerated() { + group.addTask { + (index, await runToolUse(block, mode: mode, context: context, registry: registry)) + } + } + var indexed: [(Int, ToolResultBlock)] = [] + for await pair in group { indexed.append(pair) } + return indexed.sorted(by: { $0.0 < $1.0 }).map(\.1) + } + } + + nonisolated private static func runToolUse( + _ block: ToolUseBlock, + mode: AIChatMode, + context: ChatToolContext, + registry: ChatToolRegistry? + ) async -> ToolResultBlock { + if Task.isCancelled { + return ToolResultBlock(toolUseId: block.id, content: "Cancelled", isError: true) + } + guard ChatToolRegistry.isToolAllowed(name: block.name, in: mode) else { + logger.warning( + "Tool '\(block.name, privacy: .public)' blocked in \(mode.rawValue, privacy: .public) mode" + ) + return ToolResultBlock( + toolUseId: block.id, + content: "Tool '\(block.name)' is not available in \(mode.displayName) mode", + isError: true + ) + } + let tool = await MainActor.run { + (registry ?? ChatToolRegistry.shared).tool(named: block.name, in: mode) + } + guard let tool else { + logger.warning("Tool '\(block.name, privacy: .public)' not registered; returning error") + return ToolResultBlock( + toolUseId: block.id, + content: "Tool '\(block.name)' is not available", + isError: true + ) + } + do { + let result = try await tool.execute(input: block.input, context: context) + return ToolResultBlock( + toolUseId: block.id, + content: result.content, + isError: result.isError + ) + } catch { + logger.warning( + "Tool \(block.name, privacy: .public) execution failed: \(error.localizedDescription, privacy: .public)" + ) + return ToolResultBlock( + toolUseId: block.id, + content: "Error: \(error.localizedDescription)", + isError: true + ) + } + } +} diff --git a/TablePro/ViewModels/AIChatViewModel+ToolApproval.swift b/TablePro/ViewModels/AIChatViewModel+ToolApproval.swift index f208bbfde..f68fbe2ac 100644 --- a/TablePro/ViewModels/AIChatViewModel+ToolApproval.swift +++ b/TablePro/ViewModels/AIChatViewModel+ToolApproval.swift @@ -6,6 +6,23 @@ import Foundation extension AIChatViewModel { + func confirmAIAccess() { + if let connectionID = connection?.id { + sessionApprovedConnections.insert(connectionID) + } + guard case .awaitingApproval = streamingState else { return } + streamingState = .idle + startStreaming() + } + + func denyAIAccess() { + guard case .awaitingApproval = streamingState else { return } + streamingState = .idle + if let last = messages.last, last.role == .user { + messages.removeLast() + } + } + func resolveAndAwaitApprovals( assembledBlocks: [ToolUseBlock], assistantID: UUID diff --git a/TablePro/ViewModels/AIChatViewModel.swift b/TablePro/ViewModels/AIChatViewModel.swift index 6d43484fd..99c51ba3d 100644 --- a/TablePro/ViewModels/AIChatViewModel.swift +++ b/TablePro/ViewModels/AIChatViewModel.swift @@ -2,37 +2,28 @@ // AIChatViewModel.swift // TablePro // -// View model for AI chat panel - manages conversation, streaming, and provider resolution. -// import Foundation import Observation import os import TableProPluginKit -/// View model for the AI chat panel @MainActor @Observable final class AIChatViewModel { - private static let logger = Logger(subsystem: "com.TablePro", category: "AIChatViewModel") + static let logger = Logger(subsystem: "com.TablePro", category: "AIChatViewModel") - // MARK: - Published State + enum StreamingState { + case idle + case loading + case streaming(assistantID: UUID) + case awaitingApproval + case failed(AIProviderError?) + } var messages: [ChatTurn] = [] var inputText: String = "" - var isStreaming: Bool = false - var errorMessage: String? { - didSet { - if errorMessage == nil { - lastError = nil - } - } - } - var lastError: AIProviderError? - var lastMessageFailed: Bool = false - - var canRetryLastFailure: Bool { - lastError?.isRetryable ?? true - } + private(set) var streamingState: StreamingState = .idle + var errorMessage: String? var conversations: [AIConversation] = [] var activeConversationID: UUID? var showAIAccessConfirmation = false @@ -40,228 +31,54 @@ final class AIChatViewModel { var selectedModel: String? var availableModels: [UUID: [String]] = [:] var attachedContext: [ContextItem] = [] + var savedQueries: [SQLFavorite] = [] - // MARK: - Context Properties - - /// Current database connection (set by parent view) var connection: DatabaseConnection? - /// Tables for the current connection. Always derived live from `SchemaService`, - /// so reads stay in sync with schema reloads without any push-from-upstream plumbing. var tables: [TableInfo] { guard let id = connection?.id else { return [] } return SchemaService.shared.tables(for: id) } - /// Column info cache populated on-demand when chips are attached or - /// schema is auto-included. Keyed by table name within the active connection. var columnsByTable: [String: [ColumnInfo]] = [:] - - /// Foreign keys cache populated alongside columns. var foreignKeysByTable: [String: [ForeignKeyInfo]] = [:] - @ObservationIgnored private var inFlightColumnFetches: [String: Task] = [:] - @ObservationIgnored private var inFlightSchemaLoad: Task? - - /// Current query text from the active editor tab var currentQuery: String? - - /// Query results summary from the active tab var queryResults: String? - // MARK: - AI Action Dispatch - - func loadAvailableModels() async { - let settings = AppSettingsManager.shared.ai - let pending = settings.providers.filter { availableModels[$0.id] == nil } - guard !pending.isEmpty else { return } - - let results = await withTaskGroup(of: (UUID, [String]?).self) { group in - for config in pending { - let apiKey: String? - switch config.type.authStyle { - case .apiKey: - apiKey = AIKeyStorage.shared.loadAPIKey(for: config.id) - case .oauth, .none: - apiKey = nil - } - group.addTask { - let transport = await AIProviderFactory.createProvider(for: config, apiKey: apiKey) - do { - let models = try await transport.fetchAvailableModels() - return (config.id, models) - } catch is CancellationError { - return (config.id, nil) - } catch { - return (config.id, []) - } - } - } - - var collected: [(UUID, [String]?)] = [] - for await result in group { - collected.append(result) - } - return collected - } - - guard !Task.isCancelled else { return } - - for (id, models) in results { - guard let models else { continue } - if models.isEmpty { - let fallback = pending.first(where: { $0.id == id })?.model - availableModels[id] = (fallback?.isEmpty == false) ? [fallback ?? ""] : [] - } else { - availableModels[id] = models - } + var isStreaming: Bool { + switch streamingState { + case .loading, .streaming: + return true + case .idle, .awaitingApproval, .failed: + return false } } - func runSlashCommand(_ command: SlashCommand, body: String = "") { - inputText = "" - errorMessage = nil - - let invocationText = body.isEmpty ? "/\(command.name)" : "/\(command.name) \(body)" - let databaseType = connection?.type ?? .mysql - - switch command { - case .help: - let helpMarkdown = Self.helpMarkdown - if let last = messages.last, last.role == .assistant, last.plainText == helpMarkdown { - return - } - messages.append(ChatTurn(role: .user, blocks: [.text(invocationText)])) - messages.append(ChatTurn(role: .assistant, blocks: [.text(helpMarkdown)])) - case .explain: - guard let query = resolveQuery(body: body, command: command) else { return } - messages.append(ChatTurn(role: .user, blocks: [.text(invocationText)])) - sendWithContext(prompt: AIPromptTemplates.explainQuery(query, databaseType: databaseType)) - case .optimize: - guard let query = resolveQuery(body: body, command: command) else { return } - messages.append(ChatTurn(role: .user, blocks: [.text(invocationText)])) - sendWithContext(prompt: AIPromptTemplates.optimizeQuery(query, databaseType: databaseType)) - case .fix: - guard let query = resolveQuery(body: body, command: command) else { return } - messages.append(ChatTurn(role: .user, blocks: [.text(invocationText)])) - let lastError = queryResults ?? "" - sendWithContext(prompt: AIPromptTemplates.fixError(query: query, error: lastError, databaseType: databaseType)) - } - } - - func runCustomSlashCommand(_ command: CustomSlashCommand, body: String = "") async { - guard command.isValid else { - Self.logger.warning("runCustomSlashCommand called with invalid command: name=\(command.name, privacy: .public)") - return - } - inputText = "" - errorMessage = nil - let invocationText = body.isEmpty ? "/\(command.name)" : "/\(command.name) \(body)" - let needsSchema = command.promptTemplate.contains(CustomSlashCommandVariable.schema.placeholder) - if needsSchema { - await ensureSchemaLoaded() - } - let renderingContext = CustomSlashCommandRenderer.Context( - query: currentQuery, - schema: needsSchema ? renderedSchemaSection() : nil, - database: connection.flatMap { DatabaseManager.shared.activeDatabaseName(for: $0) }, - body: body - ) - let prompt = CustomSlashCommandRenderer.render(command, context: renderingContext) - messages.append(ChatTurn(role: .user, blocks: [.text(invocationText)])) - sendWithContext(prompt: prompt) - } - - private func renderedSchemaSection() -> String? { - guard !tables.isEmpty else { return nil } - let settings = AppSettingsManager.shared.ai - let identifierQuote = connection.flatMap { - PluginManager.shared.sqlDialect(for: $0.type)?.identifierQuote - } ?? "\"" - let section = AISchemaContext.buildSchemaSection( - tables: tables, - columnsByTable: columnsByTable, - foreignKeys: foreignKeysByTable, - maxTables: settings.maxSchemaTables, - identifierQuote: identifierQuote - ) - return section.isEmpty ? nil : section + var lastMessageFailed: Bool { + if case .failed = streamingState { return true } + return false } - private func resolveQuery(body: String, command: SlashCommand) -> String? { - if !body.isEmpty { - return body - } - if let editorQuery = currentQuery, !editorQuery.isEmpty { - return editorQuery - } - errorMessage = String( - format: String(localized: "/%@ needs a query: type one in the editor or after the command."), - command.name - ) + var lastError: AIProviderError? { + if case .failed(let error) = streamingState { return error } return nil } - private static let helpMarkdown: String = { - let lines = SlashCommand.allCommands - .map { "- `/\($0.name)` · \($0.description)" } - .joined(separator: "\n") - return String(localized: "**Available commands:**") + "\n\n" + lines - }() - - func handleFixError(query: String, error: String) { - startNewConversation() - let databaseType = connection?.type ?? .mysql - let prompt = AIPromptTemplates.fixError(query: query, error: error, databaseType: databaseType) - sendWithContext(prompt: prompt) - } - - func handleExplainSelection(_ selectedText: String) { - guard !selectedText.isEmpty else { return } - startNewConversation() - let databaseType = connection?.type ?? .mysql - let prompt = AIPromptTemplates.explainQuery(selectedText, databaseType: databaseType) - sendWithContext(prompt: prompt) - } - - func handleOptimizeSelection(_ selectedText: String) { - guard !selectedText.isEmpty else { return } - startNewConversation() - let databaseType = connection?.type ?? .mysql - let prompt = AIPromptTemplates.optimizeQuery(selectedText, databaseType: databaseType) - sendWithContext(prompt: prompt) - } - - func editMessage(_ message: ChatTurn) { - guard message.role == .user, !isStreaming else { return } - guard let idx = messages.firstIndex(where: { $0.id == message.id }) else { return } - - inputText = message.plainText - attachedContext = message.blocks.compactMap { block in - if case .attachment(let item) = block { return item } - return nil - } - messages.removeSubrange(idx...) - persistCurrentConversation() + var canRetryLastFailure: Bool { + lastError?.isRetryable ?? true } - // MARK: - Constants + @ObservationIgnored var inFlightColumnFetches: [String: Task] = [:] + @ObservationIgnored var inFlightSchemaLoad: Task? + @ObservationIgnored nonisolated(unsafe) var streamingTask: Task? + @ObservationIgnored var prepTask: Task? - /// Maximum number of messages to keep in memory to prevent unbounded growth - private static let maxMessageCount = 200 + let chatStorage = AIChatStorage.shared + var sessionApprovedConnections: Set = [] + @ObservationIgnored var cachedSavedQueries: [UUID: SQLFavorite] = [:] - // MARK: - Private - - /// nonisolated(unsafe) is required because deinit is not @MainActor-isolated, - /// so accessing a @MainActor property from deinit requires opting out of isolation. - @ObservationIgnored nonisolated(unsafe) private var streamingTask: Task? - @ObservationIgnored private var prepTask: Task? - private var streamingAssistantID: UUID? - private let chatStorage = AIChatStorage.shared - private var sessionApprovedConnections: Set = [] - private var pendingApproval: Bool = false - - // MARK: - Init + static let maxMessageCount = 200 init() { loadConversations() @@ -271,9 +88,6 @@ final class AIChatViewModel { streamingTask?.cancel() } - // MARK: - Actions - - /// Send the current input text as a user message func sendMessage() { let text = inputText.trimmingCharacters(in: .whitespacesAndNewlines) guard !text.isEmpty else { return } @@ -290,297 +104,45 @@ final class AIChatViewModel { trimMessagesIfNeeded() inputText = "" attachedContext = [] - errorMessage = nil + clearError() startStreaming() } + func sendWithContext(prompt: String) { + let userMessage = ChatTurn(role: .user, blocks: [.text(prompt)]) + messages.append(userMessage) + trimMessagesIfNeeded() + clearError() + startStreaming() + } + func attach(_ item: ContextItem) { guard !attachedContext.contains(where: { $0.stableKey == item.stableKey }) else { return } attachedContext.append(item) Task { await primeAttachmentData(for: item) } } - private func primeAttachmentData(for item: ContextItem) async { - switch item { - case .schema: - await ensureSchemaLoaded() - case .table(_, let name): - await ensureColumnsLoaded(forTable: name) - case .savedQuery(let id, _): - await ensureSavedQueryLoaded(id: id) - case .currentQuery, .queryResult, .file: - break - } - } - - /// Loaded `SQLFavorite` instances keyed by id, populated when saved-query - /// chips are attached so `resolveSavedQueryAttachment` can serialize them. - @ObservationIgnored private var cachedSavedQueries: [UUID: SQLFavorite] = [:] - - /// Saved queries available as `@`-mention candidates for the active connection. - /// Refreshed on connection change via `loadSavedQueries()`. - var savedQueries: [SQLFavorite] = [] - - func loadSavedQueries() async { - guard let connectionId = connection?.id else { - savedQueries = [] - return - } - let favorites = await SQLFavoriteManager.shared.fetchFavorites(connectionId: connectionId) - savedQueries = favorites - for favorite in favorites { - cachedSavedQueries[favorite.id] = favorite - } - } - - private func ensureSavedQueryLoaded(id: UUID) async { - if cachedSavedQueries[id] != nil { return } - if let favorite = await SQLFavoriteManager.shared.fetchFavorite(id: id) { - cachedSavedQueries[id] = favorite - } - } - - /// Ensure column + foreign-key data for `tableName` is in `columnsByTable`. - /// Idempotent and dedups concurrent calls so chip attach + send-time resolve - /// share a single fetch. - func ensureColumnsLoaded(forTable tableName: String) async { - if let existing = columnsByTable[tableName], !existing.isEmpty { return } - if let inFlight = inFlightColumnFetches[tableName] { - await inFlight.value - return - } - guard let connection, - let driver = DatabaseManager.shared.driver(for: connection.id) else { return } - let task: Task = Task { [weak self] in - let columns: [ColumnInfo] - do { - columns = try await driver.fetchColumns(table: tableName) - } catch { - Self.logger.warning("Column fetch failed for \(tableName, privacy: .public): \(error.localizedDescription, privacy: .public)") - columns = [] - } - let fkMap: [String: [ForeignKeyInfo]] - do { - fkMap = try await driver.fetchForeignKeys(forTables: [tableName]) - } catch { - Self.logger.warning("Foreign key fetch failed for \(tableName, privacy: .public): \(error.localizedDescription, privacy: .public)") - fkMap = [:] - } - guard !Task.isCancelled, let self else { return } - self.columnsByTable[tableName] = columns - if let fks = fkMap[tableName] { - self.foreignKeysByTable[tableName] = fks - } - self.inFlightColumnFetches[tableName] = nil - } - inFlightColumnFetches[tableName] = task - await task.value - } - - /// Ensure column data is loaded for all tables in the live schema (capped by - /// `maxSchemaTables`). Used by `@Schema` chip resolution and the - /// auto-include-schema system-prompt path. - func ensureSchemaLoaded() async { - if let inFlight = inFlightSchemaLoad { - await inFlight.value - return - } - let task: Task = Task { [weak self] in - guard let self else { return } - await self.runSchemaLoad() - } - inFlightSchemaLoad = task - await task.value - inFlightSchemaLoad = nil - } - - private func runSchemaLoad() async { - guard let connection, - let driver = DatabaseManager.shared.driver(for: connection.id) else { return } - let settings = AppSettingsManager.shared.ai - let tablesToFetch = Array(tables.prefix(settings.maxSchemaTables)) - guard !tablesToFetch.isEmpty else { return } - - await withTaskGroup(of: (String, [ColumnInfo]).self) { group in - for table in tablesToFetch where (columnsByTable[table.name] ?? []).isEmpty { - let name = table.name - group.addTask { - do { - let cols = try await driver.fetchColumns(table: name) - return (name, cols) - } catch { - Self.logger.warning("Schema column fetch failed for \(name, privacy: .public): \(error.localizedDescription, privacy: .public)") - return (name, []) - } - } - } - for await (name, cols) in group { - columnsByTable[name] = cols - } - } - - guard !Task.isCancelled else { return } - - let needsFKFetch = tablesToFetch.contains { foreignKeysByTable[$0.name] == nil } - guard needsFKFetch else { return } - do { - let fkMap = try await driver.fetchForeignKeys(forTables: tablesToFetch.map(\.name)) - for (name, fks) in fkMap { - foreignKeysByTable[name] = fks - } - } catch { - Self.logger.warning("Foreign key bulk fetch failed: \(error.localizedDescription, privacy: .public)") - } - } - func detach(_ item: ContextItem) { attachedContext.removeAll { $0.stableKey == item.stableKey } } - /// Produce a wire-ready copy of a turn with `.attachment` blocks expanded - /// into appended text. Awaits any uncached column/foreign-key data so the - /// AI receives real schema instead of a "(columns not loaded)" placeholder. - /// The stored `messages` array keeps the raw form so `editMessage` can - /// recover the typed text and attachments cleanly. - func resolveTurnForWire(_ turn: ChatTurn) async -> ChatTurn { - let attachments = turn.blocks.compactMap { block -> ContextItem? in - if case .attachment(let item) = block { return item } - return nil - } - guard !attachments.isEmpty else { return turn } - - for item in attachments { - await primeAttachmentData(for: item) - } - - let typed = turn.blocks.compactMap { block -> String? in - if case .text(let value) = block { return value } - return nil - }.joined() - - let resolved = attachments - .compactMap { resolveAttachment($0) } - .joined(separator: "\n\n") - if resolved.isEmpty { return turn } - - let combined = typed.isEmpty ? resolved : typed + "\n\n---\n\n" + resolved - return ChatTurn( - id: turn.id, - role: turn.role, - blocks: [.text(combined)], - timestamp: turn.timestamp, - usage: turn.usage, - modelId: turn.modelId, - providerId: turn.providerId - ) - } - - private func resolveAttachment(_ item: ContextItem) -> String? { - switch item { - case .schema: - return resolveSchemaAttachment() - case .table(_, let name): - return resolveTableAttachment(name: name) - case .currentQuery(let text): - let snapshot = text.isEmpty ? (currentQuery ?? "") : text - guard !snapshot.isEmpty else { return nil } - return "## Current Query\n```\n\(snapshot)\n```" - case .queryResult(let summary): - let snapshot = summary.isEmpty ? (queryResults ?? "") : summary - guard !snapshot.isEmpty else { return nil } - return "## Query Results\n\(snapshot)" - case .savedQuery(let id, let name): - return resolveSavedQueryAttachment(id: id, fallbackName: name) - case .file: - return nil - } - } - - private func resolveSavedQueryAttachment(id: UUID, fallbackName: String) -> String? { - guard let favorite = cachedSavedQueries[id] else { return nil } - let displayName = favorite.name.isEmpty ? fallbackName : favorite.name - let header = displayName.isEmpty - ? String(localized: "Saved Query") - : "\(String(localized: "Saved Query")): \(displayName)" - return "## \(header)\n```sql\n\(favorite.query)\n```" - } - - private func resolveSchemaAttachment() -> String? { - guard !tables.isEmpty else { return nil } - let settings = AppSettingsManager.shared.ai - let identifierQuote = connection.flatMap { - PluginManager.shared.sqlDialect(for: $0.type)?.identifierQuote - } ?? "\"" - let section = AISchemaContext.buildSchemaSection( - tables: tables, - columnsByTable: columnsByTable, - foreignKeys: foreignKeysByTable, - maxTables: settings.maxSchemaTables, - identifierQuote: identifierQuote - ) - guard !section.isEmpty else { return nil } - return "## Schema\n\(section)" - } - - private func resolveTableAttachment(name: String) -> String? { - let columns = columnsByTable[name] ?? [] - guard !columns.isEmpty else { return nil } - let foreignKeys = foreignKeysByTable[name] ?? [] - var lines: [String] = ["## Table \(name)"] - for column in columns { - lines.append("- \(column.name): \(column.dataType)") - } - if !foreignKeys.isEmpty { - lines.append("Foreign keys:") - for foreign in foreignKeys { - lines.append("- \(foreign.column) -> \(foreign.referencedTable).\(foreign.referencedColumn)") - } - } - return lines.joined(separator: "\n") - } - - /// Send a pre-filled prompt - func sendWithContext(prompt: String) { - let userMessage = ChatTurn(role: .user, blocks: [.text(prompt)]) - messages.append(userMessage) - trimMessagesIfNeeded() - errorMessage = nil - - startStreaming() - } - - /// Cancel the current streaming response func cancelStream() { prepTask?.cancel() prepTask = nil streamingTask?.cancel() streamingTask = nil ToolApprovalCenter.shared.cancelAll() - isStreaming = false - // Remove empty assistant placeholder left by cancelled stream - if let assistantID = streamingAssistantID, + if case .streaming(let assistantID) = streamingState, let idx = messages.firstIndex(where: { $0.id == assistantID }), messages[idx].plainText.isEmpty { messages.remove(at: idx) } - streamingAssistantID = nil + streamingState = .idle persistCurrentConversation() } - /// Clear all recent conversations - func clearConversation() { - cancelStream() - AIProviderFactory.resetCopilotConversation() - Task { await chatStorage.deleteAll() } - conversations.removeAll() - messages.removeAll() - activeConversationID = nil - errorMessage = nil - } - - /// Retry the last failed message func retry() { guard lastMessageFailed else { return } @@ -590,12 +152,11 @@ final class AIChatViewModel { guard messages.last?.role == .user else { return } - lastMessageFailed = false + streamingState = .idle errorMessage = nil startStreaming() } - /// Regenerate the last assistant response func regenerate() { guard !isStreaming, let lastAssistantIndex = messages.lastIndex(where: { $0.role == .assistant }) @@ -603,56 +164,25 @@ final class AIChatViewModel { AIProviderFactory.copilotDeleteLastTurn() messages.remove(at: lastAssistantIndex) - errorMessage = nil - startStreaming() - } - - /// User confirmed AI access for the current connection - func confirmAIAccess() { - if let connectionID = connection?.id { - sessionApprovedConnections.insert(connectionID) - } - guard pendingApproval else { return } - pendingApproval = false + clearError() startStreaming() } - /// User denied AI access for the current connection - func denyAIAccess() { - pendingApproval = false - if let last = messages.last, last.role == .user { - messages.removeLast() - } - } - - // MARK: - Conversation Management - - /// Load saved conversations from disk - func loadConversations() { - let storage = chatStorage - Task.detached(priority: .utility) { [weak self] in - let loaded = await storage.loadAll() - await MainActor.run { - guard let self else { return } - self.conversations = loaded - if let mostRecent = loaded.first { - self.activeConversationID = mostRecent.id - self.messages = mostRecent.messages - } - } + func clearError() { + errorMessage = nil + if case .failed = streamingState { + streamingState = .idle } } - /// Start a new conversation func startNewConversation() { cancelStream() persistCurrentConversation() messages.removeAll() activeConversationID = nil - errorMessage = nil + clearError() } - /// Switch to an existing conversation func switchConversation(to id: UUID) { guard let conversation = conversations.first(where: { $0.id == id }) else { return } AIProviderFactory.resetCopilotConversation() @@ -660,11 +190,9 @@ final class AIChatViewModel { persistCurrentConversation() messages = conversation.messages activeConversationID = conversation.id - errorMessage = nil + clearError() } - /// Release all session-specific data to free memory on disconnect. - /// Unlike `clearConversation()`, this does not delete persisted history. func clearSessionData() { AIProviderFactory.resetCopilotConversation() prepTask?.cancel() @@ -683,567 +211,83 @@ final class AIChatViewModel { queryResults = nil messages = [] errorMessage = nil - lastMessageFailed = false activeConversationID = nil sessionApprovedConnections = [] - isStreaming = false - streamingAssistantID = nil - pendingApproval = false - } - - /// Delete a conversation - func deleteConversation(_ id: UUID) { - if activeConversationID == id { - AIProviderFactory.resetCopilotConversation() - } - Task { await chatStorage.delete(id) } - conversations.removeAll { $0.id == id } - if activeConversationID == id { - activeConversationID = nil - messages.removeAll() - } - } - - /// Persist the current conversation to disk - func persistCurrentConversation() { - guard !messages.isEmpty else { return } - - if let existingID = activeConversationID, - var conversation = conversations.first(where: { $0.id == existingID }) { - // Update existing conversation - conversation.messages = messages - conversation.updatedAt = Date() - conversation.updateTitle() - conversation.connectionName = connection?.name - Task { await chatStorage.save(conversation) } - - if let index = conversations.firstIndex(where: { $0.id == existingID }) { - conversations[index] = conversation - } - } else { - // Create new conversation - var conversation = AIConversation( - messages: messages, - connectionName: connection?.name - ) - conversation.updateTitle() - Task { await chatStorage.save(conversation) } - activeConversationID = conversation.id - conversations.insert(conversation, at: 0) - } + streamingState = .idle } - // MARK: - Private Methods - - /// Trims the messages array to stay within `maxMessageCount`, removing oldest messages first. - private func trimMessagesIfNeeded() { - if messages.count > Self.maxMessageCount { - messages.removeFirst(messages.count - Self.maxMessageCount) - } - // Ensure conversation starts with a user message (required by some providers) - while messages.first?.role == .assistant { - messages.removeFirst() - } + func handleFixError(query: String, error: String) { + startNewConversation() + let databaseType = connection?.type ?? .mysql + let prompt = AIPromptTemplates.fixError(query: query, error: error, databaseType: databaseType) + sendWithContext(prompt: prompt) } - private func startStreaming() { - prepTask?.cancel() - prepTask = nil - if streamingTask != nil { - streamingTask?.cancel() - streamingTask = nil - if let id = streamingAssistantID, - let idx = messages.firstIndex(where: { $0.id == id }), - messages[idx].plainText.isEmpty { - messages.remove(at: idx) - } - streamingAssistantID = nil - isStreaming = false - } - - lastMessageFailed = false - + func loadAvailableModels() async { let settings = AppSettingsManager.shared.ai + let pending = settings.providers.filter { availableModels[$0.id] == nil } + guard !pending.isEmpty else { return } - let resolved = AIProviderFactory.resolve(settings: settings, overrideProviderId: selectedProviderId, overrideModel: selectedModel) - guard let resolved else { - errorMessage = String(localized: "No AI provider configured. Go to Settings > AI to add one.") - return - } - - if connection != nil { - if let policy = resolveConnectionPolicy(settings: settings) { - if policy == .never { - errorMessage = String(localized: "AI is disabled for this connection.") - if let last = messages.last, last.role == .user { - messages.removeLast() - } - return - } - if policy == .askEachTime { - pendingApproval = true - showAIAccessConfirmation = true - return - } - } - } - - let assistantMessage = ChatTurn(role: .assistant, blocks: [], modelId: resolved.model, providerId: resolved.config.id.uuidString) - messages.append(assistantMessage) - trimMessagesIfNeeded() - let assistantID = assistantMessage.id - streamingAssistantID = assistantID - - isStreaming = true - - prepTask?.cancel() - prepTask = Task { [weak self] in - guard let self else { return } - if settings.includeSchema { - await self.ensureSchemaLoaded() - } - guard !Task.isCancelled else { return } - let promptContext = self.capturePromptContext(settings: settings) - var chatMessages: [ChatTurn] = [] - for turn in self.messages.dropLast() { - chatMessages.append(await self.resolveTurnForWire(turn)) - } - guard !Task.isCancelled else { return } - self.runStream( - chatMessages: chatMessages, - promptContext: promptContext, - resolved: resolved, - assistantID: assistantID, - settings: settings - ) - self.prepTask = nil - } - } - - private static let maxToolRoundtrips = 10 - - private func runStream( - chatMessages: [ChatTurn], - promptContext: PromptContext?, - resolved: AIProviderFactory.ResolvedProvider, - assistantID: UUID, - settings: AISettings - ) { - let chatMode = settings.chatMode - streamingTask = Task.detached(priority: .userInitiated) { [weak self] in - do { - let systemPrompt = Self.buildSystemPrompt(promptContext, mode: chatMode) - guard let self else { return } - let preflightOK = await self.preflightCheck( - systemPrompt: systemPrompt, - turns: chatMessages, - assistantID: assistantID - ) - guard preflightOK else { return } - - let toolSpecs = await MainActor.run { ChatToolRegistry.shared.allSpecs(for: chatMode) } - var workingTurns = chatMessages - var currentAssistantID = assistantID - let flushInterval: ContinuousClock.Duration = .milliseconds(150) - - for roundtrip in 0..= flushInterval { - await self.flushPending( - content: pendingContent, - usage: pendingUsage, - into: assistantIDForRound - ) - pendingContent = "" - pendingUsage = nil - lastFlushTime = .now - } - } - - if !Task.isCancelled, !pendingContent.isEmpty || pendingUsage != nil { - await self.flushPending( - content: pendingContent, - usage: pendingUsage, - into: assistantIDForRound - ) - } - - guard !Task.isCancelled else { return } - - if toolUseOrder.isEmpty { break } - - - if roundtrip == Self.maxToolRoundtrips - 1 { - await MainActor.run { [weak self] in - guard let self else { return } - self.errorMessage = String( - localized: "AI made too many tool calls in one response. Try simplifying the request." - ) - if let idx = self.messages.firstIndex(where: { $0.id == currentAssistantID }), - self.messages[idx].plainText.isEmpty { - self.messages.remove(at: idx) - } - self.lastMessageFailed = true - } - break - } - - let assembledBlocks = Self.assembleToolUseBlocks( - order: toolUseOrder, - names: toolUseNames, - inputs: toolUseInputs - ) - let context = await MainActor.run { - ChatToolContext( - connectionId: self.connection?.id, - bridge: ChatToolBootstrap.bridge, - authPolicy: ChatToolBootstrap.authPolicy - ) - } - let toolUseBlocks = await self.resolveAndAwaitApprovals( - assembledBlocks: assembledBlocks, - assistantID: assistantIDForRound - ) - guard !Task.isCancelled else { return } - - let approvedBlocks = toolUseBlocks.filter { - if case .approved = $0.approvalState { return true } - return false - } - let executedResults = await Self.executeToolUses( - approvedBlocks, mode: chatMode, context: context - ) - guard !Task.isCancelled else { return } - - let toolResultBlocks = Self.synthesizeResults( - for: toolUseBlocks, - executed: executedResults - ) - - let continuation = await self.completeToolRoundtrip( - assistantIDForRound: assistantIDForRound, - toolUseBlocks: toolUseBlocks, - toolResultBlocks: toolResultBlocks, - resolved: resolved - ) - currentAssistantID = continuation.nextAssistantID - workingTurns.append(continuation.assistantTurn) - workingTurns.append(continuation.userTurn) - } - - guard !Task.isCancelled else { return } - await MainActor.run { [weak self] in - guard let self else { return } - self.isStreaming = false - self.streamingTask = nil - self.streamingAssistantID = nil - self.persistCurrentConversation() + let results = await withTaskGroup(of: (UUID, [String]?).self) { group in + for config in pending { + let apiKey: String? + switch config.type.authStyle { + case .apiKey: + apiKey = AIKeyStorage.shared.loadAPIKey(for: config.id) + case .oauth, .none: + apiKey = nil } - } catch { - await MainActor.run { [weak self] in - guard let self else { return } - if !Task.isCancelled { - Self.logger.error("Streaming failed: \(error.localizedDescription)") - self.lastMessageFailed = true - self.errorMessage = error.localizedDescription - self.lastError = error as? AIProviderError - - // Remove empty assistant message on error - if let idx = self.messages.firstIndex(where: { $0.id == assistantID }), - self.messages[idx].plainText.isEmpty { - self.messages.remove(at: idx) - } + group.addTask { + let transport = await AIProviderFactory.createProvider(for: config, apiKey: apiKey) + do { + let models = try await transport.fetchAvailableModels() + return (config.id, models) + } catch is CancellationError { + return (config.id, nil) + } catch { + return (config.id, []) } - self.isStreaming = false - self.streamingTask = nil - self.streamingAssistantID = nil } } - } - } - - nonisolated private static func buildSystemPrompt(_ promptContext: PromptContext?, mode: AIChatMode) -> String? { - let schemaPrompt = promptContext.map { - AISchemaContext.buildSystemPrompt( - databaseType: $0.databaseType, - databaseName: $0.databaseName, - tables: $0.tables, - columnsByTable: $0.columnsByTable, - foreignKeys: $0.foreignKeys, - currentQuery: $0.currentQuery, - queryResults: $0.queryResults, - settings: $0.settings, - identifierQuote: $0.identifierQuote, - editorLanguage: $0.editorLanguage, - queryLanguageName: $0.queryLanguageName, - connectionRules: $0.connectionRules - ) - } - let modeNote = mode.systemPromptNote - guard let schemaPrompt, !schemaPrompt.isEmpty else { return modeNote } - return "\(schemaPrompt)\n\n\(modeNote)" - } - private struct ToolRoundtripContinuation { - let nextAssistantID: UUID - let assistantTurn: ChatTurn - let userTurn: ChatTurn - } - - private func completeToolRoundtrip( - assistantIDForRound: UUID, - toolUseBlocks: [ToolUseBlock], - toolResultBlocks: [ToolResultBlock], - resolved: AIProviderFactory.ResolvedProvider - ) async -> ToolRoundtripContinuation { - await MainActor.run { [weak self] () -> ToolRoundtripContinuation in - let assistantText: String = { - guard let self, - let idx = self.messages.firstIndex(where: { $0.id == assistantIDForRound }) - else { return "" } - return self.messages[idx].plainText - }() - var assistantBlocks: [ChatContentBlock] = [] - if !assistantText.isEmpty { assistantBlocks.append(.text(assistantText)) } - assistantBlocks.append(contentsOf: toolUseBlocks.map { .toolUse($0) }) - let assistantTurn = ChatTurn( - id: assistantIDForRound, - role: .assistant, - blocks: assistantBlocks, - modelId: resolved.model, - providerId: resolved.config.id.uuidString - ) - let userTurn = ChatTurn( - role: .user, - blocks: toolResultBlocks.map { .toolResult($0) } - ) - let nextAssistant = ChatTurn( - role: .assistant, - blocks: [], - modelId: resolved.model, - providerId: resolved.config.id.uuidString - ) - self?.messages.append(userTurn) - self?.messages.append(nextAssistant) - self?.streamingAssistantID = nextAssistant.id - return ToolRoundtripContinuation( - nextAssistantID: nextAssistant.id, - assistantTurn: assistantTurn, - userTurn: userTurn - ) - } - } - - private func flushPending(content: String, usage: AITokenUsage?, into assistantID: UUID) async { - guard !content.isEmpty || usage != nil else { return } - await MainActor.run { [weak self] in - guard let self, - let idx = self.messages.firstIndex(where: { $0.id == assistantID }) - else { return } - if !content.isEmpty { - self.messages[idx].appendText(content) - } - if let usage { - self.messages[idx].usage = usage + var collected: [(UUID, [String]?)] = [] + for await result in group { + collected.append(result) } + return collected } - } - private func preflightCheck(systemPrompt: String?, turns: [ChatTurn], assistantID: UUID) async -> Bool { - let totalSize = ((systemPrompt ?? "") as NSString).length - + turns.reduce(0) { $0 + ($1.plainText as NSString).length } - guard totalSize > 100_000 else { return true } - await MainActor.run { [weak self] in - guard let self else { return } - self.errorMessage = String( - localized: "Message too large. Try disabling 'Include schema' or 'Include query results' in AI settings." - ) - if let idx = self.messages.firstIndex(where: { $0.id == assistantID }) { - self.messages.remove(at: idx) - } - self.isStreaming = false - self.streamingAssistantID = nil - } - return false - } + guard !Task.isCancelled else { return } - nonisolated static func assembleToolUseBlocks( - order: [String], - names: [String: String], - inputs: [String: String] - ) -> [ToolUseBlock] { - order.compactMap { id -> ToolUseBlock? in - guard let name = names[id] else { return nil } - let inputString = inputs[id] ?? "{}" - let inputValue: JSONValue - if inputString.isEmpty { - inputValue = .object([:]) - } else if let data = inputString.data(using: .utf8), - let decoded = try? JSONDecoder().decode(JSONValue.self, from: data) { - inputValue = decoded + for (id, models) in results { + guard let models else { continue } + if models.isEmpty { + let fallback = pending.first(where: { $0.id == id })?.model + availableModels[id] = (fallback?.isEmpty == false) ? [fallback ?? ""] : [] } else { - inputValue = .object([:]) - } - return ToolUseBlock(id: id, name: name, input: inputValue) - } - } - - /// Execute the given tool-use blocks in parallel via a `withTaskGroup`, - /// returning result blocks in the same order. The `registry` parameter - /// defaults to the shared singleton; tests inject a fresh instance to - /// avoid polluting global state. - nonisolated static func executeToolUses( - _ blocks: [ToolUseBlock], - mode: AIChatMode, - context: ChatToolContext, - registry: ChatToolRegistry? = nil - ) async -> [ToolResultBlock] { - await withTaskGroup(of: (Int, ToolResultBlock).self) { group in - for (index, block) in blocks.enumerated() { - group.addTask { - (index, await runToolUse(block, mode: mode, context: context, registry: registry)) - } + availableModels[id] = models } - var indexed: [(Int, ToolResultBlock)] = [] - for await pair in group { indexed.append(pair) } - return indexed.sorted(by: { $0.0 < $1.0 }).map(\.1) } } - nonisolated private static func runToolUse( - _ block: ToolUseBlock, - mode: AIChatMode, - context: ChatToolContext, - registry: ChatToolRegistry? - ) async -> ToolResultBlock { - if Task.isCancelled { - return ToolResultBlock(toolUseId: block.id, content: "Cancelled", isError: true) - } - guard ChatToolRegistry.isToolAllowed(name: block.name, in: mode) else { - Self.logger.warning( - "Tool '\(block.name, privacy: .public)' blocked in \(mode.rawValue, privacy: .public) mode" - ) - return ToolResultBlock( - toolUseId: block.id, - content: "Tool '\(block.name)' is not available in \(mode.displayName) mode", - isError: true - ) - } - let tool = await MainActor.run { - (registry ?? ChatToolRegistry.shared).tool(named: block.name, in: mode) - } - guard let tool else { - Self.logger.warning("Tool '\(block.name, privacy: .public)' not registered; returning error") - return ToolResultBlock( - toolUseId: block.id, - content: "Tool '\(block.name)' is not available", - isError: true - ) + func loadSavedQueries() async { + guard let connectionId = connection?.id else { + savedQueries = [] + return } - do { - let result = try await tool.execute(input: block.input, context: context) - return ToolResultBlock( - toolUseId: block.id, - content: result.content, - isError: result.isError - ) - } catch { - Self.logger.warning( - "Tool \(block.name, privacy: .public) execution failed: \(error.localizedDescription, privacy: .public)" - ) - return ToolResultBlock( - toolUseId: block.id, - content: "Error: \(error.localizedDescription)", - isError: true - ) + let favorites = await SQLFavoriteManager.shared.fetchFavorites(connectionId: connectionId) + savedQueries = favorites + for favorite in favorites { + cachedSavedQueries[favorite.id] = favorite } } - private func resolveConnectionPolicy(settings: AISettings) -> AIConnectionPolicy? { - let policy = connection?.aiPolicy ?? settings.defaultConnectionPolicy - - if policy == .askEachTime { - // If already approved this session, treat as always allow - if let connectionID = connection?.id, sessionApprovedConnections.contains(connectionID) { - return .alwaysAllow - } - return .askEachTime + func trimMessagesIfNeeded() { + if messages.count > Self.maxMessageCount { + messages.removeFirst(messages.count - Self.maxMessageCount) + } + while messages.first?.role == .assistant { + messages.removeFirst() } - - return policy - } - - private struct PromptContext: Sendable { - let databaseType: DatabaseType - let databaseName: String - let tables: [TableInfo] - let columnsByTable: [String: [ColumnInfo]] - let foreignKeys: [String: [ForeignKeyInfo]] - let currentQuery: String? - let queryResults: String? - let settings: AISettings - let identifierQuote: String - let editorLanguage: EditorLanguage - let queryLanguageName: String - let connectionRules: String? - } - - private func capturePromptContext(settings: AISettings) -> PromptContext? { - guard let connection else { return nil } - return PromptContext( - databaseType: connection.type, - databaseName: DatabaseManager.shared.activeDatabaseName(for: connection), - tables: tables, - columnsByTable: columnsByTable, - foreignKeys: foreignKeysByTable, - currentQuery: settings.includeCurrentQuery ? currentQuery : nil, - queryResults: settings.includeQueryResults ? queryResults : nil, - settings: settings, - identifierQuote: PluginManager.shared.sqlDialect(for: connection.type)?.identifierQuote ?? "\"", - editorLanguage: PluginManager.shared.editorLanguage(for: connection.type), - queryLanguageName: PluginManager.shared.queryLanguageName(for: connection.type), - connectionRules: connection.aiRules - ) } } diff --git a/TablePro/Views/AIChat/AIChatPanelView.swift b/TablePro/Views/AIChat/AIChatPanelView.swift index e1c28610d..8b2a03292 100644 --- a/TablePro/Views/AIChat/AIChatPanelView.swift +++ b/TablePro/Views/AIChat/AIChatPanelView.swift @@ -191,7 +191,7 @@ struct AIChatPanelView: View { .lineLimit(2) Spacer() Button { - viewModel.errorMessage = nil + viewModel.clearError() } label: { Image(systemName: "xmark") .frame(width: 24, height: 24) From 012ee0b8ada9f2e3594a386ebdec5048f14562e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 10:19:13 +0700 Subject: [PATCH 03/16] refactor(ai-providers): extract endpoint normalization and JSON encoding --- TablePro/Core/AI/GeminiProvider.swift | 68 +++++++-------------------- 1 file changed, 18 insertions(+), 50 deletions(-) diff --git a/TablePro/Core/AI/GeminiProvider.swift b/TablePro/Core/AI/GeminiProvider.swift index ebe7f3980..32efe46aa 100644 --- a/TablePro/Core/AI/GeminiProvider.swift +++ b/TablePro/Core/AI/GeminiProvider.swift @@ -15,7 +15,7 @@ final class GeminiProvider: ChatTransport { private let session: URLSession init(endpoint: String, apiKey: String, maxOutputTokens: Int = 8_192) { - self.endpoint = endpoint.hasSuffix("/") ? String(endpoint.dropLast()) : endpoint + self.endpoint = endpoint.normalizedEndpoint() self.apiKey = apiKey.trimmingCharacters(in: .whitespacesAndNewlines) self.maxOutputTokens = maxOutputTokens self.session = URLSession(configuration: .ephemeral) @@ -37,9 +37,10 @@ final class GeminiProvider: ChatTransport { guard httpResponse.statusCode == 200 else { let errorBody = try await collectErrorBody(from: bytes) - throw mapHTTPError( + throw AIProviderError.mapHTTPError( statusCode: httpResponse.statusCode, - body: errorBody + body: errorBody, + treatForbiddenAsAuthFailure: true ) } @@ -85,6 +86,7 @@ final class GeminiProvider: ChatTransport { var request = URLRequest(url: url) request.httpMethod = "GET" + request.timeoutInterval = AIProvider.modelListTimeout request.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key") let data: Data @@ -92,20 +94,16 @@ final class GeminiProvider: ChatTransport { do { (data, response) = try await session.data(for: request) } catch { + Self.logger.warning("Gemini model fetch failed; using known models: \(error.localizedDescription, privacy: .public)") return Self.knownModels } - guard let httpResponse = response as? HTTPURLResponse else { - return Self.knownModels - } - - guard httpResponse.statusCode == 200 else { - return Self.knownModels - } - - guard let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], + guard let httpResponse = response as? HTTPURLResponse, + httpResponse.statusCode == 200, + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], let models = json["models"] as? [[String: Any]] else { + Self.logger.warning("Gemini model fetch returned unexpected response; using known models") return Self.knownModels } @@ -146,7 +144,11 @@ final class GeminiProvider: ChatTransport { guard statusCode == 200 else { let body = String(data: data, encoding: .utf8) ?? "" - throw mapHTTPError(statusCode: statusCode, body: body) + throw AIProviderError.mapHTTPError( + statusCode: statusCode, + body: body, + treatForbiddenAsAuthFailure: true + ) } return true @@ -177,14 +179,12 @@ final class GeminiProvider: ChatTransport { } if !options.tools.isEmpty { - let declarations = options.tools.map { tool -> [String: Any] in + let declarations = try options.tools.map { tool -> [String: Any] in var entry: [String: Any] = [ "name": tool.name, "description": tool.description ] - if let parameters = jsonValueToAny(tool.inputSchema) { - entry["parameters"] = parameters - } + entry["parameters"] = try tool.inputSchema.asJSONObject() return entry } body["tools"] = [["functionDeclarations": declarations]] @@ -218,7 +218,7 @@ final class GeminiProvider: ChatTransport { case .attachment: continue case .toolUse(let useBlock): - let argsObject = jsonValueToAny(useBlock.input) ?? [String: Any]() + let argsObject = (try? useBlock.input.asJSONObject()) ?? [String: Any]() parts.append([ "functionCall": [ "name": useBlock.name, @@ -259,38 +259,6 @@ final class GeminiProvider: ChatTransport { return nil } - func jsonValueToAny(_ value: JSONValue) -> Any? { - switch value { - case .null: - return NSNull() - case .bool(let bool): - return bool - case .integer(let int): - return int - case .number(let double): - return double - case .string(let string): - return string - case .array(let array): - return array.map { jsonValueToAny($0) ?? NSNull() } - case .object(let object): - var dict: [String: Any] = [:] - for (key, child) in object { - dict[key] = jsonValueToAny(child) ?? NSNull() - } - return dict - } - } - - - func mapHTTPError(statusCode: Int, body: String) -> AIProviderError { - if statusCode == 403 { - let message = AIProviderError.parseErrorMessage(from: body) ?? body - return .authenticationFailed(message) - } - return AIProviderError.mapHTTPError(statusCode: statusCode, body: body) - } - /// Decodes one Gemini SSE line. Returns nil for non-data lines. static func decodeStreamLine(_ line: String) -> [String: Any]? { guard line.hasPrefix("data: ") else { return nil } From 13a7ae3722bbab95962df2f87cc450b2888695a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 10:19:20 +0700 Subject: [PATCH 04/16] refactor(mcp): split HTTP transport, migrate pairing store to actor, add SSE keep-alive --- CHANGELOG.md | 1 + TablePro/Core/MCP/MCPPairingService.swift | 25 +- .../Transport/MCPHttpConnectionContext.swift | 245 ++++++ .../MCP/Transport/MCPHttpRequestRouter.swift | 489 +++++++++++ .../Transport/MCPHttpServerTransport.swift | 789 ++---------------- .../Core/MCP/Transport/MCPSseWriter.swift | 61 ++ .../Core/MCP/MCPPairingServiceTests.swift | 84 +- .../MCPHttpServerTransportPairingTests.swift | 22 +- 8 files changed, 930 insertions(+), 786 deletions(-) create mode 100644 TablePro/Core/MCP/Transport/MCPHttpConnectionContext.swift create mode 100644 TablePro/Core/MCP/Transport/MCPHttpRequestRouter.swift create mode 100644 TablePro/Core/MCP/Transport/MCPSseWriter.swift diff --git a/CHANGELOG.md b/CHANGELOG.md index c886bbe55..b306b3e8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - AI inline suggestions: debounce now uses structured Swift concurrency, and the delay is configurable via the `inlineSuggestionDebounceMs` setting (default 500ms) - Copilot LSP shutdown caps at 10 seconds, closes pipes explicitly, and strips the quarantine attribute from the downloaded binary - AI Chat: streaming view model split into focused extensions backed by a single `streamingState` enum +- MCP HTTP server: split transport into connection, router, and SSE writer files; pairing exchange store moved to a Swift actor; SSE streams send a 30-second keep-alive ## [0.39.1] - 2026-05-08 diff --git a/TablePro/Core/MCP/MCPPairingService.swift b/TablePro/Core/MCP/MCPPairingService.swift index 1b6729767..700e7904e 100644 --- a/TablePro/Core/MCP/MCPPairingService.swift +++ b/TablePro/Core/MCP/MCPPairingService.swift @@ -9,16 +9,13 @@ struct PairingExchangeRecord: Sendable, Equatable { let expiresAt: Date } -final class PairingExchangeStore: @unchecked Sendable { +actor PairingExchangeStore { static let exchangeWindow: TimeInterval = 300 static let maxPendingCodes = 50 - private let lock = NSLock() private var pending: [String: PairingExchangeRecord] = [:] func insert(code: String, record: PairingExchangeRecord) throws { - lock.lock() - defer { lock.unlock() } prune(now: Date.now) guard pending.count < Self.maxPendingCodes else { throw MCPDataLayerError.forbidden( @@ -29,8 +26,6 @@ final class PairingExchangeStore: @unchecked Sendable { } func consume(code: String, verifier: String, now: Date = .now) throws -> String { - lock.lock() - defer { lock.unlock() } prune(now: now) guard let entry = pending[code] else { @@ -53,21 +48,15 @@ final class PairingExchangeStore: @unchecked Sendable { } func pruneExpired(now: Date = .now) { - lock.lock() - defer { lock.unlock() } prune(now: now) } func count() -> Int { - lock.lock() - defer { lock.unlock() } - return pending.count + pending.count } func contains(code: String) -> Bool { - lock.lock() - defer { lock.unlock() } - return pending[code] != nil + pending[code] != nil } private func prune(now: Date) { @@ -148,7 +137,7 @@ final class MCPPairingService { let code = UUID().uuidString do { - try store.insert( + try await store.insert( code: code, record: PairingExchangeRecord( plaintextToken: result.plaintext, @@ -171,8 +160,8 @@ final class MCPPairingService { NSWorkspace.shared.open(redirect) } - func exchange(_ exchange: PairingExchange) throws -> String { - try store.consume(code: exchange.code, verifier: exchange.verifier) + func exchange(_ exchange: PairingExchange) async throws -> String { + try await store.consume(code: exchange.code, verifier: exchange.verifier) } private static func revokeExistingTokens(named name: String, in store: MCPTokenStore) async { @@ -188,7 +177,7 @@ final class MCPPairingService { while !Task.isCancelled { try? await Task.sleep(for: Self.pruneInterval) guard !Task.isCancelled else { return } - store.pruneExpired() + await store.pruneExpired() } } } diff --git a/TablePro/Core/MCP/Transport/MCPHttpConnectionContext.swift b/TablePro/Core/MCP/Transport/MCPHttpConnectionContext.swift new file mode 100644 index 000000000..bae0ba1e0 --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPHttpConnectionContext.swift @@ -0,0 +1,245 @@ +import Foundation +import Network +import os + +actor HttpConnectionContext { + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.HttpServer") + + nonisolated let id: UUID + private let connection: NWConnection + private var receiveBuffer = Data() + private var requestComplete = false + private var cancelled = false + private var sseActive = false + private var origin: String? + + init(id: UUID, connection: NWConnection) { + self.id = id + self.connection = connection + } + + func setOrigin(_ value: String?) { + origin = value + } + + private func corsHeaders() -> [(String, String)] { + MCPCorsHeaders.headers(forOrigin: origin) + } + + func start( + onData: @escaping @Sendable (Data) async -> Void, + onClosed: @escaping @Sendable () async -> Void + ) { + let nwConnection = connection + nwConnection.stateUpdateHandler = { [weak self] state in + guard let self else { return } + switch state { + case .ready: + Task { await self.beginReading(onData: onData, onClosed: onClosed) } + case .failed: + Task { await self.handleClosed(onClosed: onClosed) } + case .cancelled: + Task { await self.handleClosed(onClosed: onClosed) } + default: + break + } + } + nwConnection.start(queue: .global(qos: .userInitiated)) + } + + private func beginReading( + onData: @escaping @Sendable (Data) async -> Void, + onClosed: @escaping @Sendable () async -> Void + ) { + scheduleReceive(onData: onData, onClosed: onClosed) + } + + private func scheduleReceive( + onData: @escaping @Sendable (Data) async -> Void, + onClosed: @escaping @Sendable () async -> Void + ) { + if cancelled || requestComplete { return } + connection.receive(minimumIncompleteLength: 1, maximumLength: 65_536) { [weak self] content, _, isComplete, error in + guard let self else { return } + Task { + await self.handleReceive( + content: content, + isComplete: isComplete, + error: error, + onData: onData, + onClosed: onClosed + ) + } + } + } + + private func handleReceive( + content: Data?, + isComplete: Bool, + error: NWError?, + onData: @escaping @Sendable (Data) async -> Void, + onClosed: @escaping @Sendable () async -> Void + ) async { + if let error { + Self.logger.debug("Receive error: \(error.localizedDescription, privacy: .public)") + cancel() + await onClosed() + return + } + + if let content { + receiveBuffer.append(content) + await onData(receiveBuffer) + } + + if isComplete { + cancel() + await onClosed() + return + } + + if !requestComplete, !cancelled { + scheduleReceive(onData: onData, onClosed: onClosed) + } + } + + private func handleClosed(onClosed: @escaping @Sendable () async -> Void) async { + if !cancelled { + cancelled = true + } + await onClosed() + } + + func markRequestComplete() { + requestComplete = true + } + + func clientAddress() -> MCPClientAddress { + guard let endpoint = connection.currentPath?.remoteEndpoint, + case .hostPort(let host, _) = endpoint else { + return .loopback + } + let hostString = "\(host)" + if hostString == "127.0.0.1" || hostString == "::1" || hostString.lowercased() == "localhost" { + return .loopback + } + return .remote(hostString) + } + + func writeJsonResponse( + data: Data, + status: HttpStatus, + sessionId: MCPSessionId?, + extraHeaders: [(String, String)] + ) async { + if cancelled { return } + var headers: [(String, String)] = [ + ("Content-Type", "application/json"), + ("Connection", "close") + ] + if let sessionId { + headers.append(("Mcp-Session-Id", sessionId.rawValue)) + } + headers.append(contentsOf: extraHeaders) + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: status, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: data) + await send(payload) + } + + func writePlainJsonResponse(status: HttpStatus, body: Data) async { + if cancelled { return } + var headers: [(String, String)] = [ + ("Content-Type", "application/json"), + ("Connection", "close") + ] + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: status, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: body) + await send(payload) + } + + func writePlainJsonError(status: HttpStatus, message: String) async { + struct ErrorBody: Encodable { let error: String } + let payload = (try? JSONEncoder().encode(ErrorBody(error: message))) ?? Data() + await writePlainJsonResponse(status: status, body: payload) + } + + func writeOptions204() async { + if cancelled { return } + var headers: [(String, String)] = [("Connection", "close")] + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: .noContent, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: nil) + await send(payload) + } + + func writeNoContent() async { + if cancelled { return } + var headers: [(String, String)] = [("Connection", "close")] + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: .noContent, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: nil) + await send(payload) + } + + func writeAccepted() async { + if cancelled { return } + var headers: [(String, String)] = [("Connection", "close")] + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: .accepted, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: nil) + await send(payload) + } + + func writeSseStreamHeaders(sessionId: MCPSessionId) async { + if cancelled { return } + sseActive = true + var headers: [(String, String)] = [ + ("Content-Type", "text/event-stream"), + ("Cache-Control", "no-cache"), + ("Connection", "keep-alive"), + ("Mcp-Session-Id", sessionId.rawValue) + ] + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: .ok, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: nil) + await send(payload) + } + + func writeSseFrame(_ frame: SseFrame) async { + if cancelled { return } + let data = SseEncoder.encode(frame) + await send(data) + } + + func writeRaw(_ data: Data) async { + if cancelled { return } + await send(data) + } + + func cancel() { + if cancelled { return } + cancelled = true + connection.cancel() + } + + func isSseActive() -> Bool { + sseActive + } + + func isCancelled() -> Bool { + cancelled + } + + private func send(_ data: Data) async { + await withCheckedContinuation { (continuation: CheckedContinuation) in + connection.send(content: data, completion: .contentProcessed { error in + if let error { + Self.logger.debug("Send error: \(error.localizedDescription, privacy: .public)") + } + continuation.resume() + }) + } + } +} diff --git a/TablePro/Core/MCP/Transport/MCPHttpRequestRouter.swift b/TablePro/Core/MCP/Transport/MCPHttpRequestRouter.swift new file mode 100644 index 000000000..a2196f47b --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPHttpRequestRouter.swift @@ -0,0 +1,489 @@ +import Foundation +import os + +struct MCPHttpRequestRouter: Sendable { + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.HttpRouter") + + typealias InboundEmitter = @Sendable (MCPInboundExchange) -> AsyncStream.Continuation.YieldResult + typealias SseStarter = @Sendable (UUID, MCPSessionId, HttpConnectionContext) async -> Void + typealias ResponderSinkFactory = @Sendable (HttpConnectionContext) -> any MCPResponderSink + + let configuration: MCPHttpServerConfiguration + let sessionStore: MCPSessionStore + let authenticator: any MCPAuthenticator + let clock: any MCPClock + let emitInbound: InboundEmitter + let startSse: SseStarter + let makeResponderSink: ResponderSinkFactory + + func dispatch(head: HttpRequestHead, body: Data, context: HttpConnectionContext) async { + let clientAddress: MCPClientAddress = await context.clientAddress() + let now = await clock.now() + + await context.setOrigin(head.headers.value(for: "Origin")) + + if head.method == .post, stripQueryString(head.path) == "/v1/integrations/exchange" { + await handleIntegrationsExchange(body: body, context: context) + return + } + + switch head.method { + case .options: + await context.writeOptions204() + await context.cancel() + case .get: + await handleGetMcp(head: head, context: context, clientAddress: clientAddress) + case .post: + await handlePostMcp(head: head, body: body, context: context, clientAddress: clientAddress, now: now) + case .delete: + await handleDeleteMcp(head: head, context: context, clientAddress: clientAddress) + default: + await respondTopLevel( + context: context, + error: MCPProtocolError( + code: JsonRpcErrorCode.methodNotFound, + message: "Method not allowed", + httpStatus: .methodNotAllowed + ), + requestId: nil + ) + } + } + + private func handleIntegrationsExchange(body: Data, context: HttpConnectionContext) async { + struct ExchangeBody: Decodable { + let code: String + let codeVerifier: String + enum CodingKeys: String, CodingKey { + case code + case codeVerifier = "code_verifier" + } + } + struct ExchangeResponse: Encodable { + let token: String + } + + Self.logger.info("Integrations exchange request received (\(body.count, privacy: .public) bytes)") + let ip = Self.ipString(for: await context.clientAddress()) + + let parsed: ExchangeBody + do { + parsed = try JSONDecoder().decode(ExchangeBody.self, from: body) + } catch { + Self.logger.warning("Integrations exchange decode failed: \(error.localizedDescription, privacy: .public)") + MCPAuditLogger.logPairingExchange(outcome: .denied, ip: ip, details: "invalid JSON body") + await context.writePlainJsonError(status: .badRequest, message: "Invalid JSON body") + await context.cancel() + return + } + + guard !parsed.code.isEmpty, !parsed.codeVerifier.isEmpty else { + Self.logger.warning("Integrations exchange missing code or verifier") + MCPAuditLogger.logPairingExchange( + outcome: .denied, + ip: ip, + details: "missing code or code_verifier" + ) + await context.writePlainJsonError(status: .badRequest, message: "Missing code or code_verifier") + await context.cancel() + return + } + + guard parsed.code.utf8.count <= 1_024, parsed.codeVerifier.utf8.count <= 1_024 else { + Self.logger.warning("Integrations exchange field exceeds size cap") + MCPAuditLogger.logPairingExchange( + outcome: .denied, + ip: ip, + details: "field exceeds 1_024 bytes" + ) + await context.writePlainJsonError(status: .badRequest, message: "Field exceeds size limit") + await context.cancel() + return + } + + let exchange = PairingExchange(code: parsed.code, verifier: parsed.codeVerifier) + let outcome: Result + do { + let token = try await MCPPairingService.shared.exchange(exchange) + outcome = .success(token) + } catch { + outcome = .failure(error) + } + + switch outcome { + case .success(let token): + Self.logger.info("Integrations exchange succeeded (token len=\(token.count, privacy: .public))") + let label = await Self.resolveTokenLabel(for: token) + MCPAuditLogger.logPairingExchange(outcome: .success, tokenName: label, ip: ip) + let payload = (try? JSONEncoder().encode(ExchangeResponse(token: token))) ?? Data() + await context.writePlainJsonResponse(status: .ok, body: payload) + await context.cancel() + case .failure(let error): + let mapped = Self.mapExchangeError(error) + Self.logger.warning("Integrations exchange failed: status=\(mapped.status.code, privacy: .public) reason=\(mapped.message, privacy: .public)") + MCPAuditLogger.logPairingExchange( + outcome: .denied, + ip: ip, + details: mapped.message + ) + await context.writePlainJsonError(status: mapped.status, message: mapped.message) + await context.cancel() + } + } + + private static func ipString(for address: MCPClientAddress) -> String { + switch address { + case .loopback: + return "127.0.0.1" + case .remote(let host): + return host + } + } + + private static func resolveTokenLabel(for plaintext: String) async -> String? { + let store: MCPTokenStore? = await MainActor.run { MCPServerManager.shared.tokenStore } + guard let store else { return nil } + return await store.validate(bearerToken: plaintext)?.name + } + + private static func mapExchangeError(_ error: Error) -> (status: HttpStatus, message: String) { + guard let domainError = error as? MCPDataLayerError else { + return (.internalServerError, "Internal error") + } + switch domainError { + case .notFound: + return (.notFound, "Pairing code not found") + case .expired: + return (HttpStatus(code: 410, reasonPhrase: "Gone"), "Pairing code expired") + case .forbidden: + return (.forbidden, "Challenge mismatch") + default: + return (.internalServerError, "Internal error") + } + } + + private func handleGetMcp( + head: HttpRequestHead, + context: HttpConnectionContext, + clientAddress: MCPClientAddress + ) async { + guard pathMatchesMcp(head.path) else { + await respondTopLevel( + context: context, + error: MCPProtocolError( + code: JsonRpcErrorCode.methodNotFound, + message: "Method not found", + httpStatus: .notFound + ), + requestId: nil + ) + return + } + + guard let sessionIdRaw = head.headers.value(for: "Mcp-Session-Id") else { + await respondTopLevel(context: context, error: .missingSessionId(), requestId: nil) + return + } + + if head.headers.value(for: "Last-Event-ID") != nil { + await respondTopLevel( + context: context, + error: MCPProtocolError( + code: JsonRpcErrorCode.serverError, + message: "SSE event replay is not supported", + httpStatus: .notImplemented + ), + requestId: nil + ) + return + } + + if let accept = head.headers.value(for: "Accept"), + !accept.lowercased().contains("text/event-stream"), + !accept.contains("*/*") { + await respondTopLevel(context: context, error: .notAcceptable(), requestId: nil) + return + } + + let authResult = await authenticate(headers: head.headers, clientAddress: clientAddress) + guard case .allow = authResult else { + if case .deny(let error) = authResult { + await respondTopLevel(context: context, error: error, requestId: nil) + } + return + } + + let sessionId = MCPSessionId(sessionIdRaw) + guard await sessionStore.session(id: sessionId) != nil else { + await respondTopLevel(context: context, error: .sessionNotFound(), requestId: nil) + return + } + + await sessionStore.touch(id: sessionId) + + await startSse(context.id, sessionId, context) + Self.logger.info("Registered SSE notification stream for session \(sessionId.rawValue, privacy: .public)") + } + + private func handlePostMcp( + head: HttpRequestHead, + body: Data, + context: HttpConnectionContext, + clientAddress: MCPClientAddress, + now: Date + ) async { + guard pathMatchesMcp(head.path) else { + await respondTopLevel( + context: context, + error: MCPProtocolError( + code: JsonRpcErrorCode.methodNotFound, + message: "Method not found", + httpStatus: .notFound + ), + requestId: nil + ) + return + } + + if body.count > configuration.limits.maxRequestBodyBytes { + await respondTopLevel(context: context, error: .payloadTooLarge(), requestId: nil) + return + } + + let authResult = await authenticate(headers: head.headers, clientAddress: clientAddress) + guard case .allow(let principal) = authResult else { + if case .deny(let error) = authResult { + await respondTopLevel(context: context, error: error, requestId: nil) + } + return + } + + let message: JsonRpcMessage + do { + message = try JsonRpcCodec.decode(body) + } catch { + await respondTopLevel( + context: context, + error: .parseError(detail: String(describing: error)), + requestId: nil + ) + return + } + + let requestId = extractRequestId(from: message) + let methodName = extractMethod(from: message) + let mcpProtocolVersion = head.headers.value(for: "mcp-protocol-version") + + let sessionId: MCPSessionId? + if methodName == "initialize" { + do { + let session = try await sessionStore.create() + sessionId = session.id + } catch { + await respondTopLevel( + context: context, + error: .serviceUnavailable(), + requestId: requestId + ) + return + } + } else { + guard let raw = head.headers.value(for: "Mcp-Session-Id") else { + await respondTopLevel(context: context, error: .missingSessionId(), requestId: requestId) + return + } + let candidate = MCPSessionId(raw) + guard let session = await sessionStore.session(id: candidate) else { + await respondTopLevel(context: context, error: .sessionNotFound(), requestId: requestId) + return + } + if let mismatch = await Self.protocolVersionMismatch( + session: session, + headerValue: mcpProtocolVersion + ) { + await respondTopLevel(context: context, error: mismatch, requestId: requestId) + return + } + sessionId = candidate + await sessionStore.touch(id: candidate) + } + + let sink = makeResponderSink(context) + let responder = MCPExchangeResponder(sink: sink, requestId: requestId) + + let exchangeContext = MCPInboundContext( + sessionId: sessionId, + principal: principal, + clientAddress: clientAddress, + receivedAt: now, + mcpProtocolVersion: mcpProtocolVersion + ) + let exchange = MCPInboundExchange( + message: message, + context: exchangeContext, + responder: responder + ) + let yieldResult = emitInbound(exchange) + if case .dropped = yieldResult { + Self.logger.warning("exchanges buffer full, dropped inbound message; dispatcher is falling behind") + } + } + + private func handleDeleteMcp( + head: HttpRequestHead, + context: HttpConnectionContext, + clientAddress: MCPClientAddress + ) async { + guard pathMatchesMcp(head.path) else { + await respondTopLevel( + context: context, + error: MCPProtocolError( + code: JsonRpcErrorCode.methodNotFound, + message: "Method not found", + httpStatus: .notFound + ), + requestId: nil + ) + return + } + + let authResult = await authenticate(headers: head.headers, clientAddress: clientAddress) + guard case .allow = authResult else { + if case .deny(let error) = authResult { + await respondTopLevel(context: context, error: error, requestId: nil) + } + return + } + + guard let raw = head.headers.value(for: "Mcp-Session-Id") else { + await respondTopLevel(context: context, error: .missingSessionId(), requestId: nil) + return + } + + let sessionId = MCPSessionId(raw) + guard await sessionStore.session(id: sessionId) != nil else { + await respondTopLevel(context: context, error: .sessionNotFound(), requestId: nil) + return + } + + await sessionStore.terminate(id: sessionId, reason: .clientRequested) + await context.writeNoContent() + await context.cancel() + } + + private func authenticate( + headers: HttpHeaders, + clientAddress: MCPClientAddress + ) async -> AuthResult { + let authHeader = headers.value(for: "Authorization") + let decision = await authenticator.authenticate( + authorizationHeader: authHeader, + clientAddress: clientAddress + ) + switch decision { + case .allow(let principal): + return .allow(principal) + case .deny(let reason): + let mcpError = mapDenialToProtocolError(reason) + return .deny(mcpError) + } + } + + private func mapDenialToProtocolError(_ reason: MCPAuthDenialReason) -> MCPProtocolError { + switch reason.httpStatus { + case 401: + if let challenge = reason.challenge { + if challenge.contains("invalid_token") { + if challenge.contains("token_expired") || challenge.contains("token expired") { + return .tokenExpired() + } + return .tokenInvalid() + } + return .unauthenticated(challenge: challenge) + } + return .unauthenticated() + case 403: + return .forbidden(reason: reason.logMessage) + case 429: + return .rateLimited(retryAfterSeconds: reason.retryAfterSeconds) + default: + return MCPProtocolError( + code: JsonRpcErrorCode.serverError, + message: reason.logMessage, + httpStatus: HttpStatus(code: reason.httpStatus, reasonPhrase: "Error"), + extraHeaders: reason.challenge.map { [("WWW-Authenticate", $0)] } ?? [] + ) + } + } + + private func respondTopLevel( + context: HttpConnectionContext, + error: MCPProtocolError, + requestId: JsonRpcId? + ) async { + let envelope = error.toJsonRpcErrorResponse(id: requestId) + let data = (try? JSONEncoder().encode(envelope)) ?? Data() + await context.writeJsonResponse( + data: data, + status: error.httpStatus, + sessionId: nil, + extraHeaders: error.extraHeaders + ) + await context.cancel() + } + + private func pathMatchesMcp(_ path: String) -> Bool { + let trimmed = stripQueryString(path) + return trimmed == "/mcp" || trimmed == "/mcp/" + } + + private static func protocolVersionMismatch( + session: MCPSession, + headerValue: String? + ) async -> MCPProtocolError? { + let state = await session.state + guard case .ready = state else { return nil } + guard let negotiated = await session.negotiatedProtocolVersion else { return nil } + guard let headerValue, !headerValue.isEmpty else { return nil } + if headerValue == negotiated { return nil } + return .invalidRequest( + detail: "MCP-Protocol-Version mismatch: client sent \(headerValue), session negotiated \(negotiated)" + ) + } + + private func stripQueryString(_ path: String) -> String { + if let questionIndex = path.firstIndex(of: "?") { + return String(path[path.startIndex.. JsonRpcId? { + switch message { + case .request(let request): + return request.id + case .successResponse(let response): + return response.id + case .errorResponse(let response): + return response.id + case .notification: + return nil + } + } + + private func extractMethod(from message: JsonRpcMessage) -> String? { + switch message { + case .request(let request): + return request.method + case .notification(let notification): + return notification.method + case .successResponse, .errorResponse: + return nil + } + } + + enum AuthResult { + case allow(MCPPrincipal) + case deny(MCPProtocolError) + } +} diff --git a/TablePro/Core/MCP/Transport/MCPHttpServerTransport.swift b/TablePro/Core/MCP/Transport/MCPHttpServerTransport.swift index 285b94516..fa09cefa2 100644 --- a/TablePro/Core/MCP/Transport/MCPHttpServerTransport.swift +++ b/TablePro/Core/MCP/Transport/MCPHttpServerTransport.swift @@ -21,6 +21,7 @@ public actor MCPHttpServerTransport { private var listener: NWListener? private var connections: [UUID: HttpConnectionContext] = [:] + private var sseWriters: [UUID: MCPSseWriter] = [:] private var sseConnectionsBySession: [MCPSessionId: UUID] = [:] private var sessionEventsTask: Task? @@ -63,7 +64,7 @@ public actor MCPHttpServerTransport { Self.logger.info("Starting MCP HTTP server: bind=\(String(describing: self.configuration.bindAddress)) port=\(self.configuration.port) tls=\(self.configuration.tls != nil)") if configuration.bindAddress == .anyInterface, configuration.tls == nil { - Self.logger.error("Remote access requested without TLS — refusing to start") + Self.logger.error("Remote access requested without TLS, refusing to start") throw MCPHttpServerError.tlsRequiredForRemoteAccess } @@ -100,6 +101,11 @@ public actor MCPHttpServerTransport { sessionEventsTask?.cancel() sessionEventsTask = nil + for (_, writer) in sseWriters { + await writer.stop() + } + sseWriters.removeAll() + for (_, context) in connections { await context.cancel() } @@ -125,14 +131,14 @@ public actor MCPHttpServerTransport { public func sendNotification(_ notification: JsonRpcNotification, toSession sessionId: MCPSessionId) async { guard let connectionId = sseConnectionsBySession[sessionId], - let context = connections[connectionId] else { + let writer = sseWriters[connectionId] else { return } let message = JsonRpcMessage.notification(notification) guard let data = try? JsonRpcCodec.encode(message), let text = String(data: data, encoding: .utf8) else { return } - await context.writeSseFrame(SseFrame(data: text)) + await writer.writeFrame(SseFrame(data: text)) } public func broadcastNotification(_ notification: JsonRpcNotification) async { @@ -210,8 +216,7 @@ public actor MCPHttpServerTransport { } private func handleSessionTerminated(_ sessionId: MCPSessionId, reason: MCPSessionTerminationReason) async { - guard let connectionId = sseConnectionsBySession.removeValue(forKey: sessionId), - let context = connections[connectionId] else { + guard let connectionId = sseConnectionsBySession.removeValue(forKey: sessionId) else { return } @@ -228,8 +233,14 @@ public actor MCPHttpServerTransport { case .capacityEvicted: comment = "capacity-evicted" } - await context.writeRaw(Data("\u{003A} \(comment)\n\n".utf8)) - await context.cancel() + + if let writer = sseWriters.removeValue(forKey: connectionId) { + await writer.writeComment(comment) + await writer.stop() + } else if let context = connections[connectionId] { + await context.writeRaw(Data("\u{003A} \(comment)\n\n".utf8)) + await context.cancel() + } connections.removeValue(forKey: connectionId) } @@ -238,41 +249,61 @@ public actor MCPHttpServerTransport { Self.logger.debug("Accepted connection \(connectionId, privacy: .public)") let context = HttpConnectionContext(id: connectionId, connection: connection) connections[connectionId] = context + let router = makeRouter() await context.start { [weak self] data in guard let self else { return } - await self.handleReceivedData(connectionId: connectionId, data: data) + await self.handleReceivedData(connectionId: connectionId, data: data, router: router) } onClosed: { [weak self] in guard let self else { return } await self.removeConnection(connectionId: connectionId) } } + private func makeRouter() -> MCPHttpRequestRouter { + let exchangesContinuation = self.exchangesContinuation + let transport = self + return MCPHttpRequestRouter( + configuration: configuration, + sessionStore: sessionStore, + authenticator: authenticator, + clock: clock, + emitInbound: { exchange in + exchangesContinuation.yield(exchange) + }, + startSse: { connectionId, sessionId, context in + await transport.attachSseWriter(connectionId: connectionId, sessionId: sessionId, context: context) + }, + makeResponderSink: { context in + TransportResponderSink(transport: transport, context: context) + } + ) + } + private func removeConnection(connectionId: UUID) async { connections.removeValue(forKey: connectionId) + if let writer = sseWriters.removeValue(forKey: connectionId) { + await writer.stop() + } let pairs = sseConnectionsBySession.filter { $0.value == connectionId } for (sessionId, _) in pairs { sseConnectionsBySession.removeValue(forKey: sessionId) } } - private func handleReceivedData(connectionId: UUID, data: Data) async { + private func handleReceivedData(connectionId: UUID, data: Data, router: MCPHttpRequestRouter) async { guard let context = connections[connectionId] else { return } let parseResult: HttpRequestParseResult do { parseResult = try HttpRequestParser.parse(data) } catch HttpRequestParseError.bodyTooLarge { - await respondTopLevel(context: context, error: .payloadTooLarge(), requestId: nil) + await respondParseFailure(context: context, status: .payloadTooLarge) return } catch HttpRequestParseError.headerTooLarge { - await respondTopLevel(context: context, error: .payloadTooLarge(), requestId: nil) + await respondParseFailure(context: context, status: .payloadTooLarge) return } catch { - await respondTopLevel( - context: context, - error: .invalidRequest(detail: "Malformed HTTP"), - requestId: nil - ) + await respondParseFailure(context: context, status: .badRequest, detail: "Malformed HTTP") return } @@ -281,418 +312,18 @@ public actor MCPHttpServerTransport { return case .complete(let head, let body, _): await context.markRequestComplete() - await dispatch(head: head, body: body, context: context) - } - } - - private func dispatch(head: HttpRequestHead, body: Data, context: HttpConnectionContext) async { - let clientAddress: MCPClientAddress = await context.clientAddress() - let now = await clock.now() - - await context.setOrigin(head.headers.value(for: "Origin")) - - if head.method == .post, stripQueryString(head.path) == "/v1/integrations/exchange" { - await handleIntegrationsExchange(body: body, context: context) - return - } - - switch head.method { - case .options: - await context.writeOptions204() - await context.cancel() - return - case .get: - await handleGetMcp(head: head, context: context, clientAddress: clientAddress) - case .post: - await handlePostMcp(head: head, body: body, context: context, clientAddress: clientAddress, now: now) - case .delete: - await handleDeleteMcp(head: head, context: context, clientAddress: clientAddress) - default: - await respondTopLevel( - context: context, - error: MCPProtocolError( - code: JsonRpcErrorCode.methodNotFound, - message: "Method not allowed", - httpStatus: .methodNotAllowed - ), - requestId: nil - ) - } - } - - private func handleIntegrationsExchange(body: Data, context: HttpConnectionContext) async { - struct ExchangeBody: Decodable { - let code: String - let codeVerifier: String - enum CodingKeys: String, CodingKey { - case code - case codeVerifier = "code_verifier" - } - } - struct ExchangeResponse: Encodable { - let token: String - } - - Self.logger.info("Integrations exchange request received (\(body.count, privacy: .public) bytes)") - let ip = Self.ipString(for: await context.clientAddress()) - - let parsed: ExchangeBody - do { - parsed = try JSONDecoder().decode(ExchangeBody.self, from: body) - } catch { - Self.logger.warning("Integrations exchange decode failed: \(error.localizedDescription, privacy: .public)") - MCPAuditLogger.logPairingExchange(outcome: .denied, ip: ip, details: "invalid JSON body") - await context.writePlainJsonError(status: .badRequest, message: "Invalid JSON body") - await context.cancel() - return - } - - guard !parsed.code.isEmpty, !parsed.codeVerifier.isEmpty else { - Self.logger.warning("Integrations exchange missing code or verifier") - MCPAuditLogger.logPairingExchange( - outcome: .denied, - ip: ip, - details: "missing code or code_verifier" - ) - await context.writePlainJsonError(status: .badRequest, message: "Missing code or code_verifier") - await context.cancel() - return - } - - guard parsed.code.utf8.count <= 1_024, parsed.codeVerifier.utf8.count <= 1_024 else { - Self.logger.warning("Integrations exchange field exceeds size cap") - MCPAuditLogger.logPairingExchange( - outcome: .denied, - ip: ip, - details: "field exceeds 1_024 bytes" - ) - await context.writePlainJsonError(status: .badRequest, message: "Field exceeds size limit") - await context.cancel() - return - } - - let exchange = PairingExchange(code: parsed.code, verifier: parsed.codeVerifier) - let outcome: Result = await MainActor.run { - do { - return .success(try MCPPairingService.shared.exchange(exchange)) - } catch { - return .failure(error) - } - } - - switch outcome { - case .success(let token): - Self.logger.info("Integrations exchange succeeded (token len=\(token.count, privacy: .public))") - let label = await Self.resolveTokenLabel(for: token) - MCPAuditLogger.logPairingExchange(outcome: .success, tokenName: label, ip: ip) - let payload = (try? JSONEncoder().encode(ExchangeResponse(token: token))) ?? Data() - await context.writePlainJsonResponse(status: .ok, body: payload) - await context.cancel() - case .failure(let error): - let mapped = Self.mapExchangeError(error) - Self.logger.warning("Integrations exchange failed: status=\(mapped.status.code, privacy: .public) reason=\(mapped.message, privacy: .public)") - MCPAuditLogger.logPairingExchange( - outcome: .denied, - ip: ip, - details: mapped.message - ) - await context.writePlainJsonError(status: mapped.status, message: mapped.message) - await context.cancel() - } - } - - private static func ipString(for address: MCPClientAddress) -> String { - switch address { - case .loopback: - return "127.0.0.1" - case .remote(let host): - return host - } - } - - private static func resolveTokenLabel(for plaintext: String) async -> String? { - let store: MCPTokenStore? = await MainActor.run { MCPServerManager.shared.tokenStore } - guard let store else { return nil } - return await store.validate(bearerToken: plaintext)?.name - } - - private static func mapExchangeError(_ error: Error) -> (status: HttpStatus, message: String) { - guard let domainError = error as? MCPDataLayerError else { - return (.internalServerError, "Internal error") - } - switch domainError { - case .notFound: - return (.notFound, "Pairing code not found") - case .expired: - return (HttpStatus(code: 410, reasonPhrase: "Gone"), "Pairing code expired") - case .forbidden: - return (.forbidden, "Challenge mismatch") - default: - return (.internalServerError, "Internal error") + await router.dispatch(head: head, body: body, context: context) } } - private func handleGetMcp( - head: HttpRequestHead, - context: HttpConnectionContext, - clientAddress: MCPClientAddress - ) async { - guard pathMatchesMcp(head.path) else { - await respondTopLevel( - context: context, - error: MCPProtocolError( - code: JsonRpcErrorCode.methodNotFound, - message: "Method not found", - httpStatus: .notFound - ), - requestId: nil - ) - return - } - - guard let sessionIdRaw = head.headers.value(for: "Mcp-Session-Id") else { - await respondTopLevel(context: context, error: .missingSessionId(), requestId: nil) - return - } - - if head.headers.value(for: "Last-Event-ID") != nil { - await respondTopLevel( - context: context, - error: MCPProtocolError( - code: JsonRpcErrorCode.serverError, - message: "SSE event replay is not supported", - httpStatus: .notImplemented - ), - requestId: nil - ) - return - } - - if let accept = head.headers.value(for: "Accept"), - !accept.lowercased().contains("text/event-stream"), - !accept.contains("*/*") { - await respondTopLevel(context: context, error: .notAcceptable(), requestId: nil) - return - } - - let authResult = await authenticate(headers: head.headers, clientAddress: clientAddress) - guard case .allow = authResult else { - if case .deny(let error) = authResult { - await respondTopLevel(context: context, error: error, requestId: nil) - } - return - } - - let sessionId = MCPSessionId(sessionIdRaw) - guard await sessionStore.session(id: sessionId) != nil else { - await respondTopLevel(context: context, error: .sessionNotFound(), requestId: nil) - return - } - - await sessionStore.touch(id: sessionId) - - registerSseConnection(connectionId: context.id, sessionId: sessionId) - await context.writeSseStreamHeaders(sessionId: sessionId) - Self.logger.info("Registered SSE notification stream for session \(sessionId.rawValue, privacy: .public)") - } - - private func handlePostMcp( - head: HttpRequestHead, - body: Data, - context: HttpConnectionContext, - clientAddress: MCPClientAddress, - now: Date - ) async { - guard pathMatchesMcp(head.path) else { - await respondTopLevel( - context: context, - error: MCPProtocolError( - code: JsonRpcErrorCode.methodNotFound, - message: "Method not found", - httpStatus: .notFound - ), - requestId: nil - ) - return - } - - if body.count > configuration.limits.maxRequestBodyBytes { - await respondTopLevel(context: context, error: .payloadTooLarge(), requestId: nil) - return - } - - let authResult = await authenticate(headers: head.headers, clientAddress: clientAddress) - guard case .allow(let principal) = authResult else { - if case .deny(let error) = authResult { - await respondTopLevel(context: context, error: error, requestId: nil) - } - return - } - - let message: JsonRpcMessage - do { - message = try JsonRpcCodec.decode(body) - } catch { - await respondTopLevel( - context: context, - error: .parseError(detail: String(describing: error)), - requestId: nil - ) - return - } - - let requestId = extractRequestId(from: message) - let methodName = extractMethod(from: message) - let mcpProtocolVersion = head.headers.value(for: "mcp-protocol-version") - - let sessionId: MCPSessionId? - if methodName == "initialize" { - do { - let session = try await sessionStore.create() - sessionId = session.id - } catch { - await respondTopLevel( - context: context, - error: .serviceUnavailable(), - requestId: requestId - ) - return - } + private func respondParseFailure(context: HttpConnectionContext, status: HttpStatus, detail: String? = nil) async { + let error: MCPProtocolError + if status.code == HttpStatus.payloadTooLarge.code { + error = .payloadTooLarge() } else { - guard let raw = head.headers.value(for: "Mcp-Session-Id") else { - await respondTopLevel(context: context, error: .missingSessionId(), requestId: requestId) - return - } - let candidate = MCPSessionId(raw) - guard let session = await sessionStore.session(id: candidate) else { - await respondTopLevel(context: context, error: .sessionNotFound(), requestId: requestId) - return - } - if let mismatch = await Self.protocolVersionMismatch( - session: session, - headerValue: mcpProtocolVersion - ) { - await respondTopLevel(context: context, error: mismatch, requestId: requestId) - return - } - sessionId = candidate - await sessionStore.touch(id: candidate) + error = .invalidRequest(detail: detail ?? "Bad request") } - - let sink = TransportResponderSink(transport: self, context: context) - let responder = MCPExchangeResponder(sink: sink, requestId: requestId) - - let exchangeContext = MCPInboundContext( - sessionId: sessionId, - principal: principal, - clientAddress: clientAddress, - receivedAt: now, - mcpProtocolVersion: mcpProtocolVersion - ) - let exchange = MCPInboundExchange( - message: message, - context: exchangeContext, - responder: responder - ) - let yieldResult = exchangesContinuation.yield(exchange) - if case .dropped = yieldResult { - Self.logger.warning("exchanges buffer full, dropped inbound message — dispatcher is falling behind") - } - } - - private func handleDeleteMcp( - head: HttpRequestHead, - context: HttpConnectionContext, - clientAddress: MCPClientAddress - ) async { - guard pathMatchesMcp(head.path) else { - await respondTopLevel( - context: context, - error: MCPProtocolError( - code: JsonRpcErrorCode.methodNotFound, - message: "Method not found", - httpStatus: .notFound - ), - requestId: nil - ) - return - } - - let authResult = await authenticate(headers: head.headers, clientAddress: clientAddress) - guard case .allow = authResult else { - if case .deny(let error) = authResult { - await respondTopLevel(context: context, error: error, requestId: nil) - } - return - } - - guard let raw = head.headers.value(for: "Mcp-Session-Id") else { - await respondTopLevel(context: context, error: .missingSessionId(), requestId: nil) - return - } - - let sessionId = MCPSessionId(raw) - guard await sessionStore.session(id: sessionId) != nil else { - await respondTopLevel(context: context, error: .sessionNotFound(), requestId: nil) - return - } - - await sessionStore.terminate(id: sessionId, reason: .clientRequested) - await context.writeNoContent() - await context.cancel() - } - - private func authenticate( - headers: HttpHeaders, - clientAddress: MCPClientAddress - ) async -> AuthResult { - let authHeader = headers.value(for: "Authorization") - let decision = await authenticator.authenticate( - authorizationHeader: authHeader, - clientAddress: clientAddress - ) - switch decision { - case .allow(let principal): - return .allow(principal) - case .deny(let reason): - let mcpError = mapDenialToProtocolError(reason) - return .deny(mcpError) - } - } - - private func mapDenialToProtocolError(_ reason: MCPAuthDenialReason) -> MCPProtocolError { - switch reason.httpStatus { - case 401: - if let challenge = reason.challenge { - if challenge.contains("invalid_token") { - if challenge.contains("token_expired") || challenge.contains("token expired") { - return .tokenExpired() - } - return .tokenInvalid() - } - return .unauthenticated(challenge: challenge) - } - return .unauthenticated() - case 403: - return .forbidden(reason: reason.logMessage) - case 429: - return .rateLimited(retryAfterSeconds: reason.retryAfterSeconds) - default: - return MCPProtocolError( - code: JsonRpcErrorCode.serverError, - message: reason.logMessage, - httpStatus: HttpStatus(code: reason.httpStatus, reasonPhrase: "Error"), - extraHeaders: reason.challenge.map { [("WWW-Authenticate", $0)] } ?? [] - ) - } - } - - private func respondTopLevel( - context: HttpConnectionContext, - error: MCPProtocolError, - requestId: JsonRpcId? - ) async { - let envelope = error.toJsonRpcErrorResponse(id: requestId) + let envelope = error.toJsonRpcErrorResponse(id: nil) let data = (try? JSONEncoder().encode(envelope)) ?? Data() await context.writeJsonResponse( data: data, @@ -703,306 +334,28 @@ public actor MCPHttpServerTransport { await context.cancel() } - private func pathMatchesMcp(_ path: String) -> Bool { - let trimmed = stripQueryString(path) - return trimmed == "/mcp" || trimmed == "/mcp/" - } - - private static func protocolVersionMismatch( - session: MCPSession, - headerValue: String? - ) async -> MCPProtocolError? { - let state = await session.state - guard case .ready = state else { return nil } - guard let negotiated = await session.negotiatedProtocolVersion else { return nil } - guard let headerValue, !headerValue.isEmpty else { return nil } - if headerValue == negotiated { return nil } - return .invalidRequest( - detail: "MCP-Protocol-Version mismatch: client sent \(headerValue), session negotiated \(negotiated)" - ) - } - - private func stripQueryString(_ path: String) -> String { - if let questionIndex = path.firstIndex(of: "?") { - return String(path[path.startIndex.. JsonRpcId? { - switch message { - case .request(let request): - return request.id - case .successResponse(let response): - return response.id - case .errorResponse(let response): - return response.id - case .notification: - return nil - } - } - - private func extractMethod(from message: JsonRpcMessage) -> String? { - switch message { - case .request(let request): - return request.method - case .notification(let notification): - return notification.method - case .successResponse, .errorResponse: - return nil - } - } - - fileprivate func registerSseConnection(connectionId: UUID, sessionId: MCPSessionId) { - if let previous = sseConnectionsBySession[sessionId], previous != connectionId, - let oldContext = connections[previous] { - Task { await oldContext.cancel() } + fileprivate func attachSseWriter( + connectionId: UUID, + sessionId: MCPSessionId, + context: HttpConnectionContext + ) async { + if let previous = sseConnectionsBySession[sessionId], previous != connectionId { + if let oldWriter = sseWriters.removeValue(forKey: previous) { + await oldWriter.stop() + } else if let oldContext = connections[previous] { + await oldContext.cancel() + } connections.removeValue(forKey: previous) } + let writer = MCPSseWriter(context: context) + sseWriters[connectionId] = writer sseConnectionsBySession[sessionId] = connectionId + await writer.startStream(sessionId: sessionId) } - private enum AuthResult { - case allow(MCPPrincipal) - case deny(MCPProtocolError) - } -} - -actor HttpConnectionContext { - private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.HttpServer") - - nonisolated let id: UUID - private let connection: NWConnection - private var receiveBuffer = Data() - private var requestComplete = false - private var cancelled = false - private var sseActive = false - private var origin: String? - - init(id: UUID, connection: NWConnection) { - self.id = id - self.connection = connection - } - - func setOrigin(_ value: String?) { - origin = value - } - - private func corsHeaders() -> [(String, String)] { - MCPCorsHeaders.headers(forOrigin: origin) - } - - func start( - onData: @escaping @Sendable (Data) async -> Void, - onClosed: @escaping @Sendable () async -> Void - ) { - let nwConnection = connection - nwConnection.stateUpdateHandler = { [weak self] state in - guard let self else { return } - switch state { - case .ready: - Task { await self.beginReading(onData: onData, onClosed: onClosed) } - case .failed: - Task { await self.handleClosed(onClosed: onClosed) } - case .cancelled: - Task { await self.handleClosed(onClosed: onClosed) } - default: - break - } - } - nwConnection.start(queue: .global(qos: .userInitiated)) - } - - private func beginReading( - onData: @escaping @Sendable (Data) async -> Void, - onClosed: @escaping @Sendable () async -> Void - ) { - scheduleReceive(onData: onData, onClosed: onClosed) - } - - private func scheduleReceive( - onData: @escaping @Sendable (Data) async -> Void, - onClosed: @escaping @Sendable () async -> Void - ) { - if cancelled || requestComplete { return } - connection.receive(minimumIncompleteLength: 1, maximumLength: 65_536) { [weak self] content, _, isComplete, error in - guard let self else { return } - Task { - await self.handleReceive( - content: content, - isComplete: isComplete, - error: error, - onData: onData, - onClosed: onClosed - ) - } - } - } - - private func handleReceive( - content: Data?, - isComplete: Bool, - error: NWError?, - onData: @escaping @Sendable (Data) async -> Void, - onClosed: @escaping @Sendable () async -> Void - ) async { - if let error { - Self.logger.debug("Receive error: \(error.localizedDescription, privacy: .public)") - cancel() - await onClosed() - return - } - - if let content { - receiveBuffer.append(content) - await onData(receiveBuffer) - } - - if isComplete { - cancel() - await onClosed() - return - } - - if !requestComplete, !cancelled { - scheduleReceive(onData: onData, onClosed: onClosed) - } - } - - private func handleClosed(onClosed: @escaping @Sendable () async -> Void) async { - if !cancelled { - cancelled = true - } - await onClosed() - } - - func markRequestComplete() { - requestComplete = true - } - - func clientAddress() -> MCPClientAddress { - guard let endpoint = connection.currentPath?.remoteEndpoint, - case .hostPort(let host, _) = endpoint else { - return .loopback - } - let hostString = "\(host)" - if hostString == "127.0.0.1" || hostString == "::1" || hostString.lowercased() == "localhost" { - return .loopback - } - return .remote(hostString) - } - - func writeJsonResponse( - data: Data, - status: HttpStatus, - sessionId: MCPSessionId?, - extraHeaders: [(String, String)] - ) async { - if cancelled { return } - var headers: [(String, String)] = [ - ("Content-Type", "application/json"), - ("Connection", "close") - ] - if let sessionId { - headers.append(("Mcp-Session-Id", sessionId.rawValue)) - } - headers.append(contentsOf: extraHeaders) - headers.append(contentsOf: self.corsHeaders()) - let head = HttpResponseHead(status: status, headers: HttpHeaders(headers)) - let payload = HttpResponseEncoder.encode(head, body: data) - await send(payload) - } - - func writePlainJsonResponse(status: HttpStatus, body: Data) async { - if cancelled { return } - var headers: [(String, String)] = [ - ("Content-Type", "application/json"), - ("Connection", "close") - ] - headers.append(contentsOf: self.corsHeaders()) - let head = HttpResponseHead(status: status, headers: HttpHeaders(headers)) - let payload = HttpResponseEncoder.encode(head, body: body) - await send(payload) - } - - func writePlainJsonError(status: HttpStatus, message: String) async { - struct ErrorBody: Encodable { let error: String } - let payload = (try? JSONEncoder().encode(ErrorBody(error: message))) ?? Data() - await writePlainJsonResponse(status: status, body: payload) - } - - func writeOptions204() async { - if cancelled { return } - var headers: [(String, String)] = [("Connection", "close")] - headers.append(contentsOf: self.corsHeaders()) - let head = HttpResponseHead(status: .noContent, headers: HttpHeaders(headers)) - let payload = HttpResponseEncoder.encode(head, body: nil) - await send(payload) - } - - func writeNoContent() async { - if cancelled { return } - var headers: [(String, String)] = [("Connection", "close")] - headers.append(contentsOf: self.corsHeaders()) - let head = HttpResponseHead(status: .noContent, headers: HttpHeaders(headers)) - let payload = HttpResponseEncoder.encode(head, body: nil) - await send(payload) - } - - func writeAccepted() async { - if cancelled { return } - var headers: [(String, String)] = [("Connection", "close")] - headers.append(contentsOf: self.corsHeaders()) - let head = HttpResponseHead(status: .accepted, headers: HttpHeaders(headers)) - let payload = HttpResponseEncoder.encode(head, body: nil) - await send(payload) - } - - func writeSseStreamHeaders(sessionId: MCPSessionId) async { - if cancelled { return } - sseActive = true - var headers: [(String, String)] = [ - ("Content-Type", "text/event-stream"), - ("Cache-Control", "no-cache"), - ("Connection", "keep-alive"), - ("Mcp-Session-Id", sessionId.rawValue) - ] - headers.append(contentsOf: self.corsHeaders()) - let head = HttpResponseHead(status: .ok, headers: HttpHeaders(headers)) - let payload = HttpResponseEncoder.encode(head, body: nil) - await send(payload) - } - - func writeSseFrame(_ frame: SseFrame) async { - if cancelled { return } - let data = SseEncoder.encode(frame) - await send(data) - } - - func writeRaw(_ data: Data) async { - if cancelled { return } - await send(data) - } - - func cancel() { - if cancelled { return } - cancelled = true - connection.cancel() - } - - func isSseActive() -> Bool { - sseActive - } - - private func send(_ data: Data) async { - await withCheckedContinuation { (continuation: CheckedContinuation) in - connection.send(content: data, completion: .contentProcessed { error in - if let error { - Self.logger.debug("Send error: \(error.localizedDescription, privacy: .public)") - } - continuation.resume() - }) - } + fileprivate func registerSseConnection(connectionId: UUID, sessionId: MCPSessionId) async { + guard let context = connections[connectionId] else { return } + await attachSseWriter(connectionId: connectionId, sessionId: sessionId, context: context) } } diff --git a/TablePro/Core/MCP/Transport/MCPSseWriter.swift b/TablePro/Core/MCP/Transport/MCPSseWriter.swift new file mode 100644 index 000000000..af991347a --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPSseWriter.swift @@ -0,0 +1,61 @@ +import Foundation +import os + +actor MCPSseWriter { + static let keepAliveInterval: Duration = .seconds(30) + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.SseWriter") + + private let context: HttpConnectionContext + private var keepAliveTask: Task? + private var stopped = false + + init(context: HttpConnectionContext) { + self.context = context + } + + func startStream(sessionId: MCPSessionId) async { + await context.writeSseStreamHeaders(sessionId: sessionId) + startKeepAlive() + } + + func writeFrame(_ frame: SseFrame) async { + guard !stopped else { return } + await context.writeSseFrame(frame) + } + + func writeComment(_ text: String) async { + guard !stopped else { return } + await context.writeRaw(Data("\u{003A} \(text)\n\n".utf8)) + } + + func stop() async { + if stopped { return } + stopped = true + keepAliveTask?.cancel() + keepAliveTask = nil + await context.cancel() + } + + private func startKeepAlive() { + keepAliveTask?.cancel() + keepAliveTask = Task { [weak self] in + while !Task.isCancelled { + try? await Task.sleep(for: Self.keepAliveInterval) + guard !Task.isCancelled, let self else { return } + await self.emitKeepAlive() + } + } + } + + private func emitKeepAlive() async { + guard !stopped else { return } + if await context.isCancelled() { + keepAliveTask?.cancel() + keepAliveTask = nil + stopped = true + return + } + await context.writeRaw(Data("\u{003A} keep-alive\n\n".utf8)) + } +} diff --git a/TableProTests/Core/MCP/MCPPairingServiceTests.swift b/TableProTests/Core/MCP/MCPPairingServiceTests.swift index 1ac1ccb61..2d375e90e 100644 --- a/TableProTests/Core/MCP/MCPPairingServiceTests.swift +++ b/TableProTests/Core/MCP/MCPPairingServiceTests.swift @@ -23,49 +23,50 @@ struct MCPPairingServiceTests { } @Test("consume returns stored token when challenge and verifier match") - func consumeReturnsTokenForValidVerifier() throws { + func consumeReturnsTokenForValidVerifier() async throws { let verifier = "test-verifier-1" let challenge = base64UrlSha256(of: verifier) let store = makeStore() - try store.insert(code: "code-1", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: 60)) + try await store.insert(code: "code-1", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: 60)) - let token = try store.consume(code: "code-1", verifier: verifier) + let token = try await store.consume(code: "code-1", verifier: verifier) #expect(token == "tp_secret") } @Test("consume removes the entry after success (single-use)") - func consumeIsSingleUse() throws { + func consumeIsSingleUse() async throws { let verifier = "test-verifier-2" let challenge = base64UrlSha256(of: verifier) let store = makeStore() - try store.insert(code: "code-2", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: 60)) + try await store.insert(code: "code-2", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: 60)) - _ = try store.consume(code: "code-2", verifier: verifier) + _ = try await store.consume(code: "code-2", verifier: verifier) - #expect(store.contains(code: "code-2") == false) + let contains = await store.contains(code: "code-2") + #expect(contains == false) } @Test("second consume of the same code returns notFound") - func duplicateConsumeReturnsNotFound() throws { + func duplicateConsumeReturnsNotFound() async throws { let verifier = "test-verifier-3" let challenge = base64UrlSha256(of: verifier) let store = makeStore() - try store.insert(code: "code-3", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: 60)) + try await store.insert(code: "code-3", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: 60)) - _ = try store.consume(code: "code-3", verifier: verifier) + _ = try await store.consume(code: "code-3", verifier: verifier) - #expect(throws: MCPDataLayerError.self) { - try store.consume(code: "code-3", verifier: verifier) + await #expect(throws: MCPDataLayerError.self) { + try await store.consume(code: "code-3", verifier: verifier) } } @Test("consume returns notFound for unknown code") - func consumeUnknownCodeReturnsNotFound() { + func consumeUnknownCodeReturnsNotFound() async { let store = makeStore() do { - _ = try store.consume(code: "missing", verifier: "any") + _ = try await store.consume(code: "missing", verifier: "any") Issue.record("Expected notFound error") } catch let error as MCPDataLayerError { guard case .notFound = error else { @@ -78,14 +79,14 @@ struct MCPPairingServiceTests { } @Test("consume returns expired when entry has expired") - func consumeExpiredEntryReturnsExpired() throws { + func consumeExpiredEntryReturnsExpired() async throws { let verifier = "test-verifier-4" let challenge = base64UrlSha256(of: verifier) let store = makeStore() - try store.insert(code: "code-4", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: -1)) + try await store.insert(code: "code-4", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: -1)) do { - _ = try store.consume(code: "code-4", verifier: verifier, now: Date.now) + _ = try await store.consume(code: "code-4", verifier: verifier, now: Date.now) Issue.record("Expected expired error") } catch let error as MCPDataLayerError { guard case .expired = error else { @@ -98,13 +99,13 @@ struct MCPPairingServiceTests { } @Test("consume returns forbidden when challenge does not match the verifier") - func consumeMismatchedChallengeReturnsForbidden() throws { + func consumeMismatchedChallengeReturnsForbidden() async throws { let store = makeStore() let challenge = base64UrlSha256(of: "intended-verifier") - try store.insert(code: "code-5", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: 60)) + try await store.insert(code: "code-5", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: 60)) do { - _ = try store.consume(code: "code-5", verifier: "attacker-verifier") + _ = try await store.consume(code: "code-5", verifier: "attacker-verifier") Issue.record("Expected forbidden error") } catch let error as MCPDataLayerError { guard case .forbidden = error else { @@ -117,39 +118,44 @@ struct MCPPairingServiceTests { } @Test("consume on expired code removes the entry") - func consumeOnExpiredCodeRemovesEntry() throws { + func consumeOnExpiredCodeRemovesEntry() async throws { let verifier = "test-verifier-6" let challenge = base64UrlSha256(of: verifier) let store = makeStore() - try store.insert(code: "code-6", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: -1)) + try await store.insert(code: "code-6", record: record(plaintext: "tp_secret", challenge: challenge, expiresIn: -1)) - _ = try? store.consume(code: "code-6", verifier: verifier) + _ = try? await store.consume(code: "code-6", verifier: verifier) - #expect(store.contains(code: "code-6") == false) + let contains = await store.contains(code: "code-6") + #expect(contains == false) } @Test("pruneExpired removes only expired entries") - func pruneRemovesOnlyExpiredEntries() throws { + func pruneRemovesOnlyExpiredEntries() async throws { let store = makeStore() - try store.insert( + try await store.insert( code: "alive", record: record(plaintext: "tp_a", challenge: "challenge", expiresIn: 60) ) - try store.insert( + try await store.insert( code: "stale-1", record: record(plaintext: "tp_b", challenge: "challenge", expiresIn: -1) ) - try store.insert( + try await store.insert( code: "stale-2", record: record(plaintext: "tp_c", challenge: "challenge", expiresIn: -10) ) - store.pruneExpired() + await store.pruneExpired() - #expect(store.count() == 1) - #expect(store.contains(code: "alive")) - #expect(store.contains(code: "stale-1") == false) - #expect(store.contains(code: "stale-2") == false) + let count = await store.count() + let containsAlive = await store.contains(code: "alive") + let containsStale1 = await store.contains(code: "stale-1") + let containsStale2 = await store.contains(code: "stale-2") + #expect(count == 1) + #expect(containsAlive) + #expect(containsStale1 == false) + #expect(containsStale2 == false) } @Test("sha256Base64Url matches CryptoKit output without padding") @@ -180,17 +186,17 @@ struct MCPPairingServiceTests { } @Test("insert throws after maxPendingCodes consecutive inserts") - func insertThrowsWhenPendingCapReached() throws { + func insertThrowsWhenPendingCapReached() async throws { let store = makeStore() for index in 0.. String { From d0c38d82e13436e8a8d158243a0d676952acc6fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 10:21:39 +0700 Subject: [PATCH 05/16] docs(changelog): log AI provider extraction --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b306b3e8d..3d3816e49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Copilot LSP shutdown caps at 10 seconds, closes pipes explicitly, and strips the quarantine attribute from the downloaded binary - AI Chat: streaming view model split into focused extensions backed by a single `streamingState` enum - MCP HTTP server: split transport into connection, router, and SSE writer files; pairing exchange store moved to a Swift actor; SSE streams send a 30-second keep-alive +- AI providers: shared endpoint normalization and JSON encoding helpers; consistent 5s timeout and known-model fallback when listing models ## [0.39.1] - 2026-05-08 From e0a6c13377f7f07d4e92b0fd6caa0f87a312f944 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 10:27:52 +0700 Subject: [PATCH 06/16] build(ai-chat): fix imports and visibility for AIChatViewModel split --- TablePro/ViewModels/AIChatViewModel+SchemaContext.swift | 1 + TablePro/ViewModels/AIChatViewModel+Streaming.swift | 7 ++++--- TablePro/ViewModels/AIChatViewModel.swift | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/TablePro/ViewModels/AIChatViewModel+SchemaContext.swift b/TablePro/ViewModels/AIChatViewModel+SchemaContext.swift index f76b1982c..507aebe80 100644 --- a/TablePro/ViewModels/AIChatViewModel+SchemaContext.swift +++ b/TablePro/ViewModels/AIChatViewModel+SchemaContext.swift @@ -4,6 +4,7 @@ // import Foundation +import os import TableProPluginKit extension AIChatViewModel { diff --git a/TablePro/ViewModels/AIChatViewModel+Streaming.swift b/TablePro/ViewModels/AIChatViewModel+Streaming.swift index e39fc9a06..42839c85e 100644 --- a/TablePro/ViewModels/AIChatViewModel+Streaming.swift +++ b/TablePro/ViewModels/AIChatViewModel+Streaming.swift @@ -4,6 +4,7 @@ // import Foundation +import os import TableProPluginKit extension AIChatViewModel { @@ -426,7 +427,7 @@ extension AIChatViewModel { return ToolResultBlock(toolUseId: block.id, content: "Cancelled", isError: true) } guard ChatToolRegistry.isToolAllowed(name: block.name, in: mode) else { - logger.warning( + AIChatViewModel.logger.warning( "Tool '\(block.name, privacy: .public)' blocked in \(mode.rawValue, privacy: .public) mode" ) return ToolResultBlock( @@ -439,7 +440,7 @@ extension AIChatViewModel { (registry ?? ChatToolRegistry.shared).tool(named: block.name, in: mode) } guard let tool else { - logger.warning("Tool '\(block.name, privacy: .public)' not registered; returning error") + AIChatViewModel.logger.warning("Tool '\(block.name, privacy: .public)' not registered; returning error") return ToolResultBlock( toolUseId: block.id, content: "Tool '\(block.name)' is not available", @@ -454,7 +455,7 @@ extension AIChatViewModel { isError: result.isError ) } catch { - logger.warning( + AIChatViewModel.logger.warning( "Tool \(block.name, privacy: .public) execution failed: \(error.localizedDescription, privacy: .public)" ) return ToolResultBlock( diff --git a/TablePro/ViewModels/AIChatViewModel.swift b/TablePro/ViewModels/AIChatViewModel.swift index 99c51ba3d..a89a2009b 100644 --- a/TablePro/ViewModels/AIChatViewModel.swift +++ b/TablePro/ViewModels/AIChatViewModel.swift @@ -22,7 +22,7 @@ final class AIChatViewModel { var messages: [ChatTurn] = [] var inputText: String = "" - private(set) var streamingState: StreamingState = .idle + var streamingState: StreamingState = .idle var errorMessage: String? var conversations: [AIConversation] = [] var activeConversationID: UUID? From 456be0407003d93b4187b9e67a0713c7dcea79c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 10:28:20 +0700 Subject: [PATCH 07/16] refactor(ai-chat): unify JSONValue across chat layer and MCP wire --- TablePro/Core/AI/AnthropicProvider.swift | 4 +- TablePro/Core/AI/Chat/ChatTool.swift | 4 +- .../AI/Chat/ChatToolArgumentDecoder.swift | 18 ++--- .../Core/AI/Chat/ChatToolJSONFormatter.swift | 17 ----- .../Core/AI/Chat/ChatToolSpec+Copilot.swift | 2 +- TablePro/Core/AI/Chat/ChatTransport.swift | 2 +- TablePro/Core/AI/Chat/ChatTurn.swift | 6 +- TablePro/Core/AI/Chat/JSONValue.swift | 75 ------------------- .../ConfirmDestructiveOperationChatTool.swift | 6 +- .../AI/Chat/Tools/DescribeTableChatTool.swift | 6 +- .../AI/Chat/Tools/ExecuteQueryChatTool.swift | 6 +- .../Tools/GetConnectionStatusChatTool.swift | 8 +- .../AI/Chat/Tools/GetTableDDLChatTool.swift | 6 +- .../Chat/Tools/ListConnectionsChatTool.swift | 6 +- .../AI/Chat/Tools/ListDatabasesChatTool.swift | 6 +- .../AI/Chat/Tools/ListSchemasChatTool.swift | 6 +- .../AI/Chat/Tools/ListTablesChatTool.swift | 6 +- TablePro/Core/AI/GeminiProvider.swift | 4 +- TablePro/Core/AI/JSONValue+Encoding.swift | 22 ------ .../Core/AI/OpenAICompatibleProvider.swift | 4 +- TablePro/Core/LSP/LSPTypes.swift | 6 +- TablePro/Core/MCP/Wire/JsonValue.swift | 18 +++++ .../AIChatViewModel+Streaming.swift | 4 +- .../AI/ChatToolArgumentDecoderTests.swift | 26 +++---- .../Core/AI/ChatToolRegistryModeTests.swift | 4 +- .../Core/AI/ChatToolRegistryTests.swift | 4 +- .../Core/AI/ExecuteToolUsesTests.swift | 14 ++-- 27 files changed, 97 insertions(+), 193 deletions(-) delete mode 100644 TablePro/Core/AI/Chat/ChatToolJSONFormatter.swift delete mode 100644 TablePro/Core/AI/Chat/JSONValue.swift delete mode 100644 TablePro/Core/AI/JSONValue+Encoding.swift diff --git a/TablePro/Core/AI/AnthropicProvider.swift b/TablePro/Core/AI/AnthropicProvider.swift index 8a9be6f9c..4b191d21c 100644 --- a/TablePro/Core/AI/AnthropicProvider.swift +++ b/TablePro/Core/AI/AnthropicProvider.swift @@ -249,7 +249,7 @@ final class AnthropicProvider: ChatTransport { [ "name": spec.name, "description": spec.description, - "input_schema": try spec.inputSchema.asJSONObject() + "input_schema": try spec.inputSchema.jsonObject() ] } @@ -285,7 +285,7 @@ final class AnthropicProvider: ChatTransport { "type": "tool_use", "id": toolUse.id, "name": toolUse.name, - "input": try toolUse.input.asJSONObject() + "input": try toolUse.input.jsonObject() ] case .toolResult(let result): var encoded: [String: Any] = [ diff --git a/TablePro/Core/AI/Chat/ChatTool.swift b/TablePro/Core/AI/Chat/ChatTool.swift index f3ab31c02..959a72379 100644 --- a/TablePro/Core/AI/Chat/ChatTool.swift +++ b/TablePro/Core/AI/Chat/ChatTool.swift @@ -10,9 +10,9 @@ import Foundation protocol ChatTool: Sendable { var name: String { get } var description: String { get } - var inputSchema: JSONValue { get } + var inputSchema: JsonValue { get } - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult } struct ChatToolResult: Sendable, Equatable, Codable { diff --git a/TablePro/Core/AI/Chat/ChatToolArgumentDecoder.swift b/TablePro/Core/AI/Chat/ChatToolArgumentDecoder.swift index d9479e15d..1fe9b145c 100644 --- a/TablePro/Core/AI/Chat/ChatToolArgumentDecoder.swift +++ b/TablePro/Core/AI/Chat/ChatToolArgumentDecoder.swift @@ -5,25 +5,25 @@ import Foundation -/// Typed decoders for `JSONValue` input arguments coming from the AI. +/// Typed decoders for `JsonValue` input arguments coming from the AI. /// Mirrors `MCPArgumentDecoder` for the MCP protocol but operates on the -/// chat-side `JSONValue` enum. +/// chat-side `JsonValue` enum. enum ChatToolArgumentDecoder { - static func requireString(_ args: JSONValue, key: String) throws -> String { + static func requireString(_ args: JsonValue, key: String) throws -> String { guard case .object(let dict) = args, let value = dict[key], case .string(let str) = value else { throw ChatToolArgumentError.missingOrInvalid(key: key, expected: "string") } return str } - static func optionalString(_ args: JSONValue, key: String) -> String? { + static func optionalString(_ args: JsonValue, key: String) -> String? { guard case .object(let dict) = args, let value = dict[key], case .string(let str) = value else { return nil } return str } - static func requireUUID(_ args: JSONValue, key: String) throws -> UUID { + static func requireUUID(_ args: JsonValue, key: String) throws -> UUID { let str = try requireString(args, key: key) guard let uuid = UUID(uuidString: str) else { throw ChatToolArgumentError.missingOrInvalid(key: key, expected: "UUID string") @@ -31,7 +31,7 @@ enum ChatToolArgumentDecoder { return uuid } - static func optionalBool(_ args: JSONValue, key: String, default fallback: Bool = false) -> Bool { + static func optionalBool(_ args: JsonValue, key: String, default fallback: Bool = false) -> Bool { guard case .object(let dict) = args, let value = dict[key], case .bool(let bool) = value else { return fallback } @@ -39,7 +39,7 @@ enum ChatToolArgumentDecoder { } static func optionalInt( - _ args: JSONValue, + _ args: JsonValue, key: String, default fallback: Int, clamp: ClosedRange? = nil @@ -47,8 +47,8 @@ enum ChatToolArgumentDecoder { guard case .object(let dict) = args, let value = dict[key] else { return fallback } let raw: Int? switch value { - case .integer(let int): raw = Int(int) - case .number(let double): raw = Int(double) + case .int(let int): raw = int + case .double(let double): raw = Int(double) default: raw = nil } guard let raw else { return fallback } diff --git a/TablePro/Core/AI/Chat/ChatToolJSONFormatter.swift b/TablePro/Core/AI/Chat/ChatToolJSONFormatter.swift deleted file mode 100644 index a433285f3..000000000 --- a/TablePro/Core/AI/Chat/ChatToolJSONFormatter.swift +++ /dev/null @@ -1,17 +0,0 @@ -// -// ChatToolJSONFormatter.swift -// TablePro -// - -import Foundation - -/// JSON-encode a `JsonValue` (MCP wire type) as a string for inclusion in a -/// `ChatToolResult`. The chat layer needs strings; MCP bridges return `JsonValue`. -enum ChatToolJSONFormatter { - static func string(from value: JsonValue) throws -> String { - let encoder = JSONEncoder() - encoder.outputFormatting = [.prettyPrinted, .sortedKeys, .withoutEscapingSlashes] - let data = try encoder.encode(value) - return String(data: data, encoding: .utf8) ?? "{}" - } -} diff --git a/TablePro/Core/AI/Chat/ChatToolSpec+Copilot.swift b/TablePro/Core/AI/Chat/ChatToolSpec+Copilot.swift index 2c642ccad..cfdb991f6 100644 --- a/TablePro/Core/AI/Chat/ChatToolSpec+Copilot.swift +++ b/TablePro/Core/AI/Chat/ChatToolSpec+Copilot.swift @@ -14,7 +14,7 @@ extension ChatToolSpec { ) } - private static func normalizeForCopilot(_ schema: JSONValue) -> JSONValue { + private static func normalizeForCopilot(_ schema: JsonValue) -> JsonValue { guard case .object(var dict) = schema else { return schema } if dict["required"] == nil { dict["required"] = .array([]) diff --git a/TablePro/Core/AI/Chat/ChatTransport.swift b/TablePro/Core/AI/Chat/ChatTransport.swift index e47d7837a..e7e6f8322 100644 --- a/TablePro/Core/AI/Chat/ChatTransport.swift +++ b/TablePro/Core/AI/Chat/ChatTransport.swift @@ -41,7 +41,7 @@ struct ChatTransportOptions: Sendable { struct ChatToolSpec: Codable, Equatable, Sendable { let name: String let description: String - let inputSchema: JSONValue + let inputSchema: JsonValue } enum ChatStreamEvent: Sendable { diff --git a/TablePro/Core/AI/Chat/ChatTurn.swift b/TablePro/Core/AI/Chat/ChatTurn.swift index 573c02c88..8d3f55159 100644 --- a/TablePro/Core/AI/Chat/ChatTurn.swift +++ b/TablePro/Core/AI/Chat/ChatTurn.swift @@ -139,10 +139,10 @@ enum ChatContentBlock: Codable, Equatable, Sendable { struct ToolUseBlock: Codable, Equatable, Sendable { let id: String let name: String - let input: JSONValue + let input: JsonValue var approvalState: ToolApprovalState - init(id: String, name: String, input: JSONValue, approvalState: ToolApprovalState = .approved) { + init(id: String, name: String, input: JsonValue, approvalState: ToolApprovalState = .approved) { self.id = id self.name = name self.input = input @@ -153,7 +153,7 @@ struct ToolUseBlock: Codable, Equatable, Sendable { let container = try decoder.container(keyedBy: CodingKeys.self) id = try container.decode(String.self, forKey: .id) name = try container.decode(String.self, forKey: .name) - input = try container.decode(JSONValue.self, forKey: .input) + input = try container.decode(JsonValue.self, forKey: .input) approvalState = try container.decodeIfPresent(ToolApprovalState.self, forKey: .approvalState) ?? .approved } diff --git a/TablePro/Core/AI/Chat/JSONValue.swift b/TablePro/Core/AI/Chat/JSONValue.swift deleted file mode 100644 index aaadf3cc7..000000000 --- a/TablePro/Core/AI/Chat/JSONValue.swift +++ /dev/null @@ -1,75 +0,0 @@ -// -// JSONValue.swift -// TablePro -// - -import Foundation - -enum JSONValue: Codable, Equatable, Sendable, Hashable { - case null - case bool(Bool) - case number(Double) - case integer(Int64) - case string(String) - case array([JSONValue]) - case object([String: JSONValue]) - - init(from decoder: Decoder) throws { - let container = try decoder.singleValueContainer() - if container.decodeNil() { - self = .null - return - } - if let value = try? container.decode(Bool.self) { - self = .bool(value) - return - } - if let value = try? container.decode(Int64.self) { - self = .integer(value) - return - } - if let value = try? container.decode(Double.self) { - self = .number(value) - return - } - if let value = try? container.decode(String.self) { - self = .string(value) - return - } - if let value = try? container.decode([JSONValue].self) { - self = .array(value) - return - } - if let value = try? container.decode([String: JSONValue].self) { - self = .object(value) - return - } - throw DecodingError.dataCorruptedError( - in: container, - debugDescription: "Unsupported JSON value" - ) - } - - func encode(to encoder: Encoder) throws { - var container = encoder.singleValueContainer() - switch self { - case .null: try container.encodeNil() - case .bool(let value): try container.encode(value) - case .integer(let value): try container.encode(value) - case .number(let value): try container.encode(value) - case .string(let value): try container.encode(value) - case .array(let value): try container.encode(value) - case .object(let value): try container.encode(value) - } - } - - func decoded(as type: T.Type) throws -> T { - let data = try JSONEncoder().encode(self) - return try JSONDecoder().decode(type, from: data) - } - - static func encoded(_ value: some Encodable) throws -> JSONValue { - let data = try JSONEncoder().encode(value) - return try JSONDecoder().decode(JSONValue.self, from: data) - } -} diff --git a/TablePro/Core/AI/Chat/Tools/ConfirmDestructiveOperationChatTool.swift b/TablePro/Core/AI/Chat/Tools/ConfirmDestructiveOperationChatTool.swift index c6e87a691..e6398dd1e 100644 --- a/TablePro/Core/AI/Chat/Tools/ConfirmDestructiveOperationChatTool.swift +++ b/TablePro/Core/AI/Chat/Tools/ConfirmDestructiveOperationChatTool.swift @@ -21,7 +21,7 @@ struct ConfirmDestructiveOperationChatTool: ChatTool { Execute a destructive DDL query (DROP, TRUNCATE, ALTER...DROP) after explicit confirmation.\ Pass confirmation_phrase exactly as: I understand this is irreversible """) - let inputSchema: JSONValue = .object([ + let inputSchema: JsonValue = .object([ "type": .string("object"), "properties": .object([ "connection_id": .object([ @@ -44,7 +44,7 @@ struct ConfirmDestructiveOperationChatTool: ChatTool { ]) ]) - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { let connectionId = try resolveConnectionId(input: input, context: context) let query = try ChatToolArgumentDecoder.requireString(input, key: "query") let confirmationPhrase = try ChatToolArgumentDecoder.requireString(input, key: "confirmation_phrase") @@ -82,6 +82,6 @@ struct ConfirmDestructiveOperationChatTool: ChatTool { timeoutSeconds: mcpSettings.queryTimeoutSeconds, principalLabel: String(localized: "AI Chat") ) - return ChatToolResult(content: try ChatToolJSONFormatter.string(from: payload)) + return ChatToolResult(content: payload.jsonString(prettyPrinted: true)) } } diff --git a/TablePro/Core/AI/Chat/Tools/DescribeTableChatTool.swift b/TablePro/Core/AI/Chat/Tools/DescribeTableChatTool.swift index dcb8a9d41..5e6dc8e5e 100644 --- a/TablePro/Core/AI/Chat/Tools/DescribeTableChatTool.swift +++ b/TablePro/Core/AI/Chat/Tools/DescribeTableChatTool.swift @@ -8,7 +8,7 @@ import Foundation struct DescribeTableChatTool: ChatTool { let name = "describe_table" let description = String(localized: "Describe the columns of a table or view.") - let inputSchema: JSONValue = .object([ + let inputSchema: JsonValue = .object([ "type": .string("object"), "properties": .object([ "connection_id": .object([ @@ -27,7 +27,7 @@ struct DescribeTableChatTool: ChatTool { "required": .array([.string("table")]) ]) - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { let connectionId = try resolveConnectionId(input: input, context: context) let table = try ChatToolArgumentDecoder.requireString(input, key: "table") let schema = ChatToolArgumentDecoder.optionalString(input, key: "schema") @@ -36,6 +36,6 @@ struct DescribeTableChatTool: ChatTool { table: table, schema: schema ) - return ChatToolResult(content: try ChatToolJSONFormatter.string(from: payload)) + return ChatToolResult(content: payload.jsonString(prettyPrinted: true)) } } diff --git a/TablePro/Core/AI/Chat/Tools/ExecuteQueryChatTool.swift b/TablePro/Core/AI/Chat/Tools/ExecuteQueryChatTool.swift index 390fa2fff..8dff31e88 100644 --- a/TablePro/Core/AI/Chat/Tools/ExecuteQueryChatTool.swift +++ b/TablePro/Core/AI/Chat/Tools/ExecuteQueryChatTool.swift @@ -16,7 +16,7 @@ struct ExecuteQueryChatTool: ChatTool { Multi-statement queries are rejected. Destructive operations (DROP, TRUNCATE, ALTER...DROP)\ are blocked here; use confirm_destructive_operation instead. """) - let inputSchema: JSONValue = .object([ + let inputSchema: JsonValue = .object([ "type": .string("object"), "properties": .object([ "connection_id": .object([ @@ -47,7 +47,7 @@ struct ExecuteQueryChatTool: ChatTool { "required": .array([.string("connection_id"), .string("query")]) ]) - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { let connectionId = try resolveConnectionId(input: input, context: context) let query = try ChatToolArgumentDecoder.requireString(input, key: "query") let database = ChatToolArgumentDecoder.optionalString(input, key: "database") @@ -107,6 +107,6 @@ struct ExecuteQueryChatTool: ChatTool { timeoutSeconds: timeoutSeconds, principalLabel: String(localized: "AI Chat") ) - return ChatToolResult(content: try ChatToolJSONFormatter.string(from: payload)) + return ChatToolResult(content: payload.jsonString(prettyPrinted: true)) } } diff --git a/TablePro/Core/AI/Chat/Tools/GetConnectionStatusChatTool.swift b/TablePro/Core/AI/Chat/Tools/GetConnectionStatusChatTool.swift index 5fdcef3e2..c6ef2c7c8 100644 --- a/TablePro/Core/AI/Chat/Tools/GetConnectionStatusChatTool.swift +++ b/TablePro/Core/AI/Chat/Tools/GetConnectionStatusChatTool.swift @@ -8,7 +8,7 @@ import Foundation struct GetConnectionStatusChatTool: ChatTool { let name = "get_connection_status" let description = String(localized: "Get detailed status for a specific database connection.") - let inputSchema: JSONValue = .object([ + let inputSchema: JsonValue = .object([ "type": .string("object"), "properties": .object([ "connection_id": .object([ @@ -19,14 +19,14 @@ struct GetConnectionStatusChatTool: ChatTool { "required": .array([.string("connection_id")]) ]) - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { let connectionId = try resolveConnectionId(input: input, context: context) let payload = try await context.bridge.getConnectionStatus(connectionId: connectionId) - return ChatToolResult(content: try ChatToolJSONFormatter.string(from: payload)) + return ChatToolResult(content: payload.jsonString(prettyPrinted: true)) } } -func resolveConnectionId(input: JSONValue, context: ChatToolContext) throws -> UUID { +func resolveConnectionId(input: JsonValue, context: ChatToolContext) throws -> UUID { if let connectionId = try? ChatToolArgumentDecoder.requireUUID(input, key: "connection_id") { return connectionId } diff --git a/TablePro/Core/AI/Chat/Tools/GetTableDDLChatTool.swift b/TablePro/Core/AI/Chat/Tools/GetTableDDLChatTool.swift index b8ad9d0d1..c4f6b0fcf 100644 --- a/TablePro/Core/AI/Chat/Tools/GetTableDDLChatTool.swift +++ b/TablePro/Core/AI/Chat/Tools/GetTableDDLChatTool.swift @@ -8,7 +8,7 @@ import Foundation struct GetTableDDLChatTool: ChatTool { let name = "get_table_ddl" let description = String(localized: "Get the DDL (CREATE statement) for a table.") - let inputSchema: JSONValue = .object([ + let inputSchema: JsonValue = .object([ "type": .string("object"), "properties": .object([ "connection_id": .object([ @@ -27,7 +27,7 @@ struct GetTableDDLChatTool: ChatTool { "required": .array([.string("table")]) ]) - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { let connectionId = try resolveConnectionId(input: input, context: context) let table = try ChatToolArgumentDecoder.requireString(input, key: "table") let schema = ChatToolArgumentDecoder.optionalString(input, key: "schema") @@ -36,6 +36,6 @@ struct GetTableDDLChatTool: ChatTool { table: table, schema: schema ) - return ChatToolResult(content: try ChatToolJSONFormatter.string(from: payload)) + return ChatToolResult(content: payload.jsonString(prettyPrinted: true)) } } diff --git a/TablePro/Core/AI/Chat/Tools/ListConnectionsChatTool.swift b/TablePro/Core/AI/Chat/Tools/ListConnectionsChatTool.swift index 9504df91b..8de1cab1e 100644 --- a/TablePro/Core/AI/Chat/Tools/ListConnectionsChatTool.swift +++ b/TablePro/Core/AI/Chat/Tools/ListConnectionsChatTool.swift @@ -8,13 +8,13 @@ import Foundation struct ListConnectionsChatTool: ChatTool { let name = "list_connections" let description = String(localized: "List all saved database connections with their current status.") - let inputSchema: JSONValue = .object([ + let inputSchema: JsonValue = .object([ "type": .string("object"), "properties": .object([:]) ]) - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { let payload = await context.bridge.listConnections() - return ChatToolResult(content: try ChatToolJSONFormatter.string(from: payload)) + return ChatToolResult(content: payload.jsonString(prettyPrinted: true)) } } diff --git a/TablePro/Core/AI/Chat/Tools/ListDatabasesChatTool.swift b/TablePro/Core/AI/Chat/Tools/ListDatabasesChatTool.swift index 1872c5e71..21d20e210 100644 --- a/TablePro/Core/AI/Chat/Tools/ListDatabasesChatTool.swift +++ b/TablePro/Core/AI/Chat/Tools/ListDatabasesChatTool.swift @@ -8,7 +8,7 @@ import Foundation struct ListDatabasesChatTool: ChatTool { let name = "list_databases" let description = String(localized: "List databases available on a connection.") - let inputSchema: JSONValue = .object([ + let inputSchema: JsonValue = .object([ "type": .string("object"), "properties": .object([ "connection_id": .object([ @@ -18,9 +18,9 @@ struct ListDatabasesChatTool: ChatTool { ]) ]) - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { let connectionId = try resolveConnectionId(input: input, context: context) let payload = try await context.bridge.listDatabases(connectionId: connectionId) - return ChatToolResult(content: try ChatToolJSONFormatter.string(from: payload)) + return ChatToolResult(content: payload.jsonString(prettyPrinted: true)) } } diff --git a/TablePro/Core/AI/Chat/Tools/ListSchemasChatTool.swift b/TablePro/Core/AI/Chat/Tools/ListSchemasChatTool.swift index 99a105174..f9e3fc099 100644 --- a/TablePro/Core/AI/Chat/Tools/ListSchemasChatTool.swift +++ b/TablePro/Core/AI/Chat/Tools/ListSchemasChatTool.swift @@ -8,7 +8,7 @@ import Foundation struct ListSchemasChatTool: ChatTool { let name = "list_schemas" let description = String(localized: "List schemas available in the active database of a connection.") - let inputSchema: JSONValue = .object([ + let inputSchema: JsonValue = .object([ "type": .string("object"), "properties": .object([ "connection_id": .object([ @@ -18,9 +18,9 @@ struct ListSchemasChatTool: ChatTool { ]) ]) - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { let connectionId = try resolveConnectionId(input: input, context: context) let payload = try await context.bridge.listSchemas(connectionId: connectionId) - return ChatToolResult(content: try ChatToolJSONFormatter.string(from: payload)) + return ChatToolResult(content: payload.jsonString(prettyPrinted: true)) } } diff --git a/TablePro/Core/AI/Chat/Tools/ListTablesChatTool.swift b/TablePro/Core/AI/Chat/Tools/ListTablesChatTool.swift index 2f05032e1..0afb1c942 100644 --- a/TablePro/Core/AI/Chat/Tools/ListTablesChatTool.swift +++ b/TablePro/Core/AI/Chat/Tools/ListTablesChatTool.swift @@ -8,7 +8,7 @@ import Foundation struct ListTablesChatTool: ChatTool { let name = "list_tables" let description = String(localized: "List tables and views in the active database of a connection.") - let inputSchema: JSONValue = .object([ + let inputSchema: JsonValue = .object([ "type": .string("object"), "properties": .object([ "connection_id": .object([ @@ -30,7 +30,7 @@ struct ListTablesChatTool: ChatTool { ]) ]) - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { let connectionId = try resolveConnectionId(input: input, context: context) let database = ChatToolArgumentDecoder.optionalString(input, key: "database") let schema = ChatToolArgumentDecoder.optionalString(input, key: "schema") @@ -47,6 +47,6 @@ struct ListTablesChatTool: ChatTool { connectionId: connectionId, includeRowCounts: includeRowCounts ) - return ChatToolResult(content: try ChatToolJSONFormatter.string(from: payload)) + return ChatToolResult(content: payload.jsonString(prettyPrinted: true)) } } diff --git a/TablePro/Core/AI/GeminiProvider.swift b/TablePro/Core/AI/GeminiProvider.swift index 32efe46aa..dac07ce49 100644 --- a/TablePro/Core/AI/GeminiProvider.swift +++ b/TablePro/Core/AI/GeminiProvider.swift @@ -184,7 +184,7 @@ final class GeminiProvider: ChatTransport { "name": tool.name, "description": tool.description ] - entry["parameters"] = try tool.inputSchema.asJSONObject() + entry["parameters"] = try tool.inputSchema.jsonObject() return entry } body["tools"] = [["functionDeclarations": declarations]] @@ -218,7 +218,7 @@ final class GeminiProvider: ChatTransport { case .attachment: continue case .toolUse(let useBlock): - let argsObject = (try? useBlock.input.asJSONObject()) ?? [String: Any]() + let argsObject = (try? useBlock.input.jsonObject()) ?? [String: Any]() parts.append([ "functionCall": [ "name": useBlock.name, diff --git a/TablePro/Core/AI/JSONValue+Encoding.swift b/TablePro/Core/AI/JSONValue+Encoding.swift deleted file mode 100644 index 6e3f4570b..000000000 --- a/TablePro/Core/AI/JSONValue+Encoding.swift +++ /dev/null @@ -1,22 +0,0 @@ -// -// JSONValue+Encoding.swift -// TablePro -// - -import Foundation - -extension JSONValue { - func asJSONObject() throws -> Any { - let data = try JSONEncoder().encode(self) - return try JSONSerialization.jsonObject(with: data, options: [.fragmentsAllowed]) - } - - func asJSONString() -> String { - guard let data = try? JSONEncoder().encode(self), - let string = String(data: data, encoding: .utf8) - else { - return "{}" - } - return string - } -} diff --git a/TablePro/Core/AI/OpenAICompatibleProvider.swift b/TablePro/Core/AI/OpenAICompatibleProvider.swift index 672422ef3..3ee7babdb 100644 --- a/TablePro/Core/AI/OpenAICompatibleProvider.swift +++ b/TablePro/Core/AI/OpenAICompatibleProvider.swift @@ -375,7 +375,7 @@ final class OpenAICompatibleProvider: ChatTransport { "type": "function", "function": [ "name": block.name, - "arguments": block.input.asJSONString() + "arguments": block.input.jsonString() ] ] } @@ -407,7 +407,7 @@ final class OpenAICompatibleProvider: ChatTransport { } func encodeTool(_ tool: ChatToolSpec) throws -> [String: Any] { - let parameters = try tool.inputSchema.asJSONObject() + let parameters = try tool.inputSchema.jsonObject() return [ "type": "function", "function": [ diff --git a/TablePro/Core/LSP/LSPTypes.swift b/TablePro/Core/LSP/LSPTypes.swift index 2679d24b6..3d5c7c7fe 100644 --- a/TablePro/Core/LSP/LSPTypes.swift +++ b/TablePro/Core/LSP/LSPTypes.swift @@ -352,7 +352,7 @@ enum AnyCodableValue: Sendable, Equatable { struct CopilotLanguageModelToolInformation: Codable, Sendable { let name: String let description: String - let inputSchema: JSONValue? + let inputSchema: JsonValue? } struct CopilotRegisterToolsParams: Codable, Sendable { @@ -361,7 +361,7 @@ struct CopilotRegisterToolsParams: Codable, Sendable { struct CopilotInvokeClientToolParams: Codable, Sendable { let name: String - let input: JSONValue? + let input: JsonValue? let conversationId: String let turnId: String } @@ -373,7 +373,7 @@ enum CopilotToolInvocationStatus: String, Codable, Sendable { } struct CopilotLanguageModelToolResultContent: Codable, Sendable { - let value: JSONValue + let value: JsonValue } struct CopilotLanguageModelToolResult: Codable, Sendable { diff --git a/TablePro/Core/MCP/Wire/JsonValue.swift b/TablePro/Core/MCP/Wire/JsonValue.swift index 9ff2d6502..91128f095 100644 --- a/TablePro/Core/MCP/Wire/JsonValue.swift +++ b/TablePro/Core/MCP/Wire/JsonValue.swift @@ -114,6 +114,24 @@ extension JsonValue: ExpressibleByDictionaryLiteral { } public extension JsonValue { + func jsonObject() throws -> Any { + let data = try JSONEncoder().encode(self) + return try JSONSerialization.jsonObject(with: data, options: [.fragmentsAllowed]) + } + + func jsonString(prettyPrinted: Bool = false) -> String { + let encoder = JSONEncoder() + if prettyPrinted { + encoder.outputFormatting = [.prettyPrinted, .sortedKeys, .withoutEscapingSlashes] + } + guard let data = try? encoder.encode(self), + let string = String(data: data, encoding: .utf8) + else { + return "{}" + } + return string + } + subscript(key: String) -> JsonValue? { guard case .object(let dict) = self else { return nil } return dict[key] diff --git a/TablePro/ViewModels/AIChatViewModel+Streaming.swift b/TablePro/ViewModels/AIChatViewModel+Streaming.swift index 42839c85e..b81ef6430 100644 --- a/TablePro/ViewModels/AIChatViewModel+Streaming.swift +++ b/TablePro/ViewModels/AIChatViewModel+Streaming.swift @@ -386,11 +386,11 @@ extension AIChatViewModel { order.compactMap { id -> ToolUseBlock? in guard let name = names[id] else { return nil } let inputString = inputs[id] ?? "{}" - let inputValue: JSONValue + let inputValue: JsonValue if inputString.isEmpty { inputValue = .object([:]) } else if let data = inputString.data(using: .utf8), - let decoded = try? JSONDecoder().decode(JSONValue.self, from: data) { + let decoded = try? JSONDecoder().decode(JsonValue.self, from: data) { inputValue = decoded } else { inputValue = .object([:]) diff --git a/TableProTests/Core/AI/ChatToolArgumentDecoderTests.swift b/TableProTests/Core/AI/ChatToolArgumentDecoderTests.swift index ee411297f..791df8456 100644 --- a/TableProTests/Core/AI/ChatToolArgumentDecoderTests.swift +++ b/TableProTests/Core/AI/ChatToolArgumentDecoderTests.swift @@ -11,13 +11,13 @@ import Testing struct ChatToolArgumentDecoderTests { @Test("requireString returns value when key exists and is a string") func requireStringPresent() throws { - let args: JSONValue = .object(["name": .string("alpha")]) + let args: JsonValue = .object(["name": .string("alpha")]) #expect(try ChatToolArgumentDecoder.requireString(args, key: "name") == "alpha") } @Test("requireString throws when key is missing") func requireStringMissing() { - let args: JSONValue = .object([:]) + let args: JsonValue = .object([:]) #expect(throws: ChatToolArgumentError.self) { _ = try ChatToolArgumentDecoder.requireString(args, key: "name") } @@ -25,7 +25,7 @@ struct ChatToolArgumentDecoderTests { @Test("requireString throws when value is not a string") func requireStringWrongType() { - let args: JSONValue = .object(["count": .integer(42)]) + let args: JsonValue = .object(["count": .int(42)]) #expect(throws: ChatToolArgumentError.self) { _ = try ChatToolArgumentDecoder.requireString(args, key: "count") } @@ -33,20 +33,20 @@ struct ChatToolArgumentDecoderTests { @Test("optionalString returns nil for missing key") func optionalStringMissing() { - let args: JSONValue = .object([:]) + let args: JsonValue = .object([:]) #expect(ChatToolArgumentDecoder.optionalString(args, key: "name") == nil) } @Test("requireUUID parses a valid UUID string") func requireUUIDValid() throws { let id = UUID() - let args: JSONValue = .object(["connection_id": .string(id.uuidString)]) + let args: JsonValue = .object(["connection_id": .string(id.uuidString)]) #expect(try ChatToolArgumentDecoder.requireUUID(args, key: "connection_id") == id) } @Test("requireUUID throws for malformed UUID string") func requireUUIDInvalid() { - let args: JSONValue = .object(["connection_id": .string("not-a-uuid")]) + let args: JsonValue = .object(["connection_id": .string("not-a-uuid")]) #expect(throws: ChatToolArgumentError.self) { _ = try ChatToolArgumentDecoder.requireUUID(args, key: "connection_id") } @@ -54,38 +54,38 @@ struct ChatToolArgumentDecoderTests { @Test("optionalBool returns the default when key missing") func optionalBoolDefault() { - let args: JSONValue = .object([:]) + let args: JsonValue = .object([:]) #expect(ChatToolArgumentDecoder.optionalBool(args, key: "enabled", default: true) == true) #expect(ChatToolArgumentDecoder.optionalBool(args, key: "enabled", default: false) == false) } @Test("optionalBool returns the value when present") func optionalBoolPresent() { - let args: JSONValue = .object(["enabled": .bool(false)]) + let args: JsonValue = .object(["enabled": .bool(false)]) #expect(ChatToolArgumentDecoder.optionalBool(args, key: "enabled", default: true) == false) } @Test("optionalInt returns fallback when key is missing") func optionalIntMissing() { - let args: JSONValue = .object([:]) + let args: JsonValue = .object([:]) #expect(ChatToolArgumentDecoder.optionalInt(args, key: "max_rows", default: 500) == 500) } @Test("optionalInt accepts integer values") func optionalIntInteger() { - let args: JSONValue = .object(["max_rows": .integer(120)]) + let args: JsonValue = .object(["max_rows": .int(120)]) #expect(ChatToolArgumentDecoder.optionalInt(args, key: "max_rows", default: 500) == 120) } @Test("optionalInt coerces number (Double) to Int") func optionalIntFromDouble() { - let args: JSONValue = .object(["max_rows": .number(120.7)]) + let args: JsonValue = .object(["max_rows": .double(120.7)]) #expect(ChatToolArgumentDecoder.optionalInt(args, key: "max_rows", default: 500) == 120) } @Test("optionalInt clamps to the supplied range") func optionalIntClamps() { - let args: JSONValue = .object(["max_rows": .integer(50_000)]) + let args: JsonValue = .object(["max_rows": .int(50_000)]) #expect( ChatToolArgumentDecoder.optionalInt(args, key: "max_rows", default: 500, clamp: 1...10_000) == 10_000 @@ -94,7 +94,7 @@ struct ChatToolArgumentDecoderTests { @Test("optionalInt returns fallback for non-numeric value") func optionalIntWrongType() { - let args: JSONValue = .object(["max_rows": .string("ten")]) + let args: JsonValue = .object(["max_rows": .string("ten")]) #expect(ChatToolArgumentDecoder.optionalInt(args, key: "max_rows", default: 500) == 500) } } diff --git a/TableProTests/Core/AI/ChatToolRegistryModeTests.swift b/TableProTests/Core/AI/ChatToolRegistryModeTests.swift index 97663c19b..4f24b32de 100644 --- a/TableProTests/Core/AI/ChatToolRegistryModeTests.swift +++ b/TableProTests/Core/AI/ChatToolRegistryModeTests.swift @@ -13,9 +13,9 @@ struct ChatToolRegistryModeTests { private struct StubTool: ChatTool { let name: String let description = "" - let inputSchema: JSONValue = .object(["type": .string("object")]) + let inputSchema: JsonValue = .object(["type": .string("object")]) - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { ChatToolResult(content: "ok") } } diff --git a/TableProTests/Core/AI/ChatToolRegistryTests.swift b/TableProTests/Core/AI/ChatToolRegistryTests.swift index 0e82d6d6c..30616114c 100644 --- a/TableProTests/Core/AI/ChatToolRegistryTests.swift +++ b/TableProTests/Core/AI/ChatToolRegistryTests.swift @@ -13,7 +13,7 @@ struct ChatToolRegistryTests { private struct StubTool: ChatTool { let name: String let description: String - let inputSchema: JSONValue + let inputSchema: JsonValue let response: String init(name: String, description: String = "", response: String = "ok") { @@ -23,7 +23,7 @@ struct ChatToolRegistryTests { self.response = response } - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { ChatToolResult(content: response) } } diff --git a/TableProTests/Core/AI/ExecuteToolUsesTests.swift b/TableProTests/Core/AI/ExecuteToolUsesTests.swift index 95454c33d..23f425b9b 100644 --- a/TableProTests/Core/AI/ExecuteToolUsesTests.swift +++ b/TableProTests/Core/AI/ExecuteToolUsesTests.swift @@ -15,10 +15,10 @@ struct ExecuteToolUsesTests { private final class StubTool: ChatTool { let name: String let description: String - let inputSchema: JSONValue + let inputSchema: JsonValue let response: String let isError: Bool - private(set) var invocations: [JSONValue] = [] + private(set) var invocations: [JsonValue] = [] init(name: String, response: String = "ok", isError: Bool = false) { self.name = name @@ -28,7 +28,7 @@ struct ExecuteToolUsesTests { self.isError = isError } - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { invocations.append(input) return ChatToolResult(content: response, isError: isError) } @@ -39,9 +39,9 @@ struct ExecuteToolUsesTests { private struct ThrowingTool: ChatTool { let name: String let description = "" - let inputSchema: JSONValue = .object(["type": .string("object")]) + let inputSchema: JsonValue = .object(["type": .string("object")]) struct Boom: Error {} - func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + func execute(input: JsonValue, context: ChatToolContext) async throws -> ChatToolResult { throw Boom() } } @@ -158,12 +158,12 @@ struct ExecuteToolUsesTests { #expect(results[1].isError == true) } - @Test("Tool receives the input JSONValue from its ToolUseBlock") + @Test("Tool receives the input JsonValue from its ToolUseBlock") func inputForwarded() async { let registry = ChatToolRegistry() let stub = StubTool(name: "alpha") registry.register(stub) - let input: JSONValue = .object(["query": .string("SELECT 1")]) + let input: JsonValue = .object(["query": .string("SELECT 1")]) _ = await AIChatViewModel.executeToolUses( [ToolUseBlock(id: "u1", name: "alpha", input: input)], mode: .agent, From 06a0ef56efd30ee0b5dc77db73ba3fa41b734103 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 10:28:43 +0700 Subject: [PATCH 08/16] build(ai-chat): import os in SlashCommands extension --- TablePro/ViewModels/AIChatViewModel+SlashCommands.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/TablePro/ViewModels/AIChatViewModel+SlashCommands.swift b/TablePro/ViewModels/AIChatViewModel+SlashCommands.swift index bb887b156..b5a38735f 100644 --- a/TablePro/ViewModels/AIChatViewModel+SlashCommands.swift +++ b/TablePro/ViewModels/AIChatViewModel+SlashCommands.swift @@ -4,6 +4,7 @@ // import Foundation +import os extension AIChatViewModel { static let helpMarkdown: String = { From 7bbfacef032cd7691bc35078fd0bd71db47a530a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 10:32:27 +0700 Subject: [PATCH 09/16] refactor(ai-chat): add schema versioning to conversations and consistent settings defaults --- CHANGELOG.md | 3 + TablePro/Core/AI/Chat/ChatTurn.swift | 24 +++--- TablePro/Core/Storage/AIChatStorage.swift | 10 +++ .../Storage/CustomSlashCommandStorage.swift | 33 +++++++- TablePro/Models/AI/AIConversation.swift | 27 +++++-- TablePro/Models/AI/AIModels.swift | 8 +- .../Settings/CustomSlashCommandsSection.swift | 29 +++++-- .../CustomSlashCommandStorageTests.swift | 79 +++++++++++++++++++ .../Models/AIConversationTests.swift | 57 ++++++++++++- TableProTests/Models/AISettingsTests.swift | 40 ++++++---- 10 files changed, 259 insertions(+), 51 deletions(-) create mode 100644 TableProTests/Core/Storage/CustomSlashCommandStorageTests.swift diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d3816e49..08326e919 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - AI Chat: streaming view model split into focused extensions backed by a single `streamingState` enum - MCP HTTP server: split transport into connection, router, and SSE writer files; pairing exchange store moved to a Swift actor; SSE streams send a 30-second keep-alive - AI providers: shared endpoint normalization and JSON encoding helpers; consistent 5s timeout and known-model fallback when listing models +- AI settings: include schema and current query default to on for new installs, matching the previous decoded fallback +- AI Chat: persisted conversations now carry a schema version so future migrations can read older files cleanly +- AI Chat: custom slash commands reject duplicate names, including case-insensitive collisions on rename ## [0.39.1] - 2026-05-08 diff --git a/TablePro/Core/AI/Chat/ChatTurn.swift b/TablePro/Core/AI/Chat/ChatTurn.swift index 8d3f55159..00ca24bce 100644 --- a/TablePro/Core/AI/Chat/ChatTurn.swift +++ b/TablePro/Core/AI/Chat/ChatTurn.swift @@ -49,26 +49,22 @@ struct ChatTurn: Codable, Equatable, Identifiable, Sendable { if let decodedBlocks = try container.decodeIfPresent([ChatContentBlock].self, forKey: .blocks) { blocks = decodedBlocks - } else if let legacyText = try container.decodeIfPresent(String.self, forKey: .content) { - blocks = [.text(legacyText)] } else { - blocks = [] + let legacyContainer = try decoder.container(keyedBy: LegacyKeys.self) + if let legacyText = try legacyContainer.decodeIfPresent(String.self, forKey: .content) { + blocks = [.text(legacyText)] + } else { + blocks = [] + } } } - func encode(to encoder: Encoder) throws { - var container = encoder.container(keyedBy: CodingKeys.self) - try container.encode(id, forKey: .id) - try container.encode(role, forKey: .role) - try container.encode(blocks, forKey: .blocks) - try container.encode(timestamp, forKey: .timestamp) - try container.encodeIfPresent(usage, forKey: .usage) - try container.encodeIfPresent(modelId, forKey: .modelId) - try container.encodeIfPresent(providerId, forKey: .providerId) + private enum CodingKeys: String, CodingKey { + case id, role, blocks, timestamp, usage, modelId, providerId } - private enum CodingKeys: String, CodingKey { - case id, role, blocks, content, timestamp, usage, modelId, providerId + private enum LegacyKeys: String, CodingKey { + case content } var plainText: String { diff --git a/TablePro/Core/Storage/AIChatStorage.swift b/TablePro/Core/Storage/AIChatStorage.swift index a7e622196..e7928b0e6 100644 --- a/TablePro/Core/Storage/AIChatStorage.swift +++ b/TablePro/Core/Storage/AIChatStorage.swift @@ -73,9 +73,19 @@ actor AIChatStorage { var data = try Self.encoder.encode(conversation) if data.count > Self.maxFileSize { + let originalSize = data.count + let originalCount = conversation.messages.count var trimmed = conversation trimmed.messages = Array(trimmed.messages.suffix(Self.trimmedMessageCount)) + let dropped = originalCount - trimmed.messages.count data = try Self.encoder.encode(trimmed) + Self.logger.warning( + """ + Trimmed conversation \(conversation.id, privacy: .public): \ + \(originalSize) bytes exceeded \(Self.maxFileSize), \ + dropped \(dropped) of \(originalCount) messages, kept \(trimmed.messages.count) + """ + ) } try data.write(to: fileURL, options: [.atomic, .completeFileProtectionUntilFirstUserAuthentication]) diff --git a/TablePro/Core/Storage/CustomSlashCommandStorage.swift b/TablePro/Core/Storage/CustomSlashCommandStorage.swift index 3c02ce8d5..278525b35 100644 --- a/TablePro/Core/Storage/CustomSlashCommandStorage.swift +++ b/TablePro/Core/Storage/CustomSlashCommandStorage.swift @@ -7,8 +7,20 @@ import Foundation import Observation import os -/// UserDefaults-backed store for `CustomSlashCommand`s. Observable so the -/// chat composer's slash menu and the Settings list rerender on edits. +enum CustomSlashCommandError: LocalizedError, Equatable { + case duplicateName(String) + + var errorDescription: String? { + switch self { + case .duplicateName(let name): + return String( + format: String(localized: "A command named \"/%@\" already exists."), + name + ) + } + } +} + @MainActor @Observable final class CustomSlashCommandStorage { @@ -25,13 +37,26 @@ final class CustomSlashCommandStorage { self.commands = Self.load(from: defaults) } - func add(_ command: CustomSlashCommand) { + func isDuplicate(_ name: String, excluding id: UUID? = nil) -> Bool { + commands.contains { existing in + if let id, existing.id == id { return false } + return existing.name.caseInsensitiveCompare(name) == .orderedSame + } + } + + func add(_ command: CustomSlashCommand) throws { + if isDuplicate(command.name) { + throw CustomSlashCommandError.duplicateName(command.name) + } commands.append(command) persist() } - func update(_ command: CustomSlashCommand) { + func update(_ command: CustomSlashCommand) throws { guard let idx = commands.firstIndex(where: { $0.id == command.id }) else { return } + if isDuplicate(command.name, excluding: command.id) { + throw CustomSlashCommandError.duplicateName(command.name) + } commands[idx] = command persist() } diff --git a/TablePro/Models/AI/AIConversation.swift b/TablePro/Models/AI/AIConversation.swift index 0a844034d..2323e481c 100644 --- a/TablePro/Models/AI/AIConversation.swift +++ b/TablePro/Models/AI/AIConversation.swift @@ -2,19 +2,19 @@ // AIConversation.swift // TablePro // -// Data model for a persisted AI chat conversation. -// import Foundation -/// A persisted AI chat conversation struct AIConversation: Codable, Equatable, Identifiable { + static let currentSchemaVersion = 1 + let id: UUID var title: String var messages: [ChatTurn] let createdAt: Date var updatedAt: Date var connectionName: String? + let schemaVersion: Int init( id: UUID = UUID(), @@ -22,7 +22,8 @@ struct AIConversation: Codable, Equatable, Identifiable { messages: [ChatTurn] = [], createdAt: Date = Date(), updatedAt: Date = Date(), - connectionName: String? = nil + connectionName: String? = nil, + schemaVersion: Int = AIConversation.currentSchemaVersion ) { self.id = id self.title = title @@ -30,9 +31,25 @@ struct AIConversation: Codable, Equatable, Identifiable { self.createdAt = createdAt self.updatedAt = updatedAt self.connectionName = connectionName + self.schemaVersion = schemaVersion + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + id = try container.decode(UUID.self, forKey: .id) + title = try container.decodeIfPresent(String.self, forKey: .title) ?? "" + messages = try container.decodeIfPresent([ChatTurn].self, forKey: .messages) ?? [] + createdAt = try container.decode(Date.self, forKey: .createdAt) + updatedAt = try container.decode(Date.self, forKey: .updatedAt) + connectionName = try container.decodeIfPresent(String.self, forKey: .connectionName) + let storedVersion = try container.decodeIfPresent(Int.self, forKey: .schemaVersion) ?? 0 + schemaVersion = max(storedVersion, AIConversation.currentSchemaVersion) + } + + private enum CodingKeys: String, CodingKey { + case id, title, messages, createdAt, updatedAt, connectionName, schemaVersion } - /// Derive title from the first user message (max 50 chars) mutating func updateTitle() { guard title.isEmpty, let firstUserMessage = messages.first(where: { $0.role == .user }) diff --git a/TablePro/Models/AI/AIModels.swift b/TablePro/Models/AI/AIModels.swift index 9ba51e5f8..48cf10d9a 100644 --- a/TablePro/Models/AI/AIModels.swift +++ b/TablePro/Models/AI/AIModels.swift @@ -198,8 +198,8 @@ struct AISettings: Codable, Equatable, Sendable { providers: [], activeProviderID: nil, inlineSuggestionsEnabled: false, - includeSchema: false, - includeCurrentQuery: false, + includeSchema: true, + includeCurrentQuery: true, includeQueryResults: false, maxSchemaTables: 20, defaultConnectionPolicy: .askEachTime, @@ -211,8 +211,8 @@ struct AISettings: Codable, Equatable, Sendable { providers: [AIProviderConfig] = [], activeProviderID: UUID? = nil, inlineSuggestionsEnabled: Bool = false, - includeSchema: Bool = false, - includeCurrentQuery: Bool = false, + includeSchema: Bool = true, + includeCurrentQuery: Bool = true, includeQueryResults: Bool = false, maxSchemaTables: Int = 20, defaultConnectionPolicy: AIConnectionPolicy = .askEachTime, diff --git a/TablePro/Views/Settings/CustomSlashCommandsSection.swift b/TablePro/Views/Settings/CustomSlashCommandsSection.swift index 0cf68aeb6..34426cdb0 100644 --- a/TablePro/Views/Settings/CustomSlashCommandsSection.swift +++ b/TablePro/Views/Settings/CustomSlashCommandsSection.swift @@ -9,6 +9,7 @@ struct CustomSlashCommandsSection: View { @Bindable var storage: CustomSlashCommandStorage @State private var editing: CustomSlashCommand? @State private var isCreating = false + @State private var saveError: String? var body: some View { Section { @@ -43,13 +44,17 @@ struct CustomSlashCommandsSection: View { initial: command, isCreating: isCreating, onSave: { updated in - if isCreating { - storage.add(updated) - } else { - storage.update(updated) + do { + if isCreating { + try storage.add(updated) + } else { + try storage.update(updated) + } + editing = nil + isCreating = false + } catch { + saveError = error.localizedDescription } - editing = nil - isCreating = false }, onCancel: { editing = nil @@ -57,6 +62,18 @@ struct CustomSlashCommandsSection: View { } ) } + .alert( + String(localized: "Cannot Save Command"), + isPresented: Binding( + get: { saveError != nil }, + set: { if !$0 { saveError = nil } } + ), + presenting: saveError + ) { _ in + Button(String(localized: "OK"), role: .cancel) { saveError = nil } + } message: { message in + Text(message) + } } @ViewBuilder diff --git a/TableProTests/Core/Storage/CustomSlashCommandStorageTests.swift b/TableProTests/Core/Storage/CustomSlashCommandStorageTests.swift new file mode 100644 index 000000000..ffe6c54da --- /dev/null +++ b/TableProTests/Core/Storage/CustomSlashCommandStorageTests.swift @@ -0,0 +1,79 @@ +// +// CustomSlashCommandStorageTests.swift +// TableProTests +// + +import Foundation +@testable import TablePro +import Testing + +@Suite("CustomSlashCommandStorage") +@MainActor +struct CustomSlashCommandStorageTests { + private func makeStorage() -> CustomSlashCommandStorage { + let suiteName = "com.TablePro.tests.CustomSlashCommandStorage.\(UUID().uuidString)" + guard let defaults = UserDefaults(suiteName: suiteName) else { + fatalError("UserDefaults suite creation failed") + } + defaults.removePersistentDomain(forName: suiteName) + return CustomSlashCommandStorage(defaults: defaults) + } + + @Test("add stores a new command") + func addStoresCommand() throws { + let storage = makeStorage() + let command = CustomSlashCommand(name: "review", promptTemplate: "Review {{query}}") + try storage.add(command) + #expect(storage.commands.count == 1) + #expect(storage.commands.first?.name == "review") + } + + @Test("add throws on duplicate name regardless of case") + func addRejectsDuplicateName() throws { + let storage = makeStorage() + try storage.add(CustomSlashCommand(name: "review", promptTemplate: "x")) + + #expect(throws: CustomSlashCommandError.self) { + try storage.add(CustomSlashCommand(name: "REVIEW", promptTemplate: "y")) + } + #expect(storage.commands.count == 1) + } + + @Test("isDuplicate ignores the command being edited") + func isDuplicateExcludesSelf() throws { + let storage = makeStorage() + let command = CustomSlashCommand(name: "review", promptTemplate: "x") + try storage.add(command) + #expect(storage.isDuplicate("review", excluding: command.id) == false) + #expect(storage.isDuplicate("review") == true) + } + + @Test("update rejects rename to an existing command's name") + func updateRejectsCollidingRename() throws { + let storage = makeStorage() + try storage.add(CustomSlashCommand(name: "review", promptTemplate: "x")) + let second = CustomSlashCommand(name: "summarize", promptTemplate: "y") + try storage.add(second) + + var renamed = second + renamed.name = "REVIEW" + + #expect(throws: CustomSlashCommandError.self) { + try storage.update(renamed) + } + #expect(storage.command(named: "summarize") != nil) + } + + @Test("update preserves the same command across rename without collision") + func updateAllowsNonCollidingRename() throws { + let storage = makeStorage() + let original = CustomSlashCommand(name: "review", promptTemplate: "x") + try storage.add(original) + + var renamed = original + renamed.name = "audit" + try storage.update(renamed) + #expect(storage.command(named: "audit") != nil) + #expect(storage.command(named: "review") == nil) + } +} diff --git a/TableProTests/Models/AIConversationTests.swift b/TableProTests/Models/AIConversationTests.swift index 780103210..18cfe4f49 100644 --- a/TableProTests/Models/AIConversationTests.swift +++ b/TableProTests/Models/AIConversationTests.swift @@ -9,11 +9,15 @@ import Testing @Suite("AIConversation") struct AIConversationTests { + private func makeUserTurn(_ text: String) -> ChatTurn { + ChatTurn(role: .user, blocks: [.text(text)]) + } + @Test("updateTitle truncates long content") func updateTitleTruncatesLongContent() { var conv = AIConversation( title: "", - messages: [AIChatMessage(role: .user, content: String(repeating: "a", count: 60))] + messages: [makeUserTurn(String(repeating: "a", count: 60))] ) conv.updateTitle() #expect(conv.title.hasSuffix("...")) @@ -23,9 +27,58 @@ struct AIConversationTests { func updateTitleKeepsShortContent() { var conv = AIConversation( title: "", - messages: [AIChatMessage(role: .user, content: "Short query")] + messages: [makeUserTurn("Short query")] ) conv.updateTitle() #expect(conv.title == "Short query") } + + @Test("New conversations carry the current schema version") + func newConversationsUseCurrentSchemaVersion() { + let conv = AIConversation() + #expect(conv.schemaVersion == AIConversation.currentSchemaVersion) + } + + @Test("Decoding a legacy payload without schemaVersion upgrades to the current version") + func decodingLegacyPayloadUpgradesVersion() throws { + let id = UUID() + let now = ISO8601DateFormatter().string(from: Date()) + let json = """ + { + "id": "\(id.uuidString)", + "title": "Legacy", + "messages": [], + "createdAt": "\(now)", + "updatedAt": "\(now)" + } + """ + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = .iso8601 + let conversation = try decoder.decode(AIConversation.self, from: Data(json.utf8)) + #expect(conversation.id == id) + #expect(conversation.schemaVersion == AIConversation.currentSchemaVersion) + } + + @Test("Round-trip encode and decode preserves the schema version") + func roundTripPreservesSchemaVersion() throws { + let original = AIConversation(messages: [makeUserTurn("hi")]) + let encoder = JSONEncoder() + encoder.dateEncodingStrategy = .iso8601 + let data = try encoder.encode(original) + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = .iso8601 + let decoded = try decoder.decode(AIConversation.self, from: data) + #expect(decoded.schemaVersion == AIConversation.currentSchemaVersion) + } + + @Test("Encoded payload includes the schemaVersion field") + func encodedPayloadIncludesSchemaVersion() throws { + let conv = AIConversation() + let encoder = JSONEncoder() + encoder.dateEncodingStrategy = .iso8601 + let data = try encoder.encode(conv) + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] + let storedVersion = json?["schemaVersion"] as? Int + #expect(storedVersion == AIConversation.currentSchemaVersion) + } } diff --git a/TableProTests/Models/AISettingsTests.swift b/TableProTests/Models/AISettingsTests.swift index f74a46957..19f0d9c1f 100644 --- a/TableProTests/Models/AISettingsTests.swift +++ b/TableProTests/Models/AISettingsTests.swift @@ -30,31 +30,39 @@ struct AISettingsTests { #expect(settings.enabled == false) } - @Test("New installs default to opt-in context (no auto schema/query/results)") - func newInstallsAreOptIn() { + @Test("Default settings include schema and current query, exclude query results") + func defaultsForContextFlags() { let settings = AISettings.default - #expect(settings.includeSchema == false) - #expect(settings.includeCurrentQuery == false) + #expect(settings.includeSchema == true) + #expect(settings.includeCurrentQuery == true) #expect(settings.includeQueryResults == false) } - @Test("Existing users with stored true values keep their auto-context behavior") - func upgradedUsersKeepAutoContext() throws { - let json = #"{"includeSchema": true, "includeCurrentQuery": true, "includeQueryResults": true}"# - let data = Data(json.utf8) + @Test("Memberwise init uses the same context defaults as AISettings.default") + func memberwiseInitMatchesDefault() { + let settings = AISettings() + #expect(settings.includeSchema == AISettings.default.includeSchema) + #expect(settings.includeCurrentQuery == AISettings.default.includeCurrentQuery) + #expect(settings.includeQueryResults == AISettings.default.includeQueryResults) + } + + @Test("Decoding empty JSON yields the same context defaults as AISettings.default") + func decodingEmptyJSONMatchesDefault() throws { + let data = Data("{}".utf8) let settings = try JSONDecoder().decode(AISettings.self, from: data) - #expect(settings.includeSchema == true) - #expect(settings.includeCurrentQuery == true) - #expect(settings.includeQueryResults == true) + #expect(settings.includeSchema == AISettings.default.includeSchema) + #expect(settings.includeCurrentQuery == AISettings.default.includeCurrentQuery) + #expect(settings.includeQueryResults == AISettings.default.includeQueryResults) } - @Test("Decoding without context keys preserves backward-compat true defaults") - func decoderFallbacksAreBackwardCompatible() throws { - let json = "{}" + @Test("Stored false values for context flags are preserved on decode") + func storedFalseFlagsAreRespected() throws { + let json = #"{"includeSchema": false, "includeCurrentQuery": false, "includeQueryResults": false}"# let data = Data(json.utf8) let settings = try JSONDecoder().decode(AISettings.self, from: data) - #expect(settings.includeSchema == true) - #expect(settings.includeCurrentQuery == true) + #expect(settings.includeSchema == false) + #expect(settings.includeCurrentQuery == false) + #expect(settings.includeQueryResults == false) } } From acbab6e3826ef63496d063633f9c68bb72dc577d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 10:43:44 +0700 Subject: [PATCH 10/16] docs(changelog): log JSON value type unification --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 08326e919..3d50ba6e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - AI settings: include schema and current query default to on for new installs, matching the previous decoded fallback - AI Chat: persisted conversations now carry a schema version so future migrations can read older files cleanly - AI Chat: custom slash commands reject duplicate names, including case-insensitive collisions on rename +- Internal: unify JSON value type used by AI tools and MCP wire ## [0.39.1] - 2026-05-08 From 0f6b30cc4c52743101ed99ad094e54afa2a844a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 10:54:25 +0700 Subject: [PATCH 11/16] refactor(inline-suggest): replace Timer debounce with Task.sleep and add source identity guard --- .../Core/AI/Copilot/CopilotInlineSource.swift | 19 ++++-- .../InlineSuggestion/AIChatInlineSource.swift | 3 +- .../InlineSuggestionManager.swift | 59 +++++++++---------- .../InlineSuggestionSource.swift | 24 +++++--- TablePro/Models/AI/AIModels.swift | 17 ++++++ TablePro/Views/Settings/AISettingsView.swift | 7 +++ 6 files changed, 84 insertions(+), 45 deletions(-) diff --git a/TablePro/Core/AI/Copilot/CopilotInlineSource.swift b/TablePro/Core/AI/Copilot/CopilotInlineSource.swift index 7f74e4a7b..133df813a 100644 --- a/TablePro/Core/AI/Copilot/CopilotInlineSource.swift +++ b/TablePro/Core/AI/Copilot/CopilotInlineSource.swift @@ -11,6 +11,7 @@ final class CopilotInlineSource: InlineSuggestionSource { private static let logger = Logger(subsystem: "com.TablePro", category: "CopilotInlineSource") private let documentSync: CopilotDocumentSync + private var pendingCommands: [UUID: LSPCommand] = [:] init(documentSync: CopilotDocumentSync) { self.documentSync = documentSync @@ -67,25 +68,33 @@ final class CopilotInlineSource: InlineSuggestionSource { guard !ghostText.isEmpty else { return nil } - return InlineSuggestion( + let suggestion = InlineSuggestion( text: ghostText, replacementRange: replacementRange, - replacementText: first.insertText, - acceptCommand: first.command + replacementText: first.insertText ) + + if let command = first.command { + pendingCommands[suggestion.id] = command + } + + return suggestion } func didAcceptSuggestion(_ suggestion: InlineSuggestion) { - guard let command = suggestion.acceptCommand else { return } + guard let command = pendingCommands.removeValue(forKey: suggestion.id) else { return } Task { guard let client = CopilotService.shared.client else { return } try? await client.executeCommand(command: command.command, arguments: command.arguments) } } + func didDismissSuggestion(_ suggestion: InlineSuggestion) { + pendingCommands.removeValue(forKey: suggestion.id) + } + // MARK: - Private - /// Convert LSP Position (line, character) to flat character offset in text. private static func offsetForPosition(_ position: LSPPosition, in text: NSString) -> Int { var offset = 0 var line = 0 diff --git a/TablePro/Core/AI/InlineSuggestion/AIChatInlineSource.swift b/TablePro/Core/AI/InlineSuggestion/AIChatInlineSource.swift index 8d29d15ad..a3315322a 100644 --- a/TablePro/Core/AI/InlineSuggestion/AIChatInlineSource.swift +++ b/TablePro/Core/AI/InlineSuggestion/AIChatInlineSource.swift @@ -57,8 +57,7 @@ final class AIChatInlineSource: InlineSuggestionSource { return InlineSuggestion( text: cleaned, replacementRange: nil, - replacementText: cleaned, - acceptCommand: nil + replacementText: cleaned ) } diff --git a/TablePro/Core/AI/InlineSuggestion/InlineSuggestionManager.swift b/TablePro/Core/AI/InlineSuggestion/InlineSuggestionManager.swift index 58ec1683f..8f9a7749c 100644 --- a/TablePro/Core/AI/InlineSuggestion/InlineSuggestionManager.swift +++ b/TablePro/Core/AI/InlineSuggestion/InlineSuggestionManager.swift @@ -8,8 +8,6 @@ import CodeEditSourceEditor import CodeEditTextView import os -/// Manages inline suggestions rendered as ghost text in the SQL editor. -/// Delegates actual suggestion fetching to an InlineSuggestionSource. @MainActor final class InlineSuggestionManager { // MARK: - Properties @@ -21,10 +19,8 @@ final class InlineSuggestionManager { private var sourceResolver: (@MainActor () -> InlineSuggestionSource?)? private var currentSuggestion: InlineSuggestion? private var suggestionOffset: Int = 0 - private var debounceTimer: Timer? - private var currentTask: Task? - private var generationID: UInt = 0 - private let debounceInterval: TimeInterval = 0.5 + private var debounceTask: Task? + private var requestTask: Task? private let _keyEventMonitor = OSAllocatedUnfairLock(initialState: nil) private(set) var isEditorFocused = false private var isUninstalled = false @@ -62,10 +58,10 @@ final class InlineSuggestionManager { isUninstalled = true isEditorFocused = false - debounceTimer?.invalidate() - debounceTimer = nil - currentTask?.cancel() - currentTask = nil + debounceTask?.cancel() + debounceTask = nil + requestTask?.cancel() + requestTask = nil renderer.uninstall() removeKeyEventMonitor() @@ -94,17 +90,19 @@ final class InlineSuggestionManager { // MARK: - Suggestion Scheduling private func scheduleSuggestion() { - debounceTimer?.invalidate() - + debounceTask?.cancel() guard isEnabled() else { return } - let timer = Timer(timeInterval: debounceInterval, repeats: false) { [weak self] _ in - Task { @MainActor [weak self] in - self?.requestSuggestion() + let delay = Duration.milliseconds(AppSettingsManager.shared.ai.clampedInlineSuggestionDebounceMs) + debounceTask = Task { @MainActor [weak self] in + do { + try await Task.sleep(for: delay) + } catch { + return } + guard !Task.isCancelled, let self else { return } + self.requestSuggestion() } - RunLoop.main.add(timer, forMode: .common) - debounceTimer = timer } private func isEnabled() -> Bool { @@ -136,10 +134,6 @@ final class InlineSuggestionManager { let nsText = fullText as NSString let textBefore = nsText.substring(to: min(cursorOffset, nsText.length)) - currentTask?.cancel() - generationID &+= 1 - let myGeneration = generationID - let (line, character) = Self.computeLineCharacter(text: nsText, offset: cursorOffset) let context = SuggestionContext( @@ -150,13 +144,18 @@ final class InlineSuggestionManager { cursorCharacter: character ) - currentTask = Task { @MainActor [weak self] in + let requestedFromIdentity = source.sourceIdentity + + requestTask?.cancel() + requestTask = Task { @MainActor [weak self] in guard let self else { return } self.suggestionOffset = cursorOffset do { guard let suggestion = try await source.requestSuggestion(context: context) else { return } - guard !Task.isCancelled, self.generationID == myGeneration else { return } + guard !Task.isCancelled else { return } + guard let activeIdentity = self.sourceResolver?()?.sourceIdentity, + activeIdentity == requestedFromIdentity else { return } guard !suggestion.text.isEmpty else { return } self.currentSuggestion = suggestion @@ -192,9 +191,10 @@ final class InlineSuggestionManager { } func dismissSuggestion() { - debounceTimer?.invalidate() - currentTask?.cancel() - currentTask = nil + debounceTask?.cancel() + debounceTask = nil + requestTask?.cancel() + requestTask = nil if let suggestion = currentSuggestion { sourceResolver?()?.didDismissSuggestion(suggestion) @@ -220,11 +220,11 @@ final class InlineSuggestionManager { textView.window?.firstResponder === textView else { return event } switch event.keyCode { - case 48: // Tab — accept suggestion + case 48: self.acceptSuggestion() return nil - case 53: // Escape — dismiss suggestion + case 53: self.dismissSuggestion() return event @@ -246,7 +246,6 @@ final class InlineSuggestionManager { // MARK: - Helpers - /// Convert a character offset to 0-based (line, character) pair. static func computeLineCharacter(text: NSString, offset: Int) -> (Int, Int) { var line = 0 var lineStart = 0 @@ -257,7 +256,7 @@ final class InlineSuggestionManager { while i < target { let ch = text.character(at: i) i += 1 - if ch == 0x0A { // newline + if ch == 0x0A { line += 1 lineStart = i } diff --git a/TablePro/Core/AI/InlineSuggestion/InlineSuggestionSource.swift b/TablePro/Core/AI/InlineSuggestion/InlineSuggestionSource.swift index 72f41b14a..02aa9b5e8 100644 --- a/TablePro/Core/AI/InlineSuggestion/InlineSuggestionSource.swift +++ b/TablePro/Core/AI/InlineSuggestion/InlineSuggestionSource.swift @@ -5,7 +5,6 @@ import Foundation -/// Context passed to an inline suggestion source struct SuggestionContext: Sendable { let textBefore: String let fullText: String @@ -14,20 +13,28 @@ struct SuggestionContext: Sendable { let cursorCharacter: Int } -/// A completed inline suggestion -struct InlineSuggestion: Sendable { - /// Text to show as ghost text (only the part after the cursor) +struct InlineSuggestion: Sendable, Identifiable { + let id: UUID let text: String - /// Range to replace on accept (nil = insert at cursor) let replacementRange: NSRange? - /// Full text to insert when accepted (replaces range) let replacementText: String - let acceptCommand: LSPCommand? + + init( + id: UUID = UUID(), + text: String, + replacementRange: NSRange? = nil, + replacementText: String + ) { + self.id = id + self.text = text + self.replacementRange = replacementRange + self.replacementText = replacementText + } } -/// Protocol for inline suggestion sources @MainActor protocol InlineSuggestionSource: AnyObject { + var sourceIdentity: ObjectIdentifier { get } var isAvailable: Bool { get } func requestSuggestion(context: SuggestionContext) async throws -> InlineSuggestion? func didShowSuggestion(_ suggestion: InlineSuggestion) @@ -36,6 +43,7 @@ protocol InlineSuggestionSource: AnyObject { } extension InlineSuggestionSource { + var sourceIdentity: ObjectIdentifier { ObjectIdentifier(self) } func didShowSuggestion(_ suggestion: InlineSuggestion) {} func didAcceptSuggestion(_ suggestion: InlineSuggestion) {} func didDismissSuggestion(_ suggestion: InlineSuggestion) {} diff --git a/TablePro/Models/AI/AIModels.swift b/TablePro/Models/AI/AIModels.swift index 48cf10d9a..7463b05a4 100644 --- a/TablePro/Models/AI/AIModels.swift +++ b/TablePro/Models/AI/AIModels.swift @@ -186,6 +186,7 @@ struct AISettings: Codable, Equatable, Sendable { var providers: [AIProviderConfig] var activeProviderID: UUID? var inlineSuggestionsEnabled: Bool + var inlineSuggestionDebounceMs: Int var includeSchema: Bool var includeCurrentQuery: Bool var includeQueryResults: Bool @@ -193,11 +194,15 @@ struct AISettings: Codable, Equatable, Sendable { var defaultConnectionPolicy: AIConnectionPolicy var chatMode: AIChatMode + static let defaultInlineSuggestionDebounceMs: Int = 500 + static let inlineSuggestionDebounceRange: ClosedRange = 100...3_000 + static let `default` = AISettings( enabled: true, providers: [], activeProviderID: nil, inlineSuggestionsEnabled: false, + inlineSuggestionDebounceMs: AISettings.defaultInlineSuggestionDebounceMs, includeSchema: true, includeCurrentQuery: true, includeQueryResults: false, @@ -211,6 +216,7 @@ struct AISettings: Codable, Equatable, Sendable { providers: [AIProviderConfig] = [], activeProviderID: UUID? = nil, inlineSuggestionsEnabled: Bool = false, + inlineSuggestionDebounceMs: Int = AISettings.defaultInlineSuggestionDebounceMs, includeSchema: Bool = true, includeCurrentQuery: Bool = true, includeQueryResults: Bool = false, @@ -222,6 +228,7 @@ struct AISettings: Codable, Equatable, Sendable { self.providers = providers self.activeProviderID = activeProviderID self.inlineSuggestionsEnabled = inlineSuggestionsEnabled + self.inlineSuggestionDebounceMs = inlineSuggestionDebounceMs self.includeSchema = includeSchema self.includeCurrentQuery = includeCurrentQuery self.includeQueryResults = includeQueryResults @@ -236,6 +243,9 @@ struct AISettings: Codable, Equatable, Sendable { providers = try container.decodeIfPresent([AIProviderConfig].self, forKey: .providers) ?? [] activeProviderID = try container.decodeIfPresent(UUID.self, forKey: .activeProviderID) inlineSuggestionsEnabled = try container.decodeIfPresent(Bool.self, forKey: .inlineSuggestionsEnabled) ?? false + inlineSuggestionDebounceMs = try container.decodeIfPresent( + Int.self, forKey: .inlineSuggestionDebounceMs + ) ?? AISettings.defaultInlineSuggestionDebounceMs includeSchema = try container.decodeIfPresent(Bool.self, forKey: .includeSchema) ?? true includeCurrentQuery = try container.decodeIfPresent(Bool.self, forKey: .includeCurrentQuery) ?? true includeQueryResults = try container.decodeIfPresent(Bool.self, forKey: .includeQueryResults) ?? false @@ -256,6 +266,13 @@ struct AISettings: Codable, Equatable, Sendable { var hasCopilotConfigured: Bool { providers.contains(where: { $0.type == .copilot }) } + + var clampedInlineSuggestionDebounceMs: Int { + min( + max(inlineSuggestionDebounceMs, AISettings.inlineSuggestionDebounceRange.lowerBound), + AISettings.inlineSuggestionDebounceRange.upperBound + ) + } } struct AITokenUsage: Codable, Equatable, Sendable { diff --git a/TablePro/Views/Settings/AISettingsView.swift b/TablePro/Views/Settings/AISettingsView.swift index cce281cb4..e5cd8c595 100644 --- a/TablePro/Views/Settings/AISettingsView.swift +++ b/TablePro/Views/Settings/AISettingsView.swift @@ -225,6 +225,13 @@ struct AISettingsView: View { .help(settings.hasActiveProvider ? "" : String(localized: "Configure an active provider to enable inline suggestions.")) + Stepper( + String(format: String(localized: "Debounce: %d ms"), settings.inlineSuggestionDebounceMs), + value: $settings.inlineSuggestionDebounceMs, + in: AISettings.inlineSuggestionDebounceRange, + step: 50 + ) + .disabled(!settings.inlineSuggestionsEnabled) } header: { Text("Inline Suggestions") } footer: { From a8da24d24d0f6d819017dfbdec43f16fa4422c75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 11:01:30 +0700 Subject: [PATCH 12/16] refactor(copilot): harden lifecycle and binary trust --- .../AI/Copilot/CopilotBinaryManager.swift | 13 ++++- .../Core/AI/Copilot/CopilotDocumentSync.swift | 12 ++--- .../Copilot/CopilotIdleStopController.swift | 14 ++++-- .../Core/AI/Copilot/CopilotInlineSource.swift | 2 +- ...ext.swift => CopilotPreambleBuilder.swift} | 13 ++--- TablePro/Core/AI/Copilot/CopilotService.swift | 14 +++++- TablePro/Core/LSP/LSPTransport.swift | 49 ++++++++++++------- .../Views/Editor/SQLEditorCoordinator.swift | 2 +- 8 files changed, 77 insertions(+), 42 deletions(-) rename TablePro/Core/AI/Copilot/{CopilotSchemaContext.swift => CopilotPreambleBuilder.swift} (80%) diff --git a/TablePro/Core/AI/Copilot/CopilotBinaryManager.swift b/TablePro/Core/AI/Copilot/CopilotBinaryManager.swift index 854cca987..a24553445 100644 --- a/TablePro/Core/AI/Copilot/CopilotBinaryManager.swift +++ b/TablePro/Core/AI/Copilot/CopilotBinaryManager.swift @@ -4,6 +4,7 @@ // import CryptoKit +import Darwin import Foundation import os @@ -122,7 +123,8 @@ actor CopilotBinaryManager { ofItemAtPath: binaryExecutablePath ) - // Store version for future reference + stripQuarantineAttribute(at: binaryExecutablePath) + if let version = json["version"] as? String { let versionFile = baseDirectory.appendingPathComponent("version.txt") try? version.write(to: versionFile, atomically: true, encoding: .utf8) @@ -141,6 +143,15 @@ actor CopilotBinaryManager { baseDirectory.appendingPathComponent("copilot-language-server").path } + private func stripQuarantineAttribute(at path: String) { + let removed = path.withCString { removexattr($0, "com.apple.quarantine", 0) } + guard removed != 0 else { return } + let err = errno + if err != ENOATTR { + Self.logger.warning("Failed to remove quarantine xattr: errno=\(err)") + } + } + private var platform: String { #if arch(arm64) return "darwin-arm64" diff --git a/TablePro/Core/AI/Copilot/CopilotDocumentSync.swift b/TablePro/Core/AI/Copilot/CopilotDocumentSync.swift index 62c223d0e..d6b9b48e5 100644 --- a/TablePro/Core/AI/Copilot/CopilotDocumentSync.swift +++ b/TablePro/Core/AI/Copilot/CopilotDocumentSync.swift @@ -13,7 +13,7 @@ final class CopilotDocumentSync { private static let logger = Logger(subsystem: "com.TablePro", category: "CopilotDocumentSync") private let documentManager = LSPDocumentManager() - let schemaContext = CopilotSchemaContext() + let preambleBuilder = CopilotPreambleBuilder() private var currentURI: String? private var serverSyncedURIs: Set = [] private var pendingText: [String: String] = [:] @@ -23,7 +23,7 @@ final class CopilotDocumentSync { func documentURI(for tabID: UUID) -> String { if let existing = uriMap[tabID] { return existing } - let fileURL = CopilotSchemaContext.contextDirectory.appendingPathComponent("query-\(nextID).sql") + let fileURL = CopilotPreambleBuilder.contextDirectory.appendingPathComponent("query-\(nextID).sql") nextID += 1 let uri = fileURL.absoluteString uriMap[tabID] = uri @@ -39,7 +39,7 @@ final class CopilotDocumentSync { /// Register document locally. Does NOT send to server. func ensureDocumentOpen(tabID: UUID, text: String, languageId: String = "sql") { let uri = documentURI(for: tabID) - let fullText = schemaContext.prependToText(text) + let fullText = preambleBuilder.prependToText(text) if !documentManager.isOpen(uri) { _ = documentManager.openDocument(uri: uri, languageId: languageId, text: fullText) } @@ -55,7 +55,7 @@ final class CopilotDocumentSync { } let uri = documentURI(for: tabID) - let fullText = schemaContext.prependToText(text) + let fullText = preambleBuilder.prependToText(text) ensureDocumentOpen(tabID: tabID, text: text, languageId: languageId) guard let client = CopilotService.shared.client else { return } @@ -70,7 +70,7 @@ final class CopilotDocumentSync { serverSyncedURIs.insert(uri) if let pending = pendingText.removeValue(forKey: uri) { - let pendingFull = schemaContext.prependToText(pending) + let pendingFull = preambleBuilder.prependToText(pending) if let (versioned, changes) = documentManager.changeDocument(uri: uri, newText: pendingFull) { await client.didChangeDocument(uri: versioned.uri, version: versioned.version, changes: changes) } @@ -92,7 +92,7 @@ final class CopilotDocumentSync { pendingText[uri] = newText return } - let fullText = schemaContext.prependToText(newText) + let fullText = preambleBuilder.prependToText(newText) guard let client = CopilotService.shared.client else { return } guard let (versioned, changes) = documentManager.changeDocument(uri: uri, newText: fullText) else { return } await client.didChangeDocument(uri: versioned.uri, version: versioned.version, changes: changes) diff --git a/TablePro/Core/AI/Copilot/CopilotIdleStopController.swift b/TablePro/Core/AI/Copilot/CopilotIdleStopController.swift index 47f2000e5..84bc7cdde 100644 --- a/TablePro/Core/AI/Copilot/CopilotIdleStopController.swift +++ b/TablePro/Core/AI/Copilot/CopilotIdleStopController.swift @@ -41,11 +41,19 @@ final class CopilotIdleStopController { task = nil return } + let timeout = self.timeout + let isAuthenticated = self.isAuthenticated + let isRunning = self.isRunning + let onStopRequest = self.onStopRequest task = Task { - try? await Task.sleep(for: self.timeout) + do { + try await Task.sleep(for: timeout) + } catch { + return + } guard !Task.isCancelled else { return } - guard !self.isAuthenticated(), self.isRunning() else { return } - await self.onStopRequest() + guard !isAuthenticated(), isRunning() else { return } + await onStopRequest() } } diff --git a/TablePro/Core/AI/Copilot/CopilotInlineSource.swift b/TablePro/Core/AI/Copilot/CopilotInlineSource.swift index 133df813a..9f666f00a 100644 --- a/TablePro/Core/AI/Copilot/CopilotInlineSource.swift +++ b/TablePro/Core/AI/Copilot/CopilotInlineSource.swift @@ -26,7 +26,7 @@ final class CopilotInlineSource: InlineSuggestionSource { guard let docInfo = documentSync.currentDocumentInfo() else { return nil } let editorSettings = AppSettingsManager.shared.editor - let preambleOffset = documentSync.schemaContext.preambleLineCount + let preambleOffset = documentSync.preambleBuilder.preambleLineCount let params = LSPInlineCompletionParams( textDocument: LSPVersionedTextDocumentIdentifier(uri: docInfo.uri, version: docInfo.version), position: LSPPosition(line: context.cursorLine + preambleOffset, character: context.cursorCharacter), diff --git a/TablePro/Core/AI/Copilot/CopilotSchemaContext.swift b/TablePro/Core/AI/Copilot/CopilotPreambleBuilder.swift similarity index 80% rename from TablePro/Core/AI/Copilot/CopilotSchemaContext.swift rename to TablePro/Core/AI/Copilot/CopilotPreambleBuilder.swift index f7cf9986b..37f74b0b6 100644 --- a/TablePro/Core/AI/Copilot/CopilotSchemaContext.swift +++ b/TablePro/Core/AI/Copilot/CopilotPreambleBuilder.swift @@ -1,5 +1,5 @@ // -// CopilotSchemaContext.swift +// CopilotPreambleBuilder.swift // TablePro // @@ -7,26 +7,20 @@ import Foundation import os import TableProPluginKit -/// Builds a schema preamble (SQL comments with table/column info) to prepend -/// to document text sent to the Copilot language server. Pure data, no LSP concerns. @MainActor -final class CopilotSchemaContext { - private static let logger = Logger(subsystem: "com.TablePro", category: "CopilotSchemaContext") +final class CopilotPreambleBuilder { + private static let logger = Logger(subsystem: "com.TablePro", category: "CopilotPreambleBuilder") - /// Directory for query document URIs static let contextDirectory: URL = { let appSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first ?? FileManager.default.temporaryDirectory return appSupport.appendingPathComponent("TablePro/copilot-context", isDirectory: true) }() - /// The schema preamble text (SQL comments with table/column info) private(set) var preamble: String = "" - /// Number of newline characters in the preamble (for cursor offset adjustment) private(set) var preambleLineCount: Int = 0 - /// Build the preamble from cached schema data func buildPreamble( schemaProvider: SQLSchemaProvider, databaseName: String, @@ -75,7 +69,6 @@ final class CopilotSchemaContext { Self.logger.info("Copilot schema preamble: \(tables.count) tables, \(self.preambleLineCount) lines") } - /// Prepend the preamble to user text for sending to Copilot func prependToText(_ text: String) -> String { guard !preamble.isEmpty else { return text } return preamble + text diff --git a/TablePro/Core/AI/Copilot/CopilotService.swift b/TablePro/Core/AI/Copilot/CopilotService.swift index aa24d9b18..79f6e398e 100644 --- a/TablePro/Core/AI/Copilot/CopilotService.swift +++ b/TablePro/Core/AI/Copilot/CopilotService.swift @@ -126,7 +126,19 @@ final class CopilotService { serverGeneration += 1 if let client = lspClient { - try? await client.shutdown() + let shutdownCompleted = await withTaskGroup(of: Bool.self, returning: Bool.self) { group in + group.addTask { (try? await client.shutdown()) != nil } + group.addTask { + try? await Task.sleep(for: .seconds(10)) + return false + } + let first = await group.next() ?? false + group.cancelAll() + return first + } + if !shutdownCompleted { + Self.logger.warning("Copilot shutdown RPC timed out, forcing exit") + } await client.exit() } await transport?.stop() diff --git a/TablePro/Core/LSP/LSPTransport.swift b/TablePro/Core/LSP/LSPTransport.swift index 0cbabff4f..4548c8b47 100644 --- a/TablePro/Core/LSP/LSPTransport.swift +++ b/TablePro/Core/LSP/LSPTransport.swift @@ -45,7 +45,7 @@ actor LSPTransport { private var notificationHandlers: [String: @Sendable (Data) -> Void] = [:] private var requestHandlers: [String: @Sendable (Data) -> Any?] = [:] private var deferredRequestHandlers: [String: @Sendable (Data, Int) -> Void] = [:] - private var readerQueue: DispatchQueue? + private var readerTask: Task? // MARK: - Lifecycle @@ -91,32 +91,42 @@ actor LSPTransport { try proc.run() - let queue = DispatchQueue(label: "com.TablePro.LSPTransport.reader") - self.readerQueue = queue let handle = stdout.fileHandleForReading - queue.async { [weak self] in - self?.readLoopSync(handle: handle) + readerTask = Task { [weak self] in + await self?.runReadLoop(handle: handle) } Self.logger.info("LSP transport started: \(executablePath)") } - func stop() { + func stop() async { let pending = pendingRequests pendingRequests.removeAll() for (_, continuation) in pending { continuation.resume(throwing: LSPTransportError.requestCancelled) } + readerTask?.cancel() + readerTask = nil + + if let stdinHandle = stdinPipe?.fileHandleForWriting { + try? stdinHandle.close() + } + if let stdoutHandle = stdoutPipe?.fileHandleForReading { + try? stdoutHandle.close() + } + if let stderrHandle = stderrPipe?.fileHandleForReading { + stderrHandle.readabilityHandler = nil + try? stderrHandle.close() + } + if let process, process.isRunning { process.terminate() } process = nil stdinPipe = nil stdoutPipe = nil - stderrPipe?.fileHandleForReading.readabilityHandler = nil stderrPipe = nil - readerQueue = nil Self.logger.info("LSP transport stopped") } @@ -214,18 +224,19 @@ actor LSPTransport { handle.write(data) } - /// Blocking read loop that runs on a dedicated DispatchQueue to avoid blocking the actor executor. - nonisolated private func readLoopSync(handle: FileHandle) { + private func runReadLoop(handle: FileHandle) async { var buffer = Data() - - while true { - let chunk = handle.availableData - guard !chunk.isEmpty else { break } // EOF - buffer.append(chunk) - - while let (messageData, _) = Self.parseMessageFromBuffer(&buffer) { - let data = messageData - Task { [weak self] in await self?.dispatchMessage(data) } + do { + for try await byte in handle.bytes { + if Task.isCancelled { return } + buffer.append(byte) + while let (messageData, _) = Self.parseMessageFromBuffer(&buffer) { + dispatchMessage(messageData) + } + } + } catch { + if !Task.isCancelled { + Self.logger.debug("LSP read loop ended: \(error.localizedDescription)") } } } diff --git a/TablePro/Views/Editor/SQLEditorCoordinator.swift b/TablePro/Views/Editor/SQLEditorCoordinator.swift index 3214a2dd0..99b5c0652 100644 --- a/TablePro/Views/Editor/SQLEditorCoordinator.swift +++ b/TablePro/Views/Editor/SQLEditorCoordinator.swift @@ -325,7 +325,7 @@ final class SQLEditorCoordinator: TextViewCoordinator, TextViewDelegate { Task { if let provider = capturedSchemaProvider, let dbType = capturedDBType { - await sync.schemaContext.buildPreamble( + await sync.preambleBuilder.buildPreamble( schemaProvider: provider, databaseName: dbName, databaseType: dbType From 6e8ea2d971b3a0de20ef2e258cf0984f5d98bdb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 11:55:09 +0700 Subject: [PATCH 13/16] test(ai-chat): update tests for new ChatTurn block-based API and ViewModel split --- .../Core/MongoDB/MongoDBSrvHostTests.swift | 4 +++ .../DatabaseConnectionAIRulesTests.swift | 1 + .../Models/RightPanelStateTests.swift | 11 ++++--- .../AIChatViewModelActionTests.swift | 31 +++++++++---------- .../AIChatViewModelMentionsTests.swift | 4 +-- 5 files changed, 28 insertions(+), 23 deletions(-) diff --git a/TableProTests/Core/MongoDB/MongoDBSrvHostTests.swift b/TableProTests/Core/MongoDB/MongoDBSrvHostTests.swift index 10fde6687..832c65ddc 100644 --- a/TableProTests/Core/MongoDB/MongoDBSrvHostTests.swift +++ b/TableProTests/Core/MongoDB/MongoDBSrvHostTests.swift @@ -3,6 +3,8 @@ // TableProTests // +#if canImport(CLibMongoc) + import Foundation @testable import TablePro import Testing @@ -38,3 +40,5 @@ struct MongoDBSrvHostTests { #expect(MongoDBConnection.stripPort(fromSrvHost: "") == "") } } + +#endif diff --git a/TableProTests/Models/DatabaseConnectionAIRulesTests.swift b/TableProTests/Models/DatabaseConnectionAIRulesTests.swift index cff9cf799..b03c05fe6 100644 --- a/TableProTests/Models/DatabaseConnectionAIRulesTests.swift +++ b/TableProTests/Models/DatabaseConnectionAIRulesTests.swift @@ -4,6 +4,7 @@ // import Foundation +import TableProPluginKit import Testing @testable import TablePro diff --git a/TableProTests/Models/RightPanelStateTests.swift b/TableProTests/Models/RightPanelStateTests.swift index 92e98336a..ccfafe56f 100644 --- a/TableProTests/Models/RightPanelStateTests.swift +++ b/TableProTests/Models/RightPanelStateTests.swift @@ -19,16 +19,17 @@ struct RightPanelStateTests { state.teardown() } - @Test("teardown nils schemaProvider on aiViewModel") + @Test("teardown clears aiViewModel session data") @MainActor - func teardown_nilsSchemaProvider() { + func teardown_clearsAIViewModelSession() { let state = RightPanelState() - state.aiViewModel.schemaProvider = SQLSchemaProvider() - #expect(state.aiViewModel.schemaProvider != nil) + state.aiViewModel.connection = TestFixtures.makeConnection(type: .mysql) + #expect(state.aiViewModel.connection != nil) state.teardown() - #expect(state.aiViewModel.schemaProvider == nil) + #expect(state.aiViewModel.connection == nil) + #expect(state.aiViewModel.messages.isEmpty) } @Test("teardown nils onSave closure") diff --git a/TableProTests/ViewModels/AIChatViewModelActionTests.swift b/TableProTests/ViewModels/AIChatViewModelActionTests.swift index d2705c208..621e187be 100644 --- a/TableProTests/ViewModels/AIChatViewModelActionTests.swift +++ b/TableProTests/ViewModels/AIChatViewModelActionTests.swift @@ -13,7 +13,6 @@ import Testing @Suite("AIChatViewModel Action Dispatch") @MainActor struct AIChatViewModelActionTests { - // MARK: - handleFixError @Test("handleFixError with default connection uses SQL query language") @@ -26,8 +25,8 @@ struct AIChatViewModelActionTests { #expect(vm.messages.count >= 1) let userMessage = vm.messages.first { $0.role == .user } #expect(userMessage != nil) - #expect(userMessage?.content.contains("SQL query") == true) - #expect(userMessage?.content.contains("```sql") == true) + #expect(userMessage?.plainText.contains("SQL query") == true) + #expect(userMessage?.plainText.contains("```sql") == true) } @Test("handleFixError with MongoDB connection uses JavaScript language") @@ -39,8 +38,8 @@ struct AIChatViewModelActionTests { let userMessage = vm.messages.first { $0.role == .user } #expect(userMessage != nil) - #expect(userMessage?.content.contains("MongoDB query") == true) - #expect(userMessage?.content.contains("```javascript") == true) + #expect(userMessage?.plainText.contains("MongoDB query") == true) + #expect(userMessage?.plainText.contains("```javascript") == true) } @Test("handleFixError with Redis connection uses bash language") @@ -52,8 +51,8 @@ struct AIChatViewModelActionTests { let userMessage = vm.messages.first { $0.role == .user } #expect(userMessage != nil) - #expect(userMessage?.content.contains("Redis command") == true) - #expect(userMessage?.content.contains("```bash") == true) + #expect(userMessage?.plainText.contains("Redis command") == true) + #expect(userMessage?.plainText.contains("```bash") == true) } @Test("handleFixError includes query and error text verbatim") @@ -67,8 +66,8 @@ struct AIChatViewModelActionTests { vm.handleFixError(query: query, error: error) let userMessage = vm.messages.first { $0.role == .user } - #expect(userMessage?.content.contains(query) == true) - #expect(userMessage?.content.contains(error) == true) + #expect(userMessage?.plainText.contains(query) == true) + #expect(userMessage?.plainText.contains(error) == true) } // MARK: - handleExplainSelection @@ -84,9 +83,9 @@ struct AIChatViewModelActionTests { let userMessage = vm.messages.first { $0.role == .user } #expect(userMessage != nil) - #expect(userMessage?.content.contains("Explain this SQL query") == true) - #expect(userMessage?.content.contains(selectedText) == true) - #expect(userMessage?.content.contains("```sql") == true) + #expect(userMessage?.plainText.contains("Explain this SQL query") == true) + #expect(userMessage?.plainText.contains(selectedText) == true) + #expect(userMessage?.plainText.contains("```sql") == true) } @Test("handleExplainSelection with empty text is a no-op") @@ -115,9 +114,9 @@ struct AIChatViewModelActionTests { let userMessage = vm.messages.first { $0.role == .user } #expect(userMessage != nil) - #expect(userMessage?.content.contains("Optimize this SQL query") == true) - #expect(userMessage?.content.contains(selectedText) == true) - #expect(userMessage?.content.contains("```sql") == true) + #expect(userMessage?.plainText.contains("Optimize this SQL query") == true) + #expect(userMessage?.plainText.contains(selectedText) == true) + #expect(userMessage?.plainText.contains("```sql") == true) } @Test("handleOptimizeSelection with empty text is a no-op") @@ -152,6 +151,6 @@ struct AIChatViewModelActionTests { // There may also be assistant/error messages from startStreaming. let userMessages = vm.messages.filter { $0.role == .user } #expect(userMessages.count == 1) - #expect(userMessages.first?.content.contains("SELECT 2") == true) + #expect(userMessages.first?.plainText.contains("SELECT 2") == true) } } diff --git a/TableProTests/ViewModels/AIChatViewModelMentionsTests.swift b/TableProTests/ViewModels/AIChatViewModelMentionsTests.swift index e39c6cd0d..4469eb186 100644 --- a/TableProTests/ViewModels/AIChatViewModelMentionsTests.swift +++ b/TableProTests/ViewModels/AIChatViewModelMentionsTests.swift @@ -104,7 +104,7 @@ struct AIChatViewModelMentionsTests { } @Test("resolveTurnForWire expands attachments into the text block") - func resolveTurnForWireExpands() { + func resolveTurnForWireExpands() async { let vm = AIChatViewModel() vm.connection = TestFixtures.makeConnection(type: .mysql) let raw = ChatTurn(role: .user, blocks: [ @@ -112,7 +112,7 @@ struct AIChatViewModelMentionsTests { .attachment(.currentQuery(text: "SELECT * FROM Customer")) ]) - let wire = vm.resolveTurnForWire(raw) + let wire = await vm.resolveTurnForWire(raw) #expect(wire.id == raw.id) #expect(wire.plainText.contains("Explain")) From 9d6c657f6e0fc91d9f5e8aaf621767c0f3736089 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 12:01:10 +0700 Subject: [PATCH 14/16] docs(changelog): adopt Keep a Changelog 1.1.0 format link --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d50ba6e6..13cc36f5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ All notable changes to TablePro will be documented in this file. -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +The format is based on [Keep a Changelog 1.1.0](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] From 59e7dad0c034ed4d546b390c68ee84f94efd4131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 12:02:59 +0700 Subject: [PATCH 15/16] docs(claude-md): document Conventional Commits scopes and atomic API change rule --- CLAUDE.md | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index bfd15a66a..9a2d81a53 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -198,7 +198,7 @@ When approaching limits: extract into `TypeName+Category.swift` extension files These are **non-negotiable** — never skip them: -1. **CHANGELOG.md**: Update under `[Unreleased]` section (Added/Fixed/Changed) for new features and notable changes. Do **not** add a "Fixed" entry for fixing something that is itself still unreleased. "Fixed" entries are only for bugs in already-released features. Documentation-only changes (`docs/`) do **not** need a CHANGELOG entry. Each entry is one line, user-facing, with no file paths, class names, or method signatures; reference IDs go in parens at the end: `(#1234)`. +1. **CHANGELOG.md**: Follow [Keep a Changelog 1.1.0](https://keepachangelog.com/en/1.1.0/). Update under `[Unreleased]` using the canonical sections: `Added`, `Changed`, `Deprecated`, `Removed`, `Fixed`, `Security`. Do **not** add a "Fixed" entry for fixing something that is itself still unreleased; fold the fix into the Added or Changed entry instead. Documentation-only changes (`docs/`, `CLAUDE.md`, `CHANGELOG.md` formatting) do **not** need a CHANGELOG entry. Each entry is one line, user-facing, with no file paths, class names, or method signatures; reference IDs go in parens at the end: `(#1234)`. 2. **Localization**: Use `String(localized:)` for new user-facing strings in computed properties, AppKit code, alerts, and error descriptions. SwiftUI view literals (`Text("literal")`, `Button("literal")`) auto-localize. Do NOT localize technical terms (font names, database types, SQL keywords, encoding names). Never use `String(localized:)` with string interpolation — `String(localized: "Preview \(name)")` creates a dynamic key that never matches the strings catalog. Use `String(format: String(localized: "Preview %@"), name)`. @@ -212,7 +212,20 @@ These are **non-negotiable** — never skip them: 5. **Lint after changes**: Run `swiftlint lint --strict` to verify compliance. -6. **Commit messages**: Follow [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/). Single line only, no description body. Examples: `docs: fix installation instructions for unsigned app`, `fix: prevent crash on empty query result`, `feat: add CSV export`. +6. **Commit messages**: Follow [Conventional Commits 1.0.0](https://www.conventionalcommits.org/en/v1.0.0/). Single line only, no description body. Format: `(): `. Scope is optional but preferred when the change has a clear domain. Use `!` after type or scope for breaking changes (e.g. `refactor(ai-providers)!: drop OpenAI legacy completion endpoint`). + + **Types**: `feat`, `fix`, `refactor`, `perf`, `test`, `docs`, `build`, `ci`, `chore`, `style`, `revert`. + + **Canonical scopes** (reuse these instead of inventing new ones): + - AI: `ai-chat`, `ai-providers`, `mcp`, `copilot`, `inline-suggest` + - App UI: `editor`, `datagrid`, `tabs`, `coordinator`, `sidebar`, `connections`, `connection-form`, `welcome`, `settings`, `toolbar`, `hig` + - Infra: `ssh`, `ios`, `windows`, `perf`, `launch`, `plugins` + - Plugins: `plugin-` (e.g. `plugin-mongodb`, `plugin-redis`, `plugin-clickhouse`) + - Docs and release: `changelog`, `claude-md`, `docs`, `ci`, `release` + + **Examples**: `feat(ai-chat): add /refactor slash command`, `fix(editor): prevent crash on empty query result`, `refactor(mcp): migrate pairing store to actor`, `docs(changelog): adopt Keep a Changelog 1.1.0`. + +7. **Atomic API changes**: When you rename, remove, or change a public type, property, or function signature, update every caller AND every test in the same commit. Do not split a rename from "fix tests for rename" into separate commits; the in-between commit is broken, fails CI, and pollutes `git bisect`. If a refactor crosses too many files for one reviewable commit, narrow the change first or stage it behind a typealias the renaming commit removes. ## Performance Pitfalls From 0d78c126c94ba3e7dd34a731c4d14ae8864cb9ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Fri, 8 May 2026 12:25:01 +0700 Subject: [PATCH 16/16] revert(ai-chat): restore IntelligenceFocusBorder for Apple Intelligence parity --- CHANGELOG.md | 1 - TablePro/Views/AIChat/ChatComposerView.swift | 70 ++++++++++++++++++-- 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13cc36f5f..654ed4377 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- AI Chat: composer focus ring uses the standard macOS accent stroke instead of a colored gradient - AI inline suggestions: debounce now uses structured Swift concurrency, and the delay is configurable via the `inlineSuggestionDebounceMs` setting (default 500ms) - Copilot LSP shutdown caps at 10 seconds, closes pipes explicitly, and strips the quarantine attribute from the downloaded binary - AI Chat: streaming view model split into focused extensions backed by a single `streamingState` enum diff --git a/TablePro/Views/AIChat/ChatComposerView.swift b/TablePro/Views/AIChat/ChatComposerView.swift index a191065a4..f06568f57 100644 --- a/TablePro/Views/AIChat/ChatComposerView.swift +++ b/TablePro/Views/AIChat/ChatComposerView.swift @@ -55,12 +55,15 @@ struct ChatComposerView: View { return shape .fill(Color(nsColor: .textBackgroundColor)) .overlay { - shape.stroke( - isFocused ? Color.accentColor : Color(nsColor: .separatorColor), - lineWidth: isFocused ? 1 : 0.5 - ) + if isFocused { + IntelligenceFocusBorder(shape: shape) + .transition(.opacity) + } else { + shape.stroke(Color(nsColor: .separatorColor), lineWidth: 0.5) + .transition(.opacity) + } } - .animation(.default, value: isFocused) + .animation(.easeOut(duration: 0.25), value: isFocused) } private var popoverBinding: Binding { @@ -108,3 +111,60 @@ struct ChatComposerView: View { mentionState.reset() } } + +private enum IntelligenceShimmer { + static let palette: [Color] = [ + Color(red: 1.0, green: 0.404, blue: 0.471), + Color(red: 1.0, green: 0.553, blue: 0.443), + Color(red: 1.0, green: 0.729, blue: 0.443), + Color(red: 0.961, green: 0.725, blue: 0.918), + Color(red: 0.776, green: 0.525, blue: 1.0), + Color(red: 0.737, green: 0.510, blue: 0.953), + Color(red: 0.553, green: 0.624, blue: 1.0) + ] + + struct Layer: Identifiable { + let id: Int + let lineWidth: CGFloat + let blur: CGFloat + let opacity: Double + } + + static let layers: [Layer] = [ + Layer(id: 0, lineWidth: 1.5, blur: 2, opacity: 1.0), + Layer(id: 1, lineWidth: 5, blur: 4, opacity: 0.75), + Layer(id: 2, lineWidth: 9, blur: 10, opacity: 0.5), + Layer(id: 3, lineWidth: 14, blur: 16, opacity: 0.35) + ] + + static func generateStops() -> [Gradient.Stop] { + let count = palette.count + var stops = palette.enumerated().map { index, color in + Gradient.Stop(color: color, location: Double(index) / Double(count)) + } + if let first = palette.first { + stops.append(Gradient.Stop(color: first, location: 1.0)) + } + return stops + } +} + +private struct IntelligenceFocusBorder: View { + let shape: S + + @State private var stops: [Gradient.Stop] = IntelligenceShimmer.generateStops() + + var body: some View { + ZStack { + ForEach(IntelligenceShimmer.layers) { layer in + shape + .stroke( + AngularGradient(gradient: Gradient(stops: stops), center: .center), + lineWidth: layer.lineWidth + ) + .blur(radius: layer.blur) + .opacity(layer.opacity) + } + } + } +}