Skip to content

Commit 248c7c5

Browse files
authored
fix: Support non-blocking sendMessage (#327)
This is a backport of the following PRs from the a2a-python repo: 349, 294, 449, 440 and 472 These include: - Non blocking sendMessage should invoke push notification - Not possible to cancel Task in final state - Handle concurrent task completion during cancellation - Persist task state after client disconnect Also add DefaultRequestHandlerTest, and backport the parts that make sense.
1 parent b976bcf commit 248c7c5

File tree

8 files changed

+188
-28
lines changed

8 files changed

+188
-28
lines changed
1.41 MB
Binary file not shown.

server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java

Lines changed: 104 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import java.util.concurrent.ConcurrentMap;
1515
import java.util.concurrent.Executor;
1616
import java.util.concurrent.Flow;
17+
import java.util.concurrent.atomic.AtomicBoolean;
1718
import java.util.concurrent.atomic.AtomicReference;
1819
import java.util.function.Supplier;
1920

@@ -48,6 +49,8 @@
4849
import io.a2a.spec.StreamingEventKind;
4950
import io.a2a.spec.Task;
5051
import io.a2a.spec.TaskIdParams;
52+
import io.a2a.spec.TaskNotCancelableError;
53+
import io.a2a.spec.TaskState;
5154
import io.a2a.spec.TaskNotFoundError;
5255
import io.a2a.spec.TaskPushNotificationConfig;
5356
import io.a2a.spec.TaskQueryParams;
@@ -122,6 +125,13 @@ public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws
122125
if (task == null) {
123126
throw new TaskNotFoundError();
124127
}
128+
129+
// Check if task is in a non-cancelable state (completed, canceled, failed, rejected)
130+
if (task.getStatus().state().isFinal()) {
131+
throw new TaskNotCancelableError(
132+
"Task cannot be canceled - current state: " + task.getStatus().state().asString());
133+
}
134+
125135
TaskManager taskManager = new TaskManager(
126136
task.getId(),
127137
task.getContextId(),
@@ -148,11 +158,17 @@ public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws
148158

149159
EventConsumer consumer = new EventConsumer(queue);
150160
EventKind type = resultAggregator.consumeAll(consumer);
151-
if (type instanceof Task tempTask) {
152-
return tempTask;
161+
if (!(type instanceof Task tempTask)) {
162+
throw new InternalError("Agent did not return valid response for cancel");
153163
}
154164

155-
throw new InternalError("Agent did not return a valid response");
165+
// Verify task was actually canceled (not completed concurrently)
166+
if (tempTask.getStatus().state() != TaskState.CANCELED) {
167+
throw new TaskNotCancelableError(
168+
"Task cannot be canceled - current state: " + tempTask.getStatus().state().asString());
169+
}
170+
171+
return tempTask;
156172
}
157173

158174
@Override
@@ -166,32 +182,42 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte
166182
EventQueue queue = queueManager.createOrTap(taskId);
167183
ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null);
168184

169-
boolean interrupted = false;
185+
boolean blocking = true; // Default to blocking behavior
186+
if (params.configuration() != null && Boolean.FALSE.equals(params.configuration().blocking())) {
187+
blocking = false;
188+
}
189+
190+
boolean interruptedOrNonBlocking = false;
170191

171192
EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(taskId, mss.requestContext, queue);
172193
ResultAggregator.EventTypeAndInterrupt etai = null;
173194
try {
195+
// Create callback for push notifications during background event processing
196+
Runnable pushNotificationCallback = () -> sendPushNotification(taskId, resultAggregator);
197+
174198
EventConsumer consumer = new EventConsumer(queue);
175199

176200
// This callback must be added before we start consuming. Otherwise,
177201
// any errors thrown by the producerRunnable are not picked up by the consumer
178202
producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback());
179-
etai = resultAggregator.consumeAndBreakOnInterrupt(consumer);
203+
etai = resultAggregator.consumeAndBreakOnInterrupt(consumer, blocking, pushNotificationCallback);
180204

181205
if (etai == null) {
182206
LOGGER.debug("No result, throwing InternalError");
183207
throw new InternalError("No result");
184208
}
185-
interrupted = etai.interrupted();
186-
LOGGER.debug("Was interrupted: {}", interrupted);
209+
interruptedOrNonBlocking = etai.interrupted();
210+
LOGGER.debug("Was interrupted or non-blocking: {}", interruptedOrNonBlocking);
187211

188212
EventKind kind = etai.eventType();
189213
if (kind instanceof Task taskResult && !taskId.equals(taskResult.getId())) {
190214
throw new InternalError("Task ID mismatch in agent response");
191215
}
192216

217+
// Send push notification after initial return (for both blocking and non-blocking)
218+
pushNotificationCallback.run();
193219
} finally {
194-
if (interrupted) {
220+
if (interruptedOrNonBlocking) {
195221
CompletableFuture<Void> cleanupTask = CompletableFuture.runAsync(() -> cleanupProducer(taskId), executor);
196222
trackBackgroundTask(cleanupTask);
197223
} else {
@@ -214,13 +240,15 @@ public Flow.Publisher<StreamingEventKind> onMessageSendStream(
214240
ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null);
215241

216242
EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(taskId.get(), mss.requestContext, queue);
243+
244+
// Move consumer creation and callback registration outside try block
245+
// so consumer is available for background consumption on client disconnect
217246
EventConsumer consumer = new EventConsumer(queue);
247+
producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback());
218248

219-
try {
249+
AtomicBoolean backgroundConsumeStarted = new AtomicBoolean(false);
220250

221-
// This callback must be added before we start consuming. Otherwise,
222-
// any errors thrown by the producerRunnable are not picked up by the consumer
223-
producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback());
251+
try {
224252
Flow.Publisher<Event> results = resultAggregator.consumeAndEmit(consumer);
225253

226254
Flow.Publisher<Event> eventPublisher =
@@ -258,7 +286,61 @@ public Flow.Publisher<StreamingEventKind> onMessageSendStream(
258286
return true;
259287
}));
260288

261-
return convertingProcessor(eventPublisher, event -> (StreamingEventKind) event);
289+
Flow.Publisher<StreamingEventKind> finalPublisher = convertingProcessor(eventPublisher, event -> (StreamingEventKind) event);
290+
291+
// Wrap publisher to detect client disconnect and continue background consumption
292+
return subscriber -> finalPublisher.subscribe(new Flow.Subscriber<StreamingEventKind>() {
293+
private Flow.Subscription subscription;
294+
295+
@Override
296+
public void onSubscribe(Flow.Subscription subscription) {
297+
this.subscription = subscription;
298+
// Wrap subscription to detect cancellation
299+
subscriber.onSubscribe(new Flow.Subscription() {
300+
@Override
301+
public void request(long n) {
302+
subscription.request(n);
303+
}
304+
305+
@Override
306+
public void cancel() {
307+
LOGGER.debug("Client cancelled subscription for task {}, starting background consumption", taskId.get());
308+
startBackgroundConsumption();
309+
subscription.cancel();
310+
}
311+
});
312+
}
313+
314+
@Override
315+
public void onNext(StreamingEventKind item) {
316+
subscriber.onNext(item);
317+
}
318+
319+
@Override
320+
public void onError(Throwable throwable) {
321+
subscriber.onError(throwable);
322+
}
323+
324+
@Override
325+
public void onComplete() {
326+
subscriber.onComplete();
327+
}
328+
329+
private void startBackgroundConsumption() {
330+
if (backgroundConsumeStarted.compareAndSet(false, true)) {
331+
// Client disconnected: continue consuming and persisting events in background
332+
CompletableFuture<Void> bgTask = CompletableFuture.runAsync(() -> {
333+
try {
334+
resultAggregator.consumeAll(consumer);
335+
LOGGER.debug("Background consumption completed for task {}", taskId.get());
336+
} catch (Exception e) {
337+
LOGGER.error("Error during background consumption for task {}", taskId.get(), e);
338+
}
339+
}, executor);
340+
trackBackgroundTask(bgTask);
341+
}
342+
}
343+
});
262344
} finally {
263345
CompletableFuture<Void> cleanupTask = CompletableFuture.runAsync(() -> cleanupProducer(taskId.get()), executor);
264346
trackBackgroundTask(cleanupTask);
@@ -454,5 +536,14 @@ private MessageSendSetup initMessageSend(MessageSendParams params, ServerCallCon
454536
return new MessageSendSetup(taskManager, task, requestContext);
455537
}
456538

539+
private void sendPushNotification(String taskId, ResultAggregator resultAggregator) {
540+
if (pushSender != null && taskId != null) {
541+
EventKind latest = resultAggregator.getCurrentResult();
542+
if (latest instanceof Task latestTask) {
543+
pushSender.sendNotification(latestTask);
544+
}
545+
}
546+
}
547+
457548
private record MessageSendSetup(TaskManager taskManager, Task task, RequestContext requestContext) {}
458549
}

