Skip to content

Commit 334e206

Browse files
committed
Decouple Server from TransportInterface for reusability
1 parent 470237a commit 334e206

File tree

8 files changed

+76
-91
lines changed

8 files changed

+76
-91
lines changed

src/Server/Protocol.php

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
/**
3333
* @final
3434
*
35-
* @phpstan-import-type McpFiber from \Mcp\Server\Transport\TransportInterface
36-
* @phpstan-import-type FiberSuspend from \Mcp\Server\Transport\TransportInterface
35+
* @phpstan-import-type McpFiber from TransportInterface
36+
* @phpstan-import-type FiberSuspend from TransportInterface
3737
*
3838
* @author Christopher Hertel <[email protected]>
3939
* @author Kyrian Obikwelu <[email protected]>
@@ -55,9 +55,6 @@ class Protocol
5555
/** Session key for active request meta */
5656
public const SESSION_ACTIVE_REQUEST_META = '_mcp.active_request_meta';
5757

58-
/** @var TransportInterface<mixed>|null */
59-
private ?TransportInterface $transport = null;
60-
6158
/**
6259
* @param array<int, RequestHandlerInterface<ResultInterface|array<string, mixed>>> $requestHandlers
6360
* @param array<int, NotificationHandlerInterface> $notificationHandlers
@@ -73,39 +70,25 @@ public function __construct(
7370
}
7471

7572
/**
76-
* @return TransportInterface<mixed>
77-
*/
78-
public function getTransport(): TransportInterface
79-
{
80-
return $this->transport;
81-
}
82-
83-
/**
84-
* Connect this protocol to a transport.
73+
* Connect this protocol to transport.
8574
*
8675
* The protocol takes ownership of the transport and sets up all callbacks.
8776
*
8877
* @param TransportInterface<mixed> $transport
8978
*/
9079
public function connect(TransportInterface $transport): void
9180
{
92-
if ($this->transport) {
93-
throw new \RuntimeException('Protocol already connected to a transport');
94-
}
95-
96-
$this->transport = $transport;
81+
$transport->onMessage($this->processInput(...));
9782

98-
$this->transport->onMessage([$this, 'processInput']);
83+
$transport->onSessionEnd($this->destroySession(...));
9984

100-
$this->transport->onSessionEnd([$this, 'destroySession']);
85+
$transport->setOutgoingMessagesProvider($this->consumeOutgoingMessages(...));
10186

102-
$this->transport->setOutgoingMessagesProvider([$this, 'consumeOutgoingMessages']);
87+
$transport->setPendingRequestsProvider($this->getPendingRequests(...));
10388

104-
$this->transport->setPendingRequestsProvider([$this, 'getPendingRequests']);
89+
$transport->setResponseFinder($this->checkResponse(...));
10590

106-
$this->transport->setResponseFinder([$this, 'checkResponse']);
107-
108-
$this->transport->setFiberYieldHandler([$this, 'handleFiberYield']);
91+
$transport->setFiberYieldHandler($this->handleFiberYield(...));
10992

11093
$this->logger->info('Protocol connected to transport', ['transport' => $transport::class]);
11194
}
@@ -114,8 +97,10 @@ public function connect(TransportInterface $transport): void
11497
* Handle an incoming message from the transport.
11598
*
11699
* This is called by the transport whenever ANY message arrives.
100+
*
101+
* @param TransportInterface<mixed> $transport
117102
*/
118-
public function processInput(string $input, ?Uuid $sessionId): void
103+
public function processInput(TransportInterface $transport, string $input, ?Uuid $sessionId): void
119104
{
120105
$this->logger->info('Received message to process.', ['message' => $input]);
121106

@@ -126,21 +111,21 @@ public function processInput(string $input, ?Uuid $sessionId): void
126111
} catch (\JsonException $e) {
127112
$this->logger->warning('Failed to decode json message.', ['exception' => $e]);
128113
$error = Error::forParseError($e->getMessage());
129-
$this->sendResponse($error, null);
114+
$this->sendResponse($transport, $error, null);
130115

131116
return;
132117
}
133118

134-
$session = $this->resolveSession($sessionId, $messages);
119+
$session = $this->resolveSession($transport, $sessionId, $messages);
135120
if (null === $session) {
136121
return;
137122
}
138123

