|
25 | 25 | import java.util.Optional; |
26 | 26 | import java.util.UUID; |
27 | 27 | import java.util.concurrent.Callable; |
28 | | -import java.util.concurrent.CompletableFuture; |
| 28 | +import java.util.concurrent.Executor; |
29 | 29 | import java.util.concurrent.ExecutorService; |
30 | | -import java.util.concurrent.Flow; |
31 | 30 | import java.util.concurrent.Future; |
32 | 31 | import java.util.function.Function; |
33 | | -import java.util.function.Supplier; |
34 | 32 |
|
35 | 33 | import org.eclipse.microprofile.config.ConfigProvider; |
| 34 | +import org.eclipse.microprofile.context.ManagedExecutor; |
36 | 35 | import org.jboss.logging.Logger; |
37 | 36 |
|
38 | 37 | import dev.langchain4j.agent.tool.ReturnBehavior; |
|
107 | 106 | import io.quarkiverse.langchain4j.runtime.types.TypeSignatureParser; |
108 | 107 | import io.quarkiverse.langchain4j.runtime.types.TypeUtil; |
109 | 108 | import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider; |
| 109 | +import io.quarkus.arc.Arc; |
| 110 | +import io.quarkus.arc.InstanceHandle; |
| 111 | +import io.quarkus.runtime.BlockingOperationControl; |
110 | 112 | import io.smallrye.mutiny.Multi; |
111 | 113 | import io.smallrye.mutiny.infrastructure.Infrastructure; |
112 | 114 | import io.vertx.core.Context; |
@@ -184,6 +186,19 @@ public Object implement(Input input) { |
184 | 186 |
|
185 | 187 | private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, InvocationContext invocationContext, |
186 | 188 | QuarkusAiServiceContext context) { |
| 189 | + if (TypeUtil.isMulti(methodCreateInfo.getReturnType()) && !BlockingOperationControl.isBlockingAllowed()) { |
| 190 | + // this a special case where we can't block, so we need to delegate the to a worker pool |
| 191 | + // as so many of the things done in LangChain4j are blocking |
| 192 | + return Multi.createFrom().deferred( |
| 193 | + () -> ((Multi<?>) doImplement0(methodCreateInfo, invocationContext, context))) |
| 194 | + .runSubscriptionOn(createExecutor()); |
| 195 | + } else { |
| 196 | + return doImplement0(methodCreateInfo, invocationContext, context); |
| 197 | + } |
| 198 | + } |
| 199 | + |
| 200 | + private static Object doImplement0(AiServiceMethodCreateInfo methodCreateInfo, InvocationContext invocationContext, |
| 201 | + QuarkusAiServiceContext context) { |
187 | 202 | boolean isRunningOnWorkerThread = !Context.isOnEventLoopThread(); |
188 | 203 | Object[] methodArgs = invocationContext.methodArguments().toArray(Object[]::new); |
189 | 204 | Object memoryId = invocationContext.chatMemoryId(); |
@@ -249,74 +264,8 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, In |
249 | 264 | .build(); |
250 | 265 | AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata); |
251 | 266 |
|
252 | | - if (!isMulti) { |
253 | | - augmentationResult = context.retrievalAugmentor.augment(augmentationRequest); |
254 | | - userMessage = (UserMessage) augmentationResult.chatMessage(); |
255 | | - } else { |
256 | | - // TODO duplicated context propagation. |
257 | | - // this a special case where we can't block, so we need to delegate the |
258 | | - // retrieval augmentation to a worker pool |
259 | | - CompletableFuture<AugmentationResult> augmentationResultCF = CompletableFuture |
260 | | - .supplyAsync(new Supplier<>() { |
261 | | - @Override |
262 | | - public AugmentationResult get() { |
263 | | - return context.retrievalAugmentor.augment(augmentationRequest); |
264 | | - } |
265 | | - }, Infrastructure.getDefaultWorkerPool()); |
266 | | - |
267 | | - return Multi.createFrom().completionStage(augmentationResultCF).flatMap( |
268 | | - new Function<>() { |
269 | | - @Override |
270 | | - public Flow.Publisher<?> apply(AugmentationResult ar) { |
271 | | - ChatMessage augmentedUserMessage = ar.chatMessage(); |
272 | | - |
273 | | - ChatMemory memory = context.chatMemoryService.getChatMemory(memoryId); |
274 | | - var guardrailRequestParams = GuardrailRequestParams.builder() |
275 | | - .chatMemory(memory) |
276 | | - .augmentationResult(ar) |
277 | | - .userMessageTemplate(methodCreateInfo.getUserMessageTemplate()) |
278 | | - .variables(templateVariables) |
279 | | - .invocationContext(invocationContext) |
280 | | - .aiServiceListenerRegistrar(context.eventListenerRegistrar) |
281 | | - .build(); |
282 | | - |
283 | | - UserMessage guardrailsMessage = GuardrailsSupport.executeInputGuardrails( |
284 | | - context.guardrailService(), |
285 | | - (UserMessage) augmentedUserMessage, |
286 | | - methodCreateInfo, guardrailRequestParams); |
287 | | - List<ChatMessage> messagesToSend = messagesToSend(guardrailsMessage, needsMemorySeed); |
288 | | - var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications, |
289 | | - finalToolExecutors, ar.contents(), context, invocationContext, memoryId, |
290 | | - methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), |
291 | | - isRunningOnWorkerThread, methodCreateInfo, methodArgs); |
292 | | - |
293 | | - return stream |
294 | | - .filter(event -> !isStringMulti |
295 | | - || event instanceof ChatEvent.PartialResponseEvent) |
296 | | - .map(event -> { |
297 | | - if (isStringMulti && event instanceof ChatEvent.PartialResponseEvent) { |
298 | | - return ((ChatEvent.PartialResponseEvent) event).getChunk(); |
299 | | - } |
300 | | - return event; |
301 | | - }) |
302 | | - .plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo, |
303 | | - new ResponseAugmenterParams((UserMessage) augmentedUserMessage, memory, |
304 | | - ar, |
305 | | - methodCreateInfo.getUserMessageTemplate(), templateVariables))); |
306 | | - } |
307 | | - |
308 | | - private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage, |
309 | | - boolean needsMemorySeed) { |
310 | | - return context.hasChatMemory() |
311 | | - ? createMessagesToSendForExistingMemory(systemMessage, augmentedUserMessage, |
312 | | - context.chatMemoryService.getChatMemory(memoryId), needsMemorySeed, |
313 | | - context, |
314 | | - methodCreateInfo) |
315 | | - : createMessagesToSendForNoMemory(systemMessage, augmentedUserMessage, |
316 | | - needsMemorySeed, context, methodCreateInfo); |
317 | | - } |
318 | | - }); |
319 | | - } |
| 267 | + augmentationResult = context.retrievalAugmentor.augment(augmentationRequest); |
| 268 | + userMessage = (UserMessage) augmentationResult.chatMessage(); |
320 | 269 | } |
321 | 270 |
|
322 | 271 | var guardrailService = context.guardrailService(); |
@@ -1173,6 +1122,11 @@ private static int getMaxSequentialToolExecutions() { |
1173 | 1122 | DEFAULT_MAX_SEQUENTIAL_TOOL_EXECUTIONS); |
1174 | 1123 | } |
1175 | 1124 |
|
| 1125 | + private static Executor createExecutor() { |
| 1126 | + InstanceHandle<ManagedExecutor> executor = Arc.container().instance(ManagedExecutor.class); |
| 1127 | + return executor.isAvailable() ? executor.get() : Infrastructure.getDefaultExecutor(); |
| 1128 | + } |
| 1129 | + |
1176 | 1130 | public static class Input { |
1177 | 1131 | final QuarkusAiServiceContext context; |
1178 | 1132 | final AiServiceMethodCreateInfo createInfo; |
|
0 commit comments