Skip to content

Commit 977dfee

Browse files
committed
Ensure that when returning a Multi, everything is run on in a blocking thread
Closes: #2010
1 parent f7c8e31 commit 977dfee

File tree

5 files changed

+62
-73
lines changed

5 files changed

+62
-73
lines changed

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingAndRequestScopePropagationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ void testNonBlockingToolInvocationFromEventLoop() {
155155
});
156156

157157
Awaitility.await().until(() -> result.get() != null);
158-
assertThat(result.get()).contains(value.get(), caller.get());
158+
assertThat(result.get()).contains(value.get());
159159
}
160160

161161
@Test

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ void testNonBlockingToolInvocationFromEventLoop() {
148148
});
149149

150150
Awaitility.await().until(() -> result.get() != null);
151-
assertThat(result.get()).contains(uuid, caller.get());
151+
assertThat(result.get()).contains(uuid);
152152
}
153153

154154
@Test

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 25 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,13 @@
2525
import java.util.Optional;
2626
import java.util.UUID;
2727
import java.util.concurrent.Callable;
28-
import java.util.concurrent.CompletableFuture;
28+
import java.util.concurrent.Executor;
2929
import java.util.concurrent.ExecutorService;
30-
import java.util.concurrent.Flow;
3130
import java.util.concurrent.Future;
3231
import java.util.function.Function;
33-
import java.util.function.Supplier;
3432

3533
import org.eclipse.microprofile.config.ConfigProvider;
34+
import org.eclipse.microprofile.context.ManagedExecutor;
3635
import org.jboss.logging.Logger;
3736

3837
import dev.langchain4j.agent.tool.ReturnBehavior;
@@ -107,6 +106,9 @@
107106
import io.quarkiverse.langchain4j.runtime.types.TypeSignatureParser;
108107
import io.quarkiverse.langchain4j.runtime.types.TypeUtil;
109108
import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider;
109+
import io.quarkus.arc.Arc;
110+
import io.quarkus.arc.InstanceHandle;
111+
import io.quarkus.runtime.BlockingOperationControl;
110112
import io.smallrye.mutiny.Multi;
111113
import io.smallrye.mutiny.infrastructure.Infrastructure;
112114
import io.vertx.core.Context;
@@ -184,6 +186,19 @@ public Object implement(Input input) {
184186

185187
private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, InvocationContext invocationContext,
186188
QuarkusAiServiceContext context) {
189+
if (TypeUtil.isMulti(methodCreateInfo.getReturnType()) && !BlockingOperationControl.isBlockingAllowed()) {
190+
// this a special case where we can't block, so we need to delegate the to a worker pool
191+
// as so many of the things done in LangChain4j are blocking
192+
return Multi.createFrom().deferred(
193+
() -> ((Multi<?>) doImplement0(methodCreateInfo, invocationContext, context)))
194+
.runSubscriptionOn(createExecutor());
195+
} else {
196+
return doImplement0(methodCreateInfo, invocationContext, context);
197+
}
198+
}
199+
200+
private static Object doImplement0(AiServiceMethodCreateInfo methodCreateInfo, InvocationContext invocationContext,
201+
QuarkusAiServiceContext context) {
187202
boolean isRunningOnWorkerThread = !Context.isOnEventLoopThread();
188203
Object[] methodArgs = invocationContext.methodArguments().toArray(Object[]::new);
189204
Object memoryId = invocationContext.chatMemoryId();
@@ -249,74 +264,8 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, In
249264
.build();
250265
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
251266

252-
if (!isMulti) {
253-
augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
254-
userMessage = (UserMessage) augmentationResult.chatMessage();
255-
} else {
256-
// TODO duplicated context propagation.
257-
// this a special case where we can't block, so we need to delegate the
258-
// retrieval augmentation to a worker pool
259-
CompletableFuture<AugmentationResult> augmentationResultCF = CompletableFuture
260-
.supplyAsync(new Supplier<>() {
261-
@Override
262-
public AugmentationResult get() {
263-
return context.retrievalAugmentor.augment(augmentationRequest);
264-
}
265-
}, Infrastructure.getDefaultWorkerPool());
266-
267-
return Multi.createFrom().completionStage(augmentationResultCF).flatMap(
268-
new Function<>() {
269-
@Override
270-
public Flow.Publisher<?> apply(AugmentationResult ar) {
271-
ChatMessage augmentedUserMessage = ar.chatMessage();
272-
273-
ChatMemory memory = context.chatMemoryService.getChatMemory(memoryId);
274-
var guardrailRequestParams = GuardrailRequestParams.builder()
275-
.chatMemory(memory)
276-
.augmentationResult(ar)
277-
.userMessageTemplate(methodCreateInfo.getUserMessageTemplate())
278-
.variables(templateVariables)
279-
.invocationContext(invocationContext)
280-
.aiServiceListenerRegistrar(context.eventListenerRegistrar)
281-
.build();
282-
283-
UserMessage guardrailsMessage = GuardrailsSupport.executeInputGuardrails(
284-
context.guardrailService(),
285-
(UserMessage) augmentedUserMessage,
286-
methodCreateInfo, guardrailRequestParams);
287-
List<ChatMessage> messagesToSend = messagesToSend(guardrailsMessage, needsMemorySeed);
288-
var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications,
289-
finalToolExecutors, ar.contents(), context, invocationContext, memoryId,
290-
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(),
291-
isRunningOnWorkerThread, methodCreateInfo, methodArgs);
292-
293-
return stream
294-
.filter(event -> !isStringMulti
295-
|| event instanceof ChatEvent.PartialResponseEvent)
296-
.map(event -> {
297-
if (isStringMulti && event instanceof ChatEvent.PartialResponseEvent) {
298-
return ((ChatEvent.PartialResponseEvent) event).getChunk();
299-
}
300-
return event;
301-
})
302-
.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
303-
new ResponseAugmenterParams((UserMessage) augmentedUserMessage, memory,
304-
ar,
305-
methodCreateInfo.getUserMessageTemplate(), templateVariables)));
306-
}
307-
308-
private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
309-
boolean needsMemorySeed) {
310-
return context.hasChatMemory()
311-
? createMessagesToSendForExistingMemory(systemMessage, augmentedUserMessage,
312-
context.chatMemoryService.getChatMemory(memoryId), needsMemorySeed,
313-
context,
314-
methodCreateInfo)
315-
: createMessagesToSendForNoMemory(systemMessage, augmentedUserMessage,
316-
needsMemorySeed, context, methodCreateInfo);
317-
}
318-
});
319-
}
267+
augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
268+
userMessage = (UserMessage) augmentationResult.chatMessage();
320269
}
321270

322271
var guardrailService = context.guardrailService();
@@ -1173,6 +1122,11 @@ private static int getMaxSequentialToolExecutions() {
11731122
DEFAULT_MAX_SEQUENTIAL_TOOL_EXECUTIONS);
11741123
}
11751124

1125+
private static Executor createExecutor() {
1126+
InstanceHandle<ManagedExecutor> executor = Arc.container().instance(ManagedExecutor.class);
1127+
return executor.isAvailable() ? executor.get() : Infrastructure.getDefaultExecutor();
1128+
}
1129+
11761130
public static class Input {
11771131
final QuarkusAiServiceContext context;
11781132
final AiServiceMethodCreateInfo createInfo;

samples/chatbot/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@
129129
<artifactId>quarkus-langchain4j-redis</artifactId>
130130
<version>${quarkus-langchain4j.version}</version>
131131
</dependency>
132+
<dependency>
133+
<groupId>io.quarkiverse.langchain4j</groupId>
134+
<artifactId>quarkus-langchain4j-memory-store-redis</artifactId>
135+
<version>${quarkus-langchain4j.version}</version>
136+
</dependency>
132137
</dependencies>
133138
</profile>
134139
<profile>
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package io.quarkiverse.langchain4j.sample.chatbot;
2+
3+
import io.quarkiverse.langchain4j.RegisterAiService;
4+
import io.smallrye.mutiny.Multi;
5+
import jakarta.ws.rs.DefaultValue;
6+
import jakarta.ws.rs.GET;
7+
import jakarta.ws.rs.Path;
8+
import org.jboss.resteasy.reactive.RestQuery;
9+
10+
@Path("assistant")
11+
public class AssistantResource {
12+
13+
private final Assistant assistant;
14+
15+
public AssistantResource(Assistant assistant) {
16+
this.assistant = assistant;
17+
}
18+
19+
@GET
20+
public Multi<String> get(
21+
@DefaultValue("Write a short 1 paragraph funny poem about javascript frameworks") @RestQuery String message) {
22+
return assistant.chat(message);
23+
}
24+
25+
@RegisterAiService
26+
interface Assistant {
27+
28+
Multi<String> chat(String message);
29+
}
30+
}

0 commit comments

Comments
 (0)