139124
foreach ($messages as $message) {
140125
if ($message instanceof InvalidInputMessageException) {
141-
$this->handleInvalidMessage($message, $session);
126+
$this->handleInvalidMessage($transport, $message, $session);
142127
} elseif ($message instanceof Request) {
143-
$this->handleRequest($message, $session);
128+
$this->handleRequest($transport, $message, $session);
144129
} elseif ($message instanceof Response || $message instanceof Error) {
145130
$this->handleResponse($message, $session);
146131
} elseif ($message instanceof Notification) {
@@ -151,15 +136,25 @@ public function processInput(string $input, ?Uuid $sessionId): void
151136
$session->save();
152137
}
153138

154-
private function handleInvalidMessage(InvalidInputMessageException $exception, SessionInterface $session): void
139+
/**
140+
* Handle an invalid message from the transport.
141+
*
142+
* @param TransportInterface<mixed> $transport
143+
*/
144+
private function handleInvalidMessage(TransportInterface $transport, InvalidInputMessageException $exception, SessionInterface $session): void
155145
{
156146
$this->logger->warning('Failed to create message.', ['exception' => $exception]);
157147

158148
$error = Error::forInvalidRequest($exception->getMessage());
159-
$this->sendResponse($error, $session);
149+
$this->sendResponse($transport, $error, $session);
160150
}
161151

162-
private function handleRequest(Request $request, SessionInterface $session): void
152+
/**
153+
* Handle a request from the transport.
154+
*
155+
* @param TransportInterface<mixed> $transport
156+
*/
157+
private function handleRequest(TransportInterface $transport, Request $request, SessionInterface $session): void
163158
{
164159
$this->logger->info('Handling request.', ['request' => $request]);
165160

@@ -192,32 +187,32 @@ private function handleRequest(Request $request, SessionInterface $session): voi
192187
}
193188
}
194189

195-
$this->transport->attachFiberToSession($fiber, $session->getId());
190+
$transport->attachFiberToSession($fiber, $session->getId());
196191

197192
return;
198193
} else {
199194
$finalResult = $fiber->getReturn();
200195

201-
$this->sendResponse($finalResult, $session);
196+
$this->sendResponse($transport, $finalResult, $session);
202197
}
203198
} catch (\InvalidArgumentException $e) {
204199
$this->logger->warning(\sprintf('Invalid argument: %s', $e->getMessage()), ['exception' => $e]);
205200

206201
$error = Error::forInvalidParams($e->getMessage(), $request->getId());
207-
$this->sendResponse($error, $session);
202+
$this->sendResponse($transport, $error, $session);
208203
} catch (\Throwable $e) {
209204
$this->logger->error(\sprintf('Uncaught exception: %s', $e->getMessage()), ['exception' => $e]);
210205

211206
$error = Error::forInternalError($e->getMessage(), $request->getId());
212-
$this->sendResponse($error, $session);
207+
$this->sendResponse($transport, $error, $session);
213208
}
214209

215210
break;
216211
}
217212

218213
if (!$handlerFound) {
219214
$error = Error::forMethodNotFound(\sprintf('No handler found for method "%s".', $request::getMethod()), $request->getId());
220-
$this->sendResponse($error, $session);
215+
$this->sendResponse($transport, $error, $session);
221216
}
222217
}
223218

@@ -299,10 +294,11 @@ public function sendNotification(Notification $notification, SessionInterface $s
299294
/**
300295
* Sends a response either immediately or queued for later delivery.
301296
*
297+
* @param TransportInterface<mixed> $transport
302298
* @param Response<ResultInterface|array<string, mixed>>|Error $response
303299
* @param array<string, mixed> $context
304300
*/
305-
private function sendResponse(Response|Error $response, ?SessionInterface $session, array $context = []): void
301+
private function sendResponse(TransportInterface $transport, Response|Error $response, ?SessionInterface $session, array $context = []): void
306302
{
307303
if (null === $session) {
308304
$this->logger->info('Sending immediate response', [
@@ -327,7 +323,7 @@ private function sendResponse(Response|Error $response, ?SessionInterface $sessi
327323
}
328324

329325
$context['type'] = 'response';
330-
$this->transport->send($encoded, $context);
326+
$transport->send($encoded, $context);
331327
} else {
332328
$this->logger->info('Queueing server response', [
333329
'response_id' => $response->getId(),
@@ -519,24 +515,25 @@ private function hasInitializeRequest(array $messages): bool
519515
/**
520516
* Resolves and validates the session based on the request context.
521517
*
522-
* @param Uuid|null $sessionId The session ID from the transport
523-
* @param array<int,mixed> $messages The parsed messages
518+
* @param TransportInterface<mixed> $transport
519+
* @param Uuid|null $sessionId The session ID from the transport
520+
* @param array<int,mixed> $messages The parsed messages
524521
*/
525-
private function resolveSession(?Uuid $sessionId, array $messages): ?SessionInterface
522+
private function resolveSession(TransportInterface $transport, ?Uuid $sessionId, array $messages): ?SessionInterface
526523
{
527524
if ($this->hasInitializeRequest($messages)) {
528525
// Spec: An initialize request must not be part of a batch.
529526
if (\count($messages) > 1) {
530527
$error = Error::forInvalidRequest('The "initialize" request MUST NOT be part of a batch.');
531-
$this->sendResponse($error, null);
528+
$this->sendResponse($transport, $error, null);
532529

533530
return null;
534531
}
535532

536533
// Spec: An initialize request must not have a session ID.
537534
if ($sessionId) {
538535
$error = Error::forInvalidRequest('A session ID MUST NOT be sent with an "initialize" request.');
539-
$this->sendResponse($error, null);
536+
$this->sendResponse($transport, $error, null);
540537

541538
return null;
542539
}
@@ -546,21 +543,21 @@ private function resolveSession(?Uuid $sessionId, array $messages): ?SessionInte
546543
'session_id' => $session->getId()->toRfc4122(),
547544
]);
548545

549-
$this->transport->setSessionId($session->getId());
546+
$transport->setSessionId($session->getId());
550547

551548
return $session;
552549
}
553550

554551
if (!$sessionId) {
555552
$error = Error::forInvalidRequest('A valid session id is REQUIRED for non-initialize requests.');
556-
$this->sendResponse($error, null, ['status_code' => 400]);
553+
$this->sendResponse($transport, $error, null, ['status_code' => 400]);
557554

558555
return null;
559556
}
560557

561558
if (!$this->sessionStore->exists($sessionId)) {
562559
$error = Error::forInvalidRequest('Session not found or has expired.');
563-
$this->sendResponse($error, null, ['status_code' => 404]);
560+
$this->sendResponse($transport, $error, null, ['status_code' => 404]);
564561

565562
return null;
566563
}

src/Server/Transport/BaseTransport.php

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@
2626
* @phpstan-import-type FiberSuspend from TransportInterface
2727
* @phpstan-import-type McpFiber from TransportInterface
2828
*
29+
* @template TResult
30+
* @implements TransportInterface<TResult>
31+
*
2932
* @author Kyrian Obikwelu <[email protected]>
3033
*/
31-
abstract class BaseTransport
34+
abstract class BaseTransport implements TransportInterface
3235
{
3336
use ManagesTransportCallbacks;
3437

@@ -126,7 +129,7 @@ protected function handleFiberYield(mixed $yielded, ?Uuid $sessionId): void
126129
protected function handleMessage(string $payload, ?Uuid $sessionId): void
127130
{
128131
if (\is_callable($this->messageListener)) {
129-
($this->messageListener)($payload, $sessionId);
132+
($this->messageListener)($this, $payload, $sessionId);
130133
}
131134
}
132135

src/Server/Transport/InMemoryTransport.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
use Symfony\Component\Uid\Uuid;
1616

1717
/**
18-
* @implements TransportInterface<null>
18+
* @extends BaseTransport<null>
1919
*
2020
* @author Tobias Nyholm <[email protected]>
2121
*/
22-
class InMemoryTransport extends BaseTransport implements TransportInterface
22+
class InMemoryTransport extends BaseTransport
2323
{
2424
use ManagesTransportCallbacks;
2525

src/Server/Transport/ManagesTransportCallbacks.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
* */
2727
trait ManagesTransportCallbacks
2828
{
29-
/** @var callable(string, ?Uuid): void */
29+
/** @var callable(TransportInterface<mixed>, string, ?Uuid): void */
3030
protected $messageListener;
3131

3232
/** @var callable(Uuid): void */

src/Server/Transport/StdioTransport.php

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
use Psr\Log\LoggerInterface;
1616

1717
/**
18-
* @implements TransportInterface<int>
18+
* @extends BaseTransport<int>
1919
*
2020
* @author Kyrian Obikwelu <[email protected]>
21-
* */
22-
class StdioTransport extends BaseTransport implements TransportInterface
21+
*/
22+
class StdioTransport extends BaseTransport
2323
{
2424
/**
2525
* @param resource $input

src/Server/Transport/StreamableHttpTransport.php

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
use Symfony\Component\Uid\Uuid;
2222

2323
/**
24-
* @implements TransportInterface<ResponseInterface>
24+
* @extends BaseTransport<ResponseInterface>
2525
*
2626
* @author Kyrian Obikwelu <[email protected]>
27-
* */
28-
class StreamableHttpTransport extends BaseTransport implements TransportInterface
27+
*/
28+
class StreamableHttpTransport extends BaseTransport
2929
{
3030
private ResponseFactoryInterface $responseFactory;
3131
private StreamFactoryInterface $streamFactory;

src/Server/Transport/TransportInterface.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public function close(): void;
7070
*
7171
* The transport calls this whenever ANY message arrives, regardless of source.
7272
*
73-
* @param callable(string $message, ?Uuid $sessionId): void $listener
73+
* @param callable(TransportInterface<TResult> $transport, string $message, ?Uuid $sessionId): void $listener
7474
*/
7575
public function onMessage(callable $listener): void;
7676

0 commit comments

Comments
 (0)