Skip to content

Commit fac4df6

Browse files
authored
Add support for sampling (#119)
* Add UnitInterval type * Add test coverage for UnitInterval * Add support for sampling * Add integration tests for sampling
1 parent ddde66d commit fac4df6

File tree

7 files changed

+1584
-0
lines changed

7 files changed

+1584
-0
lines changed

README.md

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,89 @@ for message in messages {
189189
}
190190
```
191191

192+
### Sampling
193+
194+
Sampling allows servers to request LLM completions through the client,
195+
enabling agentic behaviors while maintaining human-in-the-loop control.
196+
Clients register a handler to process incoming sampling requests from servers.
197+
198+
> [!TIP]
199+
> Sampling requests flow from **server to client**,
200+
> not client to server.
201+
> This enables servers to request AI assistance
202+
> while clients maintain control over model access and user approval.
203+
204+
```swift
205+
// Register a sampling handler in the client
206+
await client.withSamplingHandler { parameters in
207+
// Review the sampling request (human-in-the-loop step 1)
208+
print("Server requests completion for: \(parameters.messages)")
209+
210+
// Optionally modify the request based on user input
211+
var messages = parameters.messages
212+
if let systemPrompt = parameters.systemPrompt {
213+
print("System prompt: \(systemPrompt)")
214+
}
215+
216+
// Sample from your LLM (this is where you'd call your AI service)
217+
let completion = try await callYourLLMService(
218+
messages: messages,
219+
maxTokens: parameters.maxTokens,
220+
temperature: parameters.temperature
221+
)
222+
223+
// Review the completion (human-in-the-loop step 2)
224+
print("LLM generated: \(completion)")
225+
// User can approve, modify, or reject the completion here
226+
227+
// Return the result to the server
228+
return CreateSamplingMessage.Result(
229+
model: "your-model-name",
230+
stopReason: .endTurn,
231+
role: .assistant,
232+
content: .text(completion)
233+
)
234+
}
235+
```
236+
237+
The sampling flow follows these steps:
238+
239+
```mermaid
240+
sequenceDiagram
241+
participant S as MCP Server
242+
participant C as MCP Client
243+
participant U as User/Human
244+
participant L as LLM Service
245+
246+
Note over S,L: Server-initiated sampling request
247+
S->>C: sampling/createMessage request
248+
Note right of S: Server needs AI assistance<br/>for decision or content
249+
250+
Note over C,U: Human-in-the-loop review #1
251+
C->>U: Show sampling request
252+
U->>U: Review & optionally modify<br/>messages, system prompt
253+
U->>C: Approve request
254+
255+
Note over C,L: Client handles LLM interaction
256+
C->>L: Send messages to LLM
257+
L->>C: Return completion
258+
259+
Note over C,U: Human-in-the-loop review #2
260+
C->>U: Show LLM completion
261+
U->>U: Review & optionally modify<br/>or reject completion
262+
U->>C: Approve completion
263+
264+
Note over C,S: Return result to server
265+
C->>S: sampling/createMessage response
266+
Note left of C: Contains model used,<br/>stop reason, final content
267+
268+
Note over S: Server continues with<br/>AI-assisted result
269+
```
270+
271+
This human-in-the-loop design ensures that users
272+
maintain control over what the LLM sees and generates,
273+
even when servers initiate the requests.
274+
192275
### Error Handling
193276

194277
Handle common client errors:
@@ -504,6 +587,49 @@ server.withMethodHandler(GetPrompt.self) { params in
504587
}
505588
```
506589

590+
### Sampling
591+
592+
Servers can request LLM completions from clients through sampling. This enables agentic behaviors where servers can ask for AI assistance while maintaining human oversight.
593+
594+
> [!NOTE]
595+
> The current implementation provides the correct API design for sampling, but requires bidirectional communication support in the transport layer. This feature will be fully functional when bidirectional transport support is added.
596+
597+
```swift
598+
// Enable sampling capability in server
599+
let server = Server(
600+
name: "MyModelServer",
601+
version: "1.0.0",
602+
capabilities: .init(
603+
sampling: .init(), // Enable sampling capability
604+
tools: .init(listChanged: true)
605+
)
606+
)
607+
608+
// Request sampling from the client (conceptual - requires bidirectional transport)
609+
do {
610+
let result = try await server.requestSampling(
611+
messages: [
612+
Sampling.Message(role: .user, content: .text("Analyze this data and suggest next steps"))
613+
],
614+
systemPrompt: "You are a helpful data analyst",
615+
maxTokens: 150,
616+
temperature: 0.7
617+
)
618+
619+
// Use the LLM completion in your server logic
620+
print("LLM suggested: \(result.content)")
621+
622+
} catch {
623+
print("Sampling request failed: \(error)")
624+
}
625+
```
626+
627+
Sampling enables powerful agentic workflows:
628+
- **Decision-making**: Ask the LLM to choose between options
629+
- **Content generation**: Request drafts for user approval
630+
- **Data analysis**: Get AI insights on complex data
631+
- **Multi-step reasoning**: Chain AI completions with tool calls
632+
507633
#### Initialize Hook
508634

509635
Control client connections with an initialize hook:
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/// A value constrained to the range 0.0 to 1.0, inclusive.
2+
///
3+
/// `UnitInterval` represents a normalized value that is guaranteed to be within
4+
/// the unit interval [0, 1]. This type is commonly used for representing
5+
/// priorities in sampling request model preferences.
6+
///
7+
/// The type provides safe initialization that returns `nil` for values outside
8+
/// the valid range, ensuring that all instances contain valid unit interval values.
9+
///
10+
/// - Example:
11+
/// ```swift
12+
/// let zero: UnitInterval = 0 // 0.0
13+
/// let half = UnitInterval(0.5)! // 0.5
14+
/// let one: UnitInterval = 1.0 // 1.0
15+
/// let invalid = UnitInterval(1.5) // nil
16+
/// ```
17+
public struct UnitInterval: Hashable, Sendable {
18+
private let value: Double
19+
20+
/// Creates a unit interval value from a `Double`.
21+
///
22+
/// - Parameter value: A double value that must be in the range 0.0...1.0
23+
/// - Returns: A `UnitInterval` instance if the value is valid, `nil` otherwise
24+
///
25+
/// - Example:
26+
/// ```swift
27+
/// let valid = UnitInterval(0.75) // Optional(0.75)
28+
/// let invalid = UnitInterval(-0.1) // nil
29+
/// let boundary = UnitInterval(1.0) // Optional(1.0)
30+
/// ```
31+
public init?(_ value: Double) {
32+
guard (0...1).contains(value) else { return nil }
33+
self.value = value
34+
}
35+
36+
/// The underlying double value.
37+
///
38+
/// This property provides access to the raw double value that is guaranteed
39+
/// to be within the range [0, 1].
40+
///
41+
/// - Returns: The double value between 0.0 and 1.0, inclusive
42+
public var doubleValue: Double { value }
43+
}
44+
45+
// MARK: - Comparable
46+
47+
extension UnitInterval: Comparable {
48+
public static func < (lhs: UnitInterval, rhs: UnitInterval) -> Bool {
49+
lhs.value < rhs.value
50+
}
51+
}
52+
53+
// MARK: - CustomStringConvertible
54+
55+
extension UnitInterval: CustomStringConvertible {
56+
public var description: String { "\(value)" }
57+
}
58+
59+
// MARK: - Codable
60+
61+
extension UnitInterval: Codable {
62+
public init(from decoder: Decoder) throws {
63+
let container = try decoder.singleValueContainer()
64+
let doubleValue = try container.decode(Double.self)
65+
guard let interval = UnitInterval(doubleValue) else {
66+
throw DecodingError.dataCorrupted(
67+
DecodingError.Context(
68+
codingPath: decoder.codingPath,
69+
debugDescription: "Value \(doubleValue) is not in range 0...1")
70+
)
71+
}
72+
self = interval
73+
}
74+
75+
public func encode(to encoder: Encoder) throws {
76+
var container = encoder.singleValueContainer()
77+
try container.encode(value)
78+
}
79+
}
80+
81+
// MARK: - ExpressibleByFloatLiteral
82+
83+
extension UnitInterval: ExpressibleByFloatLiteral {
84+
/// Creates a unit interval from a floating-point literal.
85+
///
86+
/// This initializer allows you to create `UnitInterval` instances using
87+
/// floating-point literals. The literal value must be in the range [0, 1]
88+
/// or a runtime error will occur.
89+
///
90+
/// - Parameter value: A floating-point literal between 0.0 and 1.0
91+
///
92+
/// - Warning: This initializer will crash if the literal is outside the valid range.
93+
/// Use the failable initializer `init(_:)` for runtime validation.
94+
///
95+
/// - Example:
96+
/// ```swift
97+
/// let quarter: UnitInterval = 0.25
98+
/// let half: UnitInterval = 0.5
99+
/// ```
100+
public init(floatLiteral value: Double) {
101+
self.init(value)!
102+
}
103+
}
104+
105+
// MARK: - ExpressibleByIntegerLiteral
106+
107+
extension UnitInterval: ExpressibleByIntegerLiteral {
108+
/// Creates a unit interval from an integer literal.
109+
///
110+
/// This initializer allows you to create `UnitInterval` instances using
111+
/// integer literals. Only the values 0 and 1 are valid.
112+
///
113+
/// - Parameter value: An integer literal, either 0 or 1
114+
///
115+
/// - Warning: This initializer will crash if the literal is outside the valid range.
116+
/// Use the failable initializer `init(_:)` for runtime validation.
117+
///
118+
/// - Example:
119+
/// ```swift
120+
/// let zero: UnitInterval = 0
121+
/// let one: UnitInterval = 1
122+
/// ```
123+
public init(integerLiteral value: Int) {
124+
self.init(Double(value))!
125+
}
126+
}

Sources/MCP/Client/Client.swift

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,45 @@ public actor Client {
583583
return (content: result.content, isError: result.isError)
584584
}
585585

586+
// MARK: - Sampling
587+
588+
/// Register a handler for sampling requests from servers
589+
///
590+
/// Sampling allows servers to request LLM completions through the client,
591+
/// enabling sophisticated agentic behaviors while maintaining human-in-the-loop control.
592+
///
593+
/// The sampling flow follows these steps:
594+
/// 1. Server sends a `sampling/createMessage` request to the client
595+
/// 2. Client reviews the request and can modify it (via this handler)
596+
/// 3. Client samples from an LLM (via this handler)
597+
/// 4. Client reviews the completion (via this handler)
598+
/// 5. Client returns the result to the server
599+
///
600+
/// - Parameter handler: A closure that processes sampling requests and returns completions
601+
/// - Returns: Self for method chaining
602+
/// - SeeAlso: https://modelcontextprotocol.io/docs/concepts/sampling#how-sampling-works
603+
@discardableResult
604+
public func withSamplingHandler(
605+
_ handler: @escaping @Sendable (CreateSamplingMessage.Parameters) async throws ->
606+
CreateSamplingMessage.Result
607+
) -> Self {
608+
// Note: This would require extending the client architecture to handle incoming requests from servers.
609+
// The current MCP Swift SDK architecture assumes clients only send requests to servers,
610+
// but sampling requires bidirectional communication where servers can send requests to clients.
611+
//
612+
// A full implementation would need:
613+
// 1. Request handlers in the client (similar to how servers handle requests)
614+
// 2. Bidirectional transport support
615+
// 3. Request/response correlation for server-to-client requests
616+
//
617+
// For now, this serves as the correct API design for when bidirectional support is added.
618+
619+
// This would register the handler similar to how servers register method handlers:
620+
// methodHandlers[CreateSamplingMessage.name] = TypedRequestHandler(handler)
621+
622+
return self
623+
}
624+
586625
// MARK: -
587626

588627
private func handleResponse(_ response: Response<AnyMethod>) async {

0 commit comments

Comments
 (0)