Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions Sources/MCP/Client/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@ import class Foundation.JSONEncoder

/// Model Context Protocol client
public actor Client {
/// The client configuration
public struct Configuration: Hashable, Codable, Sendable {
/// The default configuration.
public static let `default` = Configuration(strict: false)

/// The strict configuration.
public static let strict = Configuration(strict: true)

/// When strict mode is enabled, the client:
/// - Requires server capabilities to be initialized before making requests
/// - Rejects all requests that require capabilities before initialization
///
/// While the MCP specification requires servers to respond to initialize requests
/// with their capabilities, some implementations may not follow this.
/// Disabling strict mode allows the client to be more lenient with non-compliant
/// servers, though this may lead to undefined behavior.
public var strict: Bool

public init(strict: Bool = false) {
self.strict = strict
}
}

/// Implementation information
public struct Info: Hashable, Codable, Sendable {
/// The client name
Expand Down Expand Up @@ -73,6 +96,8 @@ public actor Client {

/// The client capabilities
public var capabilities: Client.Capabilities
/// The client configuration
public var configuration: Configuration

/// The server capabilities
private var serverCapabilities: Server.Capabilities?
Expand Down Expand Up @@ -131,10 +156,12 @@ public actor Client {

public init(
name: String,
version: String
version: String,
configuration: Configuration = .default
) {
self.clientInfo = Client.Info(name: name, version: version)
self.capabilities = Capabilities()
self.configuration = configuration
}

/// Connect to the server using the given transport
Expand Down Expand Up @@ -294,7 +321,7 @@ public actor Client {
public func getPrompt(name: String, arguments: [String: Value]? = nil) async throws
-> (description: String?, messages: [Prompt.Message])
{
_ = try checkCapability(\.prompts, "Prompts")
try validateServerCapability(\.prompts, "Prompts")
let request = GetPrompt.request(.init(name: name, arguments: arguments))
let result = try await send(request)
return (description: result.description, messages: result.messages)
Expand All @@ -303,7 +330,7 @@ public actor Client {
public func listPrompts(cursor: String? = nil) async throws
-> (prompts: [Prompt], nextCursor: String?)
{
_ = try checkCapability(\.prompts, "Prompts")
try validateServerCapability(\.prompts, "Prompts")
let request: Request<ListPrompts>
if let cursor = cursor {
request = ListPrompts.request(.init(cursor: cursor))
Expand All @@ -317,7 +344,7 @@ public actor Client {
// MARK: - Resources

public func readResource(uri: String) async throws -> [Resource.Content] {
_ = try checkCapability(\.resources, "Resources")
try validateServerCapability(\.resources, "Resources")
let request = ReadResource.request(.init(uri: uri))
let result = try await send(request)
return result.contents
Expand All @@ -326,7 +353,7 @@ public actor Client {
public func listResources(cursor: String? = nil) async throws -> (
resources: [Resource], nextCursor: String?
) {
_ = try checkCapability(\.resources, "Resources")
try validateServerCapability(\.resources, "Resources")
let request: Request<ListResources>
if let cursor = cursor {
request = ListResources.request(.init(cursor: cursor))
Expand All @@ -338,15 +365,15 @@ public actor Client {
}

public func subscribeToResource(uri: String) async throws {
_ = try checkCapability(\.resources?.subscribe, "Resource subscription")
try validateServerCapability(\.resources?.subscribe, "Resource subscription")
let request = ResourceSubscribe.request(.init(uri: uri))
_ = try await send(request)
}

// MARK: - Tools

public func listTools(cursor: String? = nil) async throws -> [Tool] {
_ = try checkCapability(\.tools, "Tools")
try validateServerCapability(\.tools, "Tools")
let request: Request<ListTools>
if let cursor = cursor {
request = ListTools.request(.init(cursor: cursor))
Expand All @@ -360,7 +387,7 @@ public actor Client {
public func callTool(name: String, arguments: [String: Value]? = nil) async throws -> (
content: [Tool.Content], isError: Bool?
) {
_ = try checkCapability(\.tools, "Tools")
try validateServerCapability(\.tools, "Tools")
let request = CallTool.request(.init(name: name, arguments: arguments))
let result = try await send(request)
return (content: result.content, isError: result.isError)
Expand Down Expand Up @@ -410,15 +437,21 @@ public actor Client {

// MARK: -

private func checkCapability<T>(_ keyPath: KeyPath<Server.Capabilities, T?>, _ name: String)
throws -> T
/// Validate the server capabilities.
/// Throws an error if the client is configured to be strict and the capability is not supported.
private func validateServerCapability<T>(
_ keyPath: KeyPath<Server.Capabilities, T?>,
_ name: String
)
throws
{
guard let capabilities = serverCapabilities else {
throw Error.methodNotFound("Server capabilities not initialized")
}
guard let value = capabilities[keyPath: keyPath] else {
throw Error.methodNotFound("\(name) is not supported by the server")
if configuration.strict {
guard let capabilities = serverCapabilities else {
throw Error.methodNotFound("Server capabilities not initialized")
}
guard capabilities[keyPath: keyPath] != nil else {
throw Error.methodNotFound("\(name) is not supported by the server")
}
}
return value
}
}
105 changes: 105 additions & 0 deletions Tests/MCPTests/ClientTests.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Foundation
import Testing

@testable import MCP
Expand Down Expand Up @@ -121,4 +122,108 @@ struct ClientTests {
#expect(Bool(false), "Expected MCP.Error")
}
}

@Test("Strict configuration - capabilities check")
func testStrictConfiguration() async throws {
let transport = MockTransport()
let config = Client.Configuration.strict
let client = Client(name: "TestClient", version: "1.0", configuration: config)

try await client.connect(transport: transport)

// Create a task for listPrompts
let promptsTask = Task<Void, Swift.Error> {
do {
_ = try await client.listPrompts()
#expect(Bool(false), "Expected listPrompts to fail in strict mode")
} catch let error as Error {
if case Error.methodNotFound = error {
#expect(Bool(true))
} else {
#expect(Bool(false), "Expected methodNotFound error, got \(error)")
}
} catch {
#expect(Bool(false), "Expected MCP.Error")
}
}

// Give it a short time to execute the task
try await Task.sleep(for: .milliseconds(50))

// Cancel the task if it's still running
promptsTask.cancel()

// Disconnect client
await client.disconnect()
try await Task.sleep(for: .milliseconds(50))
}

@Test("Non-strict configuration - capabilities check")
func testNonStrictConfiguration() async throws {
let transport = MockTransport()
let config = Client.Configuration.default
let client = Client(name: "TestClient", version: "1.0", configuration: config)

try await client.connect(transport: transport)

// Wait a bit for any setup to complete
try await Task.sleep(for: .milliseconds(10))

// Send the listPrompts request and immediately provide an error response
let promptsTask = Task {
do {
// Start the request
try await Task.sleep(for: .seconds(1))

// Get the last sent message and extract the request ID
if let lastMessage = await transport.sentMessages.last,
let data = lastMessage.data(using: .utf8),
let decodedRequest = try? JSONDecoder().decode(
Request<ListPrompts>.self, from: data)
{

// Create an error response with the same ID
let errorResponse = Response<ListPrompts>(
id: decodedRequest.id,
error: Error.methodNotFound("Test: Prompts capability not available")
)
try await transport.queueResponse(errorResponse)

// Try the request now that we have a response queued
do {
_ = try await client.listPrompts()
#expect(Bool(false), "Expected listPrompts to fail in non-strict mode")
} catch let error as Error {
if case Error.methodNotFound = error {
#expect(Bool(true))
} else {
#expect(Bool(false), "Expected methodNotFound error, got \(error)")
}
} catch {
#expect(Bool(false), "Expected MCP.Error")
}
}
} catch {
// Ignore task cancellation
if !(error is CancellationError) {
throw error
}
}
}

// Wait for the task to complete or timeout
let timeoutTask = Task {
try await Task.sleep(for: .milliseconds(500))
promptsTask.cancel()
}

// Wait for the task to complete
_ = await promptsTask.result

// Cancel the timeout task
timeoutTask.cancel()

// Disconnect client
await client.disconnect()
}
}