server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public Flow.Publisher<Event> consumeAndEmit(EventConsumer consumer) {
4545
}));
4646
}
4747

48-
public EventKind consumeAll(EventConsumer consumer) {
48+
public EventKind consumeAll(EventConsumer consumer) throws JSONRPCError {
4949
AtomicReference<EventKind> returnedEvent = new AtomicReference<>();
5050
Flow.Publisher<Event> all = consumer.consumeAll();
5151
AtomicReference<Throwable> error = new AtomicReference<>();
@@ -65,13 +65,22 @@ public EventKind consumeAll(EventConsumer consumer) {
6565
},
6666
error::set);
6767

68+
Throwable err = error.get();
69+
if (err != null) {
70+
Utils.rethrow(err);
71+
}
72+
6873
if (returnedEvent.get() != null) {
6974
return returnedEvent.get();
7075
}
7176
return taskManager.getTask();
7277
}
7378

74-
public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer) throws JSONRPCError {
79+
public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, boolean blocking) throws JSONRPCError {
80+
return consumeAndBreakOnInterrupt(consumer, blocking, null);
81+
}
82+
83+
public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, boolean blocking, Runnable eventCallback) throws JSONRPCError {
7584
Flow.Publisher<Event> all = consumer.consumeAll();
7685
AtomicReference<Message> message = new AtomicReference<>();
7786
AtomicBoolean interrupted = new AtomicBoolean(false);
@@ -92,16 +101,28 @@ public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer)
92101

93102
callTaskManagerProcess(event);
94103

95-
if ((event instanceof Task task && task.getStatus().state() == TaskState.AUTH_REQUIRED)
96-
|| (event instanceof TaskStatusUpdateEvent tsue && tsue.getStatus().state() == TaskState.AUTH_REQUIRED)) {
104+
boolean shouldInterrupt = false;
105+
boolean isAuthRequired = (event instanceof Task task && task.getStatus().state() == TaskState.AUTH_REQUIRED)
106+
|| (event instanceof TaskStatusUpdateEvent tsue && tsue.getStatus().state() == TaskState.AUTH_REQUIRED);
107+
108+
// Always interrupt on auth_required, as it needs external action.
109+
if (isAuthRequired) {
97110
// auth-required is a special state: the message should be
98111
// escalated back to the caller, but the agent is expected to
99112
// continue producing events once the authorization is received
100113
// out-of-band. This is in contrast to input-required, where a
101114
// new request is expected in order for the agent to make progress,
102115
// so the agent should exit.
116+
shouldInterrupt = true;
117+
}
118+
// For non-blocking calls, interrupt as soon as a task is available.
119+
else if (!blocking) {
120+
shouldInterrupt = true;
121+
}
103122

104-
CompletableFuture.runAsync(() -> continueConsuming(all));
123+
if (shouldInterrupt) {
124+
// Continue consuming the rest of the events in the background.
125+
CompletableFuture.runAsync(() -> continueConsuming(all, eventCallback));
105126
interrupted.set(true);
106127
return false;
107128
}
@@ -118,11 +139,14 @@ public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer)
118139
message.get() != null ? message.get() : taskManager.getTask(), interrupted.get());
119140
}
120141

121-
private void continueConsuming(Flow.Publisher<Event> all) {
142+
private void continueConsuming(Flow.Publisher<Event> all, Runnable eventCallback) {
122143
consumer(createTubeConfig(),
123144
all,
124145
event -> {
125146
callTaskManagerProcess(event);
147+
if (eventCallback != null) {
148+
eventCallback.run();
149+
}
126150
return true;
127151
},
128152
t -> {});
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@
2828
import org.junit.jupiter.api.Timeout;
2929

3030
/**
31-
* Tests for DefaultRequestHandler background cleanup and task tracking functionality,
32-
* backported from Python PR #440 and #472.
31+
* Comprehensive tests for DefaultRequestHandler, backported from Python's
32+
* test_default_request_handler.py. These tests cover core functionality that
33+
* is transport-agnostic and should work across JSON-RPC, gRPC, and REST.
34+
*
35+
* Background cleanup and task tracking tests are from Python PR #440 and #472.
3336
*/
34-
public class DefaultRequestHandlerBackgroundTest {
37+
public class DefaultRequestHandlerTest {
3538

3639
private DefaultRequestHandler requestHandler;
3740
private InMemoryTaskStore taskStore;

server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static org.junit.jupiter.api.Assertions.assertEquals;
44
import static org.junit.jupiter.api.Assertions.assertNotNull;
5+
import static org.junit.jupiter.api.Assertions.assertTrue;
56
import static org.mockito.Mockito.reset;
67
import static org.mockito.Mockito.times;
78
import static org.mockito.Mockito.verify;
@@ -10,6 +11,9 @@
1011

1112
import java.util.Collections;
1213

14+
import io.a2a.server.events.EventConsumer;
15+
import io.a2a.server.events.EventQueue;
16+
import io.a2a.server.events.InMemoryQueueManager;
1317
import io.a2a.spec.EventKind;
1418
import io.a2a.spec.Message;
1519
import io.a2a.spec.Task;
@@ -183,4 +187,35 @@ void testGetCurrentResultWithMessageTakesPrecedence() {
183187
// Task manager should not be called when message is present
184188
verifyNoInteractions(mockTaskManager);
185189
}
190+
191+
@Test
192+
void testConsumeAndBreakNonBlocking() throws Exception {
193+
// Test that with blocking=false, the method returns after the first event
194+
Task firstEvent = createSampleTask("non_blocking_task", TaskState.WORKING, "ctx1");
195+
196+
// After processing firstEvent, the current result will be that task
197+
when(mockTaskManager.getTask()).thenReturn(firstEvent);
198+
199+
// Create an event queue using QueueManager (which has access to builder)
200+
InMemoryQueueManager queueManager =
201+
new InMemoryQueueManager();
202+
203+
EventQueue queue = queueManager.getEventQueueBuilder("test-task").build();
204+
queue.enqueueEvent(firstEvent);
205+
206+
// Create real EventConsumer with the queue
207+
EventConsumer eventConsumer =
208+
new EventConsumer(queue);
209+
210+
// Close queue after first event to simulate stream ending after processing
211+
queue.close();
212+
213+
ResultAggregator.EventTypeAndInterrupt result =
214+
aggregator.consumeAndBreakOnInterrupt(eventConsumer, false);
215+
216+
assertEquals(firstEvent, result.eventType());
217+
assertTrue(result.interrupted());
218+
verify(mockTaskManager).process(firstEvent);
219+
verify(mockTaskManager).getTask();
220+
}
186221
}

spec/src/main/java/io/a2a/spec/MessageSendConfiguration.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
@JsonInclude(JsonInclude.Include.NON_ABSENT)
1212
@JsonIgnoreProperties(ignoreUnknown = true)
1313
public record MessageSendConfiguration(List<String> acceptedOutputModes, Integer historyLength,
14-
PushNotificationConfig pushNotificationConfig, boolean blocking) {
14+
PushNotificationConfig pushNotificationConfig, Boolean blocking) {
1515

1616
public MessageSendConfiguration {
1717
if (historyLength != null && historyLength < 0) {
@@ -23,7 +23,7 @@ public static class Builder {
2323
List<String> acceptedOutputModes;
2424
Integer historyLength;
2525
PushNotificationConfig pushNotificationConfig;
26-
boolean blocking;
26+
Boolean blocking = true;
2727

2828
public Builder acceptedOutputModes(List<String> acceptedOutputModes) {
2929
this.acceptedOutputModes = acceptedOutputModes;
@@ -40,7 +40,7 @@ public Builder historyLength(Integer historyLength) {
4040
return this;
4141
}
4242

43-
public Builder blocking(boolean blocking) {
43+
public Builder blocking(Boolean blocking) {
4444
this.blocking = blocking;
4545
return this;
4646
}

spec/src/main/java/io/a2a/spec/TaskNotCancelableError.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,8 @@ public TaskNotCancelableError(
3131
data);
3232
}
3333

34+
public TaskNotCancelableError(@JsonProperty("message") String message) {
35+
this(null, message, null);
36+
}
37+
3438
}

0 commit comments

Comments
 (0)