Skip to content

Commit 15139cc

Browse files
authored
Add tools support to sampling (#976)
1 parent ad86b53 commit 15139cc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1623
-425
lines changed

samples/ChatWithTools/Program.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.Extensions.AI;
22
using Microsoft.Extensions.Logging;
3+
using ModelContextProtocol;
34
using ModelContextProtocol.Client;
45
using OpenAI;
56
using OpenTelemetry;

samples/EverythingServer/Tools/SampleLlmTool.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public static async Task<string> SampleLLM(
1717
var samplingParams = CreateRequestSamplingParams(prompt ?? string.Empty, "sampleLLM", maxTokens);
1818
var sampleResult = await server.SampleAsync(samplingParams, cancellationToken);
1919

20-
return $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}";
20+
return $"LLM sampling result: {sampleResult.Content.OfType<TextContentBlock>().FirstOrDefault()?.Text}";
2121
}
2222

2323
private static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100)
@@ -27,7 +27,7 @@ private static CreateMessageRequestParams CreateRequestSamplingParams(string con
2727
Messages = [new SamplingMessage
2828
{
2929
Role = Role.User,
30-
Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" },
30+
Content = [new TextContentBlock { Text = $"Resource {uri} context: {context}" }],
3131
}],
3232
SystemPrompt = "You are a helpful test server.",
3333
MaxTokens = maxTokens,

samples/TestServerWithHosting/Tools/SampleLlmTool.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public static async Task<string> SampleLLM(
2020
var samplingParams = CreateRequestSamplingParams(prompt ?? string.Empty, "sampleLLM", maxTokens);
2121
var sampleResult = await thisServer.SampleAsync(samplingParams, cancellationToken);
2222

23-
return $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}";
23+
return $"LLM sampling result: {sampleResult.Content.OfType<TextContentBlock>().FirstOrDefault()?.Text}";
2424
}
2525

2626
private static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100)
@@ -30,7 +30,7 @@ private static CreateMessageRequestParams CreateRequestSamplingParams(string con
3030
Messages = [new SamplingMessage
3131
{
3232
Role = Role.User,
33-
Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" },
33+
Content = [new TextContentBlock { Text = $"Resource {uri} context: {context}" }],
3434
}],
3535
SystemPrompt = "You are a helpful test server.",
3636
MaxTokens = maxTokens,

src/ModelContextProtocol.Core/AIContentExtensions.cs

Lines changed: 208 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
using Microsoft.Extensions.AI;
2+
using ModelContextProtocol.Client;
23
using ModelContextProtocol.Protocol;
34
#if !NET
45
using System.Runtime.InteropServices;
56
#endif
67
using System.Text.Json;
8+
using System.Text.Json.Nodes;
79

810
namespace ModelContextProtocol;
911

@@ -16,6 +18,140 @@ namespace ModelContextProtocol;
1618
/// </remarks>
1719
public static class AIContentExtensions
1820
{
21+
/// <summary>
22+
/// Creates a sampling handler for use with <see cref="McpClientHandlers.SamplingHandler"/> that will
23+
/// satisfy sampling requests using the specified <see cref="IChatClient"/>.
24+
/// </summary>
25+
/// <param name="chatClient">The <see cref="IChatClient"/> with which to satisfy sampling requests.</param>
26+
/// <returns>The created handler delegate that can be assigned to <see cref="McpClientHandlers.SamplingHandler"/>.</returns>
27+
/// <remarks>
28+
/// <para>
29+
/// This method creates a function that converts MCP message requests into chat client calls, enabling
30+
/// an MCP client to generate text or other content using an actual AI model via the provided chat client.
31+
/// </para>
32+
/// <para>
33+
/// The handler can process text messages, image messages, resource messages, and tool use/results as defined in the
34+
/// Model Context Protocol.
35+
/// </para>
36+
/// </remarks>
37+
/// <exception cref="ArgumentNullException"><paramref name="chatClient"/> is <see langword="null"/>.</exception>
38+
public static Func<CreateMessageRequestParams?, IProgress<ProgressNotificationValue>, CancellationToken, ValueTask<CreateMessageResult>> CreateSamplingHandler(
39+
this IChatClient chatClient)
40+
{
41+
Throw.IfNull(chatClient);
42+
43+
return async (requestParams, progress, cancellationToken) =>
44+
{
45+
Throw.IfNull(requestParams);
46+
47+
var (messages, options) = ToChatClientArguments(requestParams);
48+
var progressToken = requestParams.ProgressToken;
49+
50+
List<ChatResponseUpdate> updates = [];
51+
await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
52+
{
53+
updates.Add(update);
54+
55+
if (progressToken is not null)
56+
{
57+
progress.Report(new() { Progress = updates.Count });
58+
}
59+
}
60+
61+
ChatResponse? chatResponse = updates.ToChatResponse();
62+
ChatMessage? lastMessage = chatResponse.Messages.LastOrDefault();
63+
64+
IList<ContentBlock>? contents = lastMessage?.Contents.Select(c => c.ToContentBlock()).ToList();
65+
if (contents is not { Count: > 0 })
66+
{
67+
(contents ??= []).Add(new TextContentBlock() { Text = "" });
68+
}
69+
70+
return new()
71+
{
72+
Model = chatResponse.ModelId ?? "",
73+
StopReason =
74+
chatResponse.FinishReason == ChatFinishReason.Stop ? CreateMessageResult.StopReasonEndTurn :
75+
chatResponse.FinishReason == ChatFinishReason.Length ? CreateMessageResult.StopReasonMaxTokens :
76+
chatResponse.FinishReason == ChatFinishReason.ToolCalls ? CreateMessageResult.StopReasonToolUse :
77+
chatResponse.FinishReason.ToString(),
78+
Meta = chatResponse.AdditionalProperties?.ToJsonObject(),
79+
Role = lastMessage?.Role == ChatRole.User ? Role.User : Role.Assistant,
80+
Content = contents,
81+
};
82+
83+
static (IList<ChatMessage> Messages, ChatOptions? Options) ToChatClientArguments(CreateMessageRequestParams requestParams)
84+
{
85+
ChatOptions? options = null;
86+
87+
if (requestParams.MaxTokens is int maxTokens)
88+
{
89+
(options ??= new()).MaxOutputTokens = maxTokens;
90+
}
91+
92+
if (requestParams.Temperature is float temperature)
93+
{
94+
(options ??= new()).Temperature = temperature;
95+
}
96+
97+
if (requestParams.StopSequences is { } stopSequences)
98+
{
99+
(options ??= new()).StopSequences = stopSequences.ToArray();
100+
}
101+
102+
if (requestParams.SystemPrompt is { } systemPrompt)
103+
{
104+
(options ??= new()).Instructions = systemPrompt;
105+
}
106+
107+
if (requestParams.Tools is { } tools)
108+
{
109+
foreach (var tool in tools)
110+
{
111+
((options ??= new()).Tools ??= []).Add(new ToolAIFunctionDeclaration(tool));
112+
}
113+
114+
if (options.Tools is { Count: > 0 } && requestParams.ToolChoice is { } toolChoice)
115+
{
116+
options.ToolMode = toolChoice.Mode switch
117+
{
118+
ToolChoice.ModeAuto => ChatToolMode.Auto,
119+
ToolChoice.ModeRequired => ChatToolMode.RequireAny,
120+
ToolChoice.ModeNone => ChatToolMode.None,
121+
_ => null,
122+
};
123+
}
124+
}
125+
126+
List<ChatMessage> messages = [];
127+
foreach (var sm in requestParams.Messages)
128+
{
129+
if (sm.Content?.Select(b => b.ToAIContent()).OfType<AIContent>().ToList() is { Count: > 0 } aiContents)
130+
{
131+
messages.Add(new ChatMessage(sm.Role is Role.Assistant ? ChatRole.Assistant : ChatRole.User, aiContents));
132+
}
133+
}
134+
135+
return (messages, options);
136+
}
137+
};
138+
}
139+
140+
/// <summary>Converts the specified dictionary to a <see cref="JsonObject"/>.</summary>
141+
internal static JsonObject? ToJsonObject(this IReadOnlyDictionary<string, object?> properties) =>
142+
JsonSerializer.SerializeToNode(properties, McpJsonUtilities.JsonContext.Default.IReadOnlyDictionaryStringObject) as JsonObject;
143+
144+
internal static AdditionalPropertiesDictionary ToAdditionalProperties(this JsonObject obj)
145+
{
146+
AdditionalPropertiesDictionary d = [];
147+
foreach (var kvp in obj)
148+
{
149+
d.Add(kvp.Key, kvp.Value);
150+
}
151+
152+
return d;
153+
}
154+
19155
/// <summary>
20156
/// Converts a <see cref="PromptMessage"/> to a <see cref="ChatMessage"/> object.
21157
/// </summary>
@@ -99,7 +235,7 @@ public static IList<PromptMessage> ToPromptMessages(this ChatMessage chatMessage
99235
{
100236
if (content is TextContent or DataContent)
101237
{
102-
messages.Add(new PromptMessage { Role = r, Content = content.ToContent() });
238+
messages.Add(new PromptMessage { Role = r, Content = content.ToContentBlock() });
103239
}
104240
}
105241

@@ -122,13 +258,31 @@ public static IList<PromptMessage> ToPromptMessages(this ChatMessage chatMessage
122258
AIContent? ac = content switch
123259
{
124260
TextContentBlock textContent => new TextContent(textContent.Text),
261+
125262
ImageContentBlock imageContent => new DataContent(Convert.FromBase64String(imageContent.Data), imageContent.MimeType),
263+
126264
AudioContentBlock audioContent => new DataContent(Convert.FromBase64String(audioContent.Data), audioContent.MimeType),
265+
127266
EmbeddedResourceBlock resourceContent => resourceContent.Resource.ToAIContent(),
267+
268+
ToolUseContentBlock toolUse => FunctionCallContent.CreateFromParsedArguments(toolUse.Input, toolUse.Id, toolUse.Name,
269+
static json => JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.IDictionaryStringObject)),
270+
271+
ToolResultContentBlock toolResult => new FunctionResultContent(
272+
toolResult.ToolUseId,
273+
toolResult.Content.Count == 1 ? toolResult.Content[0].ToAIContent() : toolResult.Content.Select(c => c.ToAIContent()).OfType<AIContent>().ToList())
274+
{
275+
Exception = toolResult.IsError is true ? new() : null,
276+
},
277+
128278
_ => null,
129279
};
130280

131-
ac?.RawRepresentation = content;
281+
if (ac is not null)
282+
{
283+
ac.RawRepresentation = content;
284+
ac.AdditionalProperties = content.Meta?.ToAdditionalProperties();
285+
}
132286

133287
return ac;
134288
}
@@ -200,8 +354,12 @@ public static IList<AIContent> ToAIContents(this IEnumerable<ResourceContents> c
200354
return [.. contents.Select(ToAIContent)];
201355
}
202356

203-
internal static ContentBlock ToContent(this AIContent content) =>
204-
content switch
357+
/// <summary>Creates a new <see cref="ContentBlock"/> from the content of an <see cref="AIContent"/>.</summary>
358+
/// <param name="content">The <see cref="AIContent"/> to convert.</param>
359+
/// <returns>The created <see cref="ContentBlock"/>.</returns>
360+
public static ContentBlock ToContentBlock(this AIContent content)
361+
{
362+
ContentBlock contentBlock = content switch
205363
{
206364
TextContent textContent => new TextContentBlock
207365
{
@@ -230,9 +388,55 @@ internal static ContentBlock ToContent(this AIContent content) =>
230388
}
231389
},
232390

391+
FunctionCallContent callContent => new ToolUseContentBlock()
392+
{
393+
Id = callContent.CallId,
394+
Name = callContent.Name,
395+
Input = JsonSerializer.SerializeToElement(callContent.Arguments, McpJsonUtilities.DefaultOptions.GetTypeInfo<IDictionary<string, object?>>()!),
396+
},
397+
398+
FunctionResultContent resultContent => new ToolResultContentBlock()
399+
{
400+
ToolUseId = resultContent.CallId,
401+
IsError = resultContent.Exception is not null,
402+
Content =
403+
resultContent.Result is AIContent c ? [c.ToContentBlock()] :
404+
resultContent.Result is IEnumerable<AIContent> ec ? [.. ec.Select(c => c.ToContentBlock())] :
405+
[new TextContentBlock { Text = JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions.GetTypeInfo<object>()) }],
406+
StructuredContent = resultContent.Result is JsonElement je ? je : null,
407+
},
408+
233409
_ => new TextContentBlock
234410
{
235411
Text = JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))),
236412
}
237413
};
414+
415+
contentBlock.Meta = content.AdditionalProperties?.ToJsonObject();
416+
417+
return contentBlock;
418+
}
419+
420+
private sealed class ToolAIFunctionDeclaration(Tool tool) : AIFunctionDeclaration
421+
{
422+
public override string Name => tool.Name;
423+
424+
public override string Description => tool.Description ?? "";
425+
426+
public override IReadOnlyDictionary<string, object?> AdditionalProperties =>
427+
field ??= tool.Meta is { } meta ? meta.ToDictionary(p => p.Key, p => (object?)p.Value) : [];
428+
429+
public override JsonElement JsonSchema => tool.InputSchema;
430+
431+
public override JsonElement? ReturnJsonSchema => tool.OutputSchema;
432+
433+
public override object? GetService(Type serviceType, object? serviceKey = null)
434+
{
435+
Throw.IfNull(serviceType);
436+
437+
return
438+
serviceKey is null && serviceType.IsInstanceOfType(tool) ? tool :
439+
base.GetService(serviceType, serviceKey);
440+
}
441+
}
238442
}

0 commit comments

Comments
 (0)