diff --git a/Sources/MCP/Client/Client.swift b/Sources/MCP/Client/Client.swift index bd2d2dfb..d4d121b5 100644 --- a/Sources/MCP/Client/Client.swift +++ b/Sources/MCP/Client/Client.swift @@ -7,6 +7,29 @@ import class Foundation.JSONEncoder /// Model Context Protocol client public actor Client { + /// The client configuration + public struct Configuration: Hashable, Codable, Sendable { + /// The default configuration. + public static let `default` = Configuration(strict: false) + + /// The strict configuration. + public static let strict = Configuration(strict: true) + + /// When strict mode is enabled, the client: + /// - Requires server capabilities to be initialized before making requests + /// - Rejects all requests that require capabilities before initialization + /// + /// While the MCP specification requires servers to respond to initialize requests + /// with their capabilities, some implementations may not follow this. + /// Disabling strict mode allows the client to be more lenient with non-compliant + /// servers, though this may lead to undefined behavior. + public var strict: Bool + + public init(strict: Bool = false) { + self.strict = strict + } + } + /// Implementation information public struct Info: Hashable, Codable, Sendable { /// The client name @@ -73,6 +96,8 @@ public actor Client { /// The client capabilities public var capabilities: Client.Capabilities + /// The client configuration + public var configuration: Configuration /// The server capabilities private var serverCapabilities: Server.Capabilities? @@ -131,10 +156,12 @@ public actor Client { public init( name: String, - version: String + version: String, + configuration: Configuration = .default ) { self.clientInfo = Client.Info(name: name, version: version) self.capabilities = Capabilities() + self.configuration = configuration } /// Connect to the server using the given transport @@ -294,7 +321,7 @@ public actor Client { public func getPrompt(name: String, arguments: [String: Value]? = nil) async throws -> (description: String?, messages: [Prompt.Message]) { - _ = try checkCapability(\.prompts, "Prompts") + try validateServerCapability(\.prompts, "Prompts") let request = GetPrompt.request(.init(name: name, arguments: arguments)) let result = try await send(request) return (description: result.description, messages: result.messages) @@ -303,7 +330,7 @@ public actor Client { public func listPrompts(cursor: String? = nil) async throws -> (prompts: [Prompt], nextCursor: String?) { - _ = try checkCapability(\.prompts, "Prompts") + try validateServerCapability(\.prompts, "Prompts") let request: Request if let cursor = cursor { request = ListPrompts.request(.init(cursor: cursor)) @@ -317,7 +344,7 @@ public actor Client { // MARK: - Resources public func readResource(uri: String) async throws -> [Resource.Content] { - _ = try checkCapability(\.resources, "Resources") + try validateServerCapability(\.resources, "Resources") let request = ReadResource.request(.init(uri: uri)) let result = try await send(request) return result.contents @@ -326,7 +353,7 @@ public actor Client { public func listResources(cursor: String? = nil) async throws -> ( resources: [Resource], nextCursor: String? ) { - _ = try checkCapability(\.resources, "Resources") + try validateServerCapability(\.resources, "Resources") let request: Request if let cursor = cursor { request = ListResources.request(.init(cursor: cursor)) @@ -338,7 +365,7 @@ public actor Client { } public func subscribeToResource(uri: String) async throws { - _ = try checkCapability(\.resources?.subscribe, "Resource subscription") + try validateServerCapability(\.resources?.subscribe, "Resource subscription") let request = ResourceSubscribe.request(.init(uri: uri)) _ = try await send(request) } @@ -346,7 +373,7 @@ public actor Client { // MARK: - Tools public func listTools(cursor: String? = nil) async throws -> [Tool] { - _ = try checkCapability(\.tools, "Tools") + try validateServerCapability(\.tools, "Tools") let request: Request if let cursor = cursor { request = ListTools.request(.init(cursor: cursor)) @@ -360,7 +387,7 @@ public actor Client { public func callTool(name: String, arguments: [String: Value]? = nil) async throws -> ( content: [Tool.Content], isError: Bool? ) { - _ = try checkCapability(\.tools, "Tools") + try validateServerCapability(\.tools, "Tools") let request = CallTool.request(.init(name: name, arguments: arguments)) let result = try await send(request) return (content: result.content, isError: result.isError) @@ -410,15 +437,21 @@ public actor Client { // MARK: - - private func checkCapability(_ keyPath: KeyPath, _ name: String) - throws -> T + /// Validate the server capabilities. + /// Throws an error if the client is configured to be strict and the capability is not supported. + private func validateServerCapability( + _ keyPath: KeyPath, + _ name: String + ) + throws { - guard let capabilities = serverCapabilities else { - throw Error.methodNotFound("Server capabilities not initialized") - } - guard let value = capabilities[keyPath: keyPath] else { - throw Error.methodNotFound("\(name) is not supported by the server") + if configuration.strict { + guard let capabilities = serverCapabilities else { + throw Error.methodNotFound("Server capabilities not initialized") + } + guard capabilities[keyPath: keyPath] != nil else { + throw Error.methodNotFound("\(name) is not supported by the server") + } } - return value } } diff --git a/Tests/MCPTests/ClientTests.swift b/Tests/MCPTests/ClientTests.swift index a2cba6b4..d34c6018 100644 --- a/Tests/MCPTests/ClientTests.swift +++ b/Tests/MCPTests/ClientTests.swift @@ -1,3 +1,4 @@ +import Foundation import Testing @testable import MCP @@ -121,4 +122,108 @@ struct ClientTests { #expect(Bool(false), "Expected MCP.Error") } } + + @Test("Strict configuration - capabilities check") + func testStrictConfiguration() async throws { + let transport = MockTransport() + let config = Client.Configuration.strict + let client = Client(name: "TestClient", version: "1.0", configuration: config) + + try await client.connect(transport: transport) + + // Create a task for listPrompts + let promptsTask = Task { + do { + _ = try await client.listPrompts() + #expect(Bool(false), "Expected listPrompts to fail in strict mode") + } catch let error as Error { + if case Error.methodNotFound = error { + #expect(Bool(true)) + } else { + #expect(Bool(false), "Expected methodNotFound error, got \(error)") + } + } catch { + #expect(Bool(false), "Expected MCP.Error") + } + } + + // Give it a short time to execute the task + try await Task.sleep(for: .milliseconds(50)) + + // Cancel the task if it's still running + promptsTask.cancel() + + // Disconnect client + await client.disconnect() + try await Task.sleep(for: .milliseconds(50)) + } + + @Test("Non-strict configuration - capabilities check") + func testNonStrictConfiguration() async throws { + let transport = MockTransport() + let config = Client.Configuration.default + let client = Client(name: "TestClient", version: "1.0", configuration: config) + + try await client.connect(transport: transport) + + // Wait a bit for any setup to complete + try await Task.sleep(for: .milliseconds(10)) + + // Send the listPrompts request and immediately provide an error response + let promptsTask = Task { + do { + // Start the request + try await Task.sleep(for: .seconds(1)) + + // Get the last sent message and extract the request ID + if let lastMessage = await transport.sentMessages.last, + let data = lastMessage.data(using: .utf8), + let decodedRequest = try? JSONDecoder().decode( + Request.self, from: data) + { + + // Create an error response with the same ID + let errorResponse = Response( + id: decodedRequest.id, + error: Error.methodNotFound("Test: Prompts capability not available") + ) + try await transport.queueResponse(errorResponse) + + // Try the request now that we have a response queued + do { + _ = try await client.listPrompts() + #expect(Bool(false), "Expected listPrompts to fail in non-strict mode") + } catch let error as Error { + if case Error.methodNotFound = error { + #expect(Bool(true)) + } else { + #expect(Bool(false), "Expected methodNotFound error, got \(error)") + } + } catch { + #expect(Bool(false), "Expected MCP.Error") + } + } + } catch { + // Ignore task cancellation + if !(error is CancellationError) { + throw error + } + } + } + + // Wait for the task to complete or timeout + let timeoutTask = Task { + try await Task.sleep(for: .milliseconds(500)) + promptsTask.cancel() + } + + // Wait for the task to complete + _ = await promptsTask.result + + // Cancel the timeout task + timeoutTask.cancel() + + // Disconnect client + await client.disconnect() + } }