diff --git a/core/deployment/pom.xml b/core/deployment/pom.xml index 560671613..c382cc0fe 100644 --- a/core/deployment/pom.xml +++ b/core/deployment/pom.xml @@ -121,6 +121,14 @@ quarkus-test-vertx test + + dev.langchain4j + langchain4j-core + ${langchain4j.version} + tests + test-jar + test + diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index 44242d73d..ee6def554 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -45,6 +45,7 @@ import jakarta.enterprise.inject.spi.DeploymentException; import jakarta.enterprise.util.AnnotationLiteral; import jakarta.inject.Inject; +import jakarta.inject.Singleton; import jakarta.interceptor.InterceptorBinding; import org.eclipse.microprofile.config.ConfigProvider; @@ -77,6 +78,10 @@ import dev.langchain4j.service.memory.ChatMemoryAccess; import dev.langchain4j.service.output.JsonSchemas; import dev.langchain4j.service.output.ServiceOutputParser; +import dev.langchain4j.service.tool.ToolArgumentsErrorHandler; +import dev.langchain4j.service.tool.ToolErrorContext; +import dev.langchain4j.service.tool.ToolErrorHandlerResult; +import dev.langchain4j.service.tool.ToolExecutionErrorHandler; import dev.langchain4j.spi.classloading.ClassInstanceFactory; import dev.langchain4j.spi.classloading.ClassMetadataProviderFactory; import io.quarkiverse.langchain4j.ModelName; @@ -341,6 +346,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, BuildProducer declarativeAiServiceProducer, BuildProducer toolProviderProducer, BuildProducer reflectiveClassProducer, + BuildProducer generatedBeanProducer, BuildProducer generatedClassProducer) { IndexView index = indexBuildItem.getIndex(); @@ -496,6 +502,8 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, toolHallucinationStrategy(instance), classInputGuardrails(declarativeAiServiceClassInfo, index), classOutputGuardrails(declarativeAiServiceClassInfo, index), + toolArgumentsErrorHandlerDotName(declarativeAiServiceClassInfo, generatedBeanProducer), + toolExecutionErrorHandlerDotName(declarativeAiServiceClassInfo, generatedBeanProducer), maxSequentialToolInvocations, allowContinuousForcedToolCalling, // we need to make these @DefaultBean because there could be other CDI beans of the same type that need to take precedence @@ -647,6 +655,97 @@ private static OutputGuardrailsLiteral classOutputGuardrails(DeclarativeAiServic declarativeAiServiceBuildItem.getOutputGuardrails().maxRetries()); } + private DotName toolArgumentsErrorHandlerDotName(ClassInfo aiServiceClassInfo, + BuildProducer generatedBean) { + return toolErrorHandlerDotName(aiServiceClassInfo, LangChain4jDotNames.HANDLE_TOOL_ARGUMENT_ERROR, generatedBean, + ToolArgumentsErrorHandler.class); + } + + private DotName toolExecutionErrorHandlerDotName(ClassInfo aiServiceClassInfo, + BuildProducer generatedBean) { + return toolErrorHandlerDotName(aiServiceClassInfo, LangChain4jDotNames.HANDLE_TOOL_EXECUTION_ERROR, generatedBean, + ToolExecutionErrorHandler.class); + } + + private DotName toolErrorHandlerDotName(ClassInfo aiServiceClassInfo, DotName annotationName, + BuildProducer generatedBean, + Class interfaceType) { + List instances = aiServiceClassInfo.annotations(annotationName); + if (instances.isEmpty()) { + return null; + } + if (instances.size() > 1) { + throw new IllegalConfigurationException( + "`@%s` can be used only once in an AI Service. Offending class is '%s'".formatted(annotationName, + aiServiceClassInfo.name())); + } + AnnotationTarget target = instances.get(0).target(); + if (target.kind() != AnnotationTarget.Kind.METHOD) { + throw new IllegalConfigurationException( + "`@%s` can be used only methods. Offending class is '%s'".formatted(annotationName, + aiServiceClassInfo.name())); + } + MethodInfo targetMethod = target.asMethod(); + if (!Modifier.isStatic(targetMethod.flags())) { + throw new IllegalConfigurationException( + "`@%s` can be used only on static methods. Offending class is '%s'".formatted(annotationName, + aiServiceClassInfo.name())); + } + DotName returnType = targetMethod.returnType().name(); + if ((!returnType.equals(DotNames.STRING)) && !returnType.equals(LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT)) { + throw new IllegalConfigurationException( + "`@%s` can be used only on static methods that return '%s' or '%s'. Offending class is '%s'" + .formatted(annotationName, + DotNames.STRING, LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT, aiServiceClassInfo.name())); + } + + ClassOutput output = new GeneratedBeanGizmoAdaptor(generatedBean); + String generatedClassName = aiServiceClassInfo.name().toString() + "$" + annotationName.withoutPackagePrefix(); + ClassCreator.Builder classCreatorBuilder = ClassCreator.builder() + .classOutput(output) + .interfaces(interfaceType) + .className(generatedClassName); + try (ClassCreator classCreator = classCreatorBuilder.build()) { + classCreator.addAnnotation(Singleton.class); + + MethodCreator handleMethod = classCreator.getMethodCreator(MethodDescriptor.ofMethod(generatedClassName, "handle", + ToolErrorHandlerResult.class, Throwable.class, ToolErrorContext.class)); + + List paramHandles = new ArrayList<>(); + for (MethodParameterInfo parameter : targetMethod.parameters()) { + DotName paramTypeDotName = parameter.type().name(); + if (paramTypeDotName.equals(DotNames.THROWABLE) || paramTypeDotName.equals(DotNames.EXCEPTION)) { + paramHandles.add(handleMethod.getMethodParam(0)); + } else if (paramTypeDotName.equals(LangChain4jDotNames.TOOL_ERROR_CONTEXT)) { + paramHandles.add(handleMethod.getMethodParam(1)); + } else { + throw new IllegalConfigurationException( + "`@%s` can be used only on static methods that use the parameters of type '%s' or '%s'. Offending class is '%s'" + .formatted(annotationName, + DotNames.THROWABLE, LangChain4jDotNames.TOOL_ERROR_CONTEXT, + aiServiceClassInfo.name())); + } + } + + ResultHandle result = handleMethod.invokeStaticInterfaceMethod(targetMethod, + paramHandles.toArray(new ResultHandle[0])); + if (returnType.equals(LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT)) { + handleMethod.returnValue(result); + } else if (returnType.equals(DotNames.STRING)) { + ResultHandle toolErrorHandlerResultResult = handleMethod.invokeStaticMethod( + MethodDescriptor.ofMethod(ToolErrorHandlerResult.class, "text", ToolErrorHandlerResult.class, + String.class), + result); + handleMethod.returnValue(toolErrorHandlerResultResult); + } else { + throw new IllegalStateException("Unhandled result type: " + returnType); + } + + } + + return DotName.createSimple(generatedClassName); + } + private static List tools(AnnotationInstance instance, IndexView index) { AnnotationValue toolsInstance = instance.value("tools"); if (toolsInstance != null) { @@ -792,6 +891,14 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, allToolHallucinationStrategies.add(bi.getToolHallucinationStrategyClassDotName()); } + String toolArgumentsErrorHandlerDotName = (bi.getToolArgumentsErrorHandlerDotName() != null + ? bi.getToolArgumentsErrorHandlerDotName().toString() + : null); + + String toolExecutionErrorHandlerDotName = (bi.getToolExecutionErrorHandlerDotName() != null + ? bi.getToolExecutionErrorHandlerDotName().toString() + : null); + String chatMemoryProviderSupplierClassName = bi.getChatMemoryProviderSupplierClassDotName() != null ? bi.getChatMemoryProviderSupplierClassDotName().toString() : null; @@ -884,6 +991,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, injectModerationModelBean, injectImageModel, toolHallucinationStrategyClassName, + toolArgumentsErrorHandlerDotName, + toolExecutionErrorHandlerDotName, classInputGuardrails(bi), classOutputGuardrails(bi), maxSequentialToolInvocations, @@ -931,6 +1040,12 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, if (bi.getToolHallucinationStrategyClassDotName() != null) { configurator.addInjectionPoint(ClassType.create(bi.getToolHallucinationStrategyClassDotName())); } + if (bi.getToolArgumentsErrorHandlerDotName() != null) { + configurator.addInjectionPoint(ClassType.create(bi.getToolArgumentsErrorHandlerDotName())); + } + if (bi.getToolExecutionErrorHandlerDotName() != null) { + configurator.addInjectionPoint(ClassType.create(bi.getToolExecutionErrorHandlerDotName())); + } if (LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) { configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MEMORY_PROVIDER)); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java index a64741a49..c01fa49de 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java @@ -33,6 +33,8 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem { private final Optional beanName; private final DeclarativeAiServiceInputGuardrails inputGuardrails; private final DeclarativeAiServiceOutputGuardrails outputGuardrails; + private final DotName toolArgumentsErrorHandlerDotName; + private final DotName toolExecutionErrorHandlerDotName; private final Integer maxSequentialToolInvocations; private final boolean allowContinuousForcedToolCalling; private final boolean makeDefaultBean; @@ -57,6 +59,8 @@ public DeclarativeAiServiceBuildItem( DotName toolHallucinationStrategyClassDotName, DeclarativeAiServiceInputGuardrails inputGuardrails, DeclarativeAiServiceOutputGuardrails outputGuardrails, + DotName toolArgumentsErrorHandlerDotName, + DotName toolExecutionErrorHandlerDotName, Integer maxSequentialToolInvocations, boolean allowContinuousForcedToolCalling, boolean makeDefaultBean) { @@ -79,6 +83,8 @@ public DeclarativeAiServiceBuildItem( this.toolHallucinationStrategyClassDotName = toolHallucinationStrategyClassDotName; this.inputGuardrails = inputGuardrails; this.outputGuardrails = outputGuardrails; + this.toolArgumentsErrorHandlerDotName = toolArgumentsErrorHandlerDotName; + this.toolExecutionErrorHandlerDotName = toolExecutionErrorHandlerDotName; this.maxSequentialToolInvocations = maxSequentialToolInvocations; this.allowContinuousForcedToolCalling = allowContinuousForcedToolCalling; this.makeDefaultBean = makeDefaultBean; @@ -160,6 +166,14 @@ public DeclarativeAiServiceOutputGuardrails getOutputGuardrails() { return outputGuardrails; } + public DotName getToolArgumentsErrorHandlerDotName() { + return toolArgumentsErrorHandlerDotName; + } + + public DotName getToolExecutionErrorHandlerDotName() { + return toolExecutionErrorHandlerDotName; + } + public boolean isMakeDefaultBean() { return makeDefaultBean; } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java index 3b68a40c9..c635f5103 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java @@ -81,6 +81,9 @@ public class DotNames { public static final DotName EXECUTOR = DotName.createSimple(Executor.class); + public static final DotName THROWABLE = DotName.createSimple(Throwable.class); + public static final DotName EXCEPTION = DotName.createSimple(Exception.class); + public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class); public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class); public static final DotName TOOL = DotName.createSimple(Tool.class); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java index da61ad410..3c6d026a7 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java @@ -6,6 +6,8 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.pdf.PdfFile; +import dev.langchain4j.exception.ToolArgumentsException; +import dev.langchain4j.exception.ToolExecutionException; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -29,11 +31,15 @@ import dev.langchain4j.service.UserName; import dev.langchain4j.service.guardrail.InputGuardrails; import dev.langchain4j.service.guardrail.OutputGuardrails; +import dev.langchain4j.service.tool.ToolErrorContext; +import dev.langchain4j.service.tool.ToolErrorHandlerResult; import dev.langchain4j.service.tool.ToolProvider; import dev.langchain4j.web.search.WebSearchEngine; import dev.langchain4j.web.search.WebSearchTool; import io.quarkiverse.langchain4j.AudioUrl; import io.quarkiverse.langchain4j.CreatedAware; +import io.quarkiverse.langchain4j.HandleToolArgumentError; +import io.quarkiverse.langchain4j.HandleToolExecutionError; import io.quarkiverse.langchain4j.ImageUrl; import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.PdfUrl; @@ -124,6 +130,9 @@ public class LangChain4jDotNames { static final DotName SEED_MEMORY = DotName.createSimple(SeedMemory.class); + static final DotName HANDLE_TOOL_ARGUMENT_ERROR = DotName.createSimple(HandleToolArgumentError.class); + static final DotName HANDLE_TOOL_EXECUTION_ERROR = DotName.createSimple(HandleToolExecutionError.class); + static final DotName WEB_SEARCH_TOOL = DotName.createSimple(WebSearchTool.class); static final DotName WEB_SEARCH_ENGINE = DotName.createSimple(WebSearchEngine.class); static final DotName IMAGE = DotName.createSimple(Image.class); @@ -136,4 +145,8 @@ public class LangChain4jDotNames { public static final DotName MCP_TOOLBOX = DotName.createSimple("io.quarkiverse.langchain4j.mcp.runtime.McpToolBox"); public static final DotName CHAT_EVENT = DotName.createSimple(ChatEvent.class); public static final DotName CHAT_MEMORY = DotName.createSimple(ChatMemory.class); + public static final DotName TOOL_ERROR_HANDLER_RESULT = DotName.createSimple(ToolErrorHandlerResult.class); + public static final DotName TOOL_ARGUMENTS_EXCEPTION = DotName.createSimple(ToolArgumentsException.class); + public static final DotName TOOL_EXECUTION_EXCEPTION = DotName.createSimple(ToolExecutionException.class); + public static final DotName TOOL_ERROR_CONTEXT = DotName.createSimple(ToolErrorContext.class); } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java index 9e529687c..e183e3df1 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java @@ -19,6 +19,7 @@ import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.exception.ToolArgumentsException; import dev.langchain4j.service.tool.ToolExecutor; import io.quarkiverse.langchain4j.runtime.ToolsRecorder; import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor; @@ -272,7 +273,7 @@ private void executeAndExpectFailure(String arguments, String methodName) { ToolExecutor toolExecutor = getToolExecutor(methodName); assertThatThrownBy(() -> toolExecutor.execute(request, null)) - .isExactlyInstanceOf(IllegalArgumentException.class); + .isExactlyInstanceOf(ToolArgumentsException.class); } private ToolExecutor getToolExecutor(String methodName) { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailEventLoopBlockingTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailEventLoopBlockingTest.java index 52bc85e5a..60ac575c7 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailEventLoopBlockingTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailEventLoopBlockingTest.java @@ -41,6 +41,7 @@ import io.quarkiverse.langchain4j.guardrails.ToolOutputGuardrailRequest; import io.quarkiverse.langchain4j.guardrails.ToolOutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.ToolOutputGuardrails; +import io.quarkiverse.langchain4j.runtime.BlockingToolNotAllowedException; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkiverse.langchain4j.test.Lists; import io.quarkus.arc.Arc; @@ -130,7 +131,7 @@ void toolWithInputGuardrail_throwsException_whenCalledFromEventLoop() throws Exc Throwable exception = exceptionRef.get(); assertThat(exception) .isNotNull() - .isInstanceOf(ToolExecutionException.class); + .isInstanceOf(BlockingToolNotAllowedException.class); // Our ToolGuardrailsWrapper should catch the event loop and throw a clear error assertThat(exception.getMessage()) @@ -169,7 +170,7 @@ void toolWithOutputGuardrail_throwsException_whenCalledFromEventLoop() throws Ex Throwable exception = exceptionRef.get(); assertThat(exception) .isNotNull() - .isInstanceOf(ToolExecutionException.class); + .isInstanceOf(BlockingToolNotAllowedException.class); assertThat(exception.getMessage()) .contains("Cannot execute guardrails") diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ToolArgumentsErrorHandlerTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ToolArgumentsErrorHandlerTest.java new file mode 100644 index 000000000..6b9786253 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ToolArgumentsErrorHandlerTest.java @@ -0,0 +1,137 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import static org.assertj.core.api.Assertions.*; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.mock.ChatModelMock; +import dev.langchain4j.service.tool.ToolErrorContext; +import dev.langchain4j.service.tool.ToolErrorHandlerResult; +import io.quarkiverse.langchain4j.HandleToolArgumentError; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.ToolBox; +import io.quarkus.test.QuarkusUnitTest; + +public class ToolArgumentsErrorHandlerTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); + + @Inject + Assistant assistant; + + @Inject + Assistant2 assistant2; + + @Test + @ActivateRequestContext + void assistant() { + assertThat(Assistant.ARGUMENT_ERROR_HANDLER_CALLED).isFalse(); + int initialToolInvocationCount = Tools.INVOCATION_COUNTER.get(); + + String hello = assistant.chat("hello"); + assertThat(hello).isNotNull(); + + int latestToolInvocationCount = Tools.INVOCATION_COUNTER.get(); + assertThat(latestToolInvocationCount).isEqualTo(initialToolInvocationCount + 1); + + assertThat(Assistant.ARGUMENT_ERROR_HANDLER_CALLED).isTrue(); + } + + @Test + @ActivateRequestContext + void assistant2() { + assertThat(Assistant2.ARGUMENT_ERROR_HANDLER_CALLED).isFalse(); + int initialToolInvocationCount = Tools.INVOCATION_COUNTER.get(); + + String hello = assistant2.chat("hello"); + assertThat(hello).isNotNull(); + + int latestToolInvocationCount = Tools.INVOCATION_COUNTER.get(); + assertThat(latestToolInvocationCount).isEqualTo(initialToolInvocationCount + 1); + + assertThat(Assistant2.ARGUMENT_ERROR_HANDLER_CALLED).isTrue(); + } + + @RegisterAiService(chatLanguageModelSupplier = TestChatModelSupplier.class) + interface Assistant { + + AtomicBoolean ARGUMENT_ERROR_HANDLER_CALLED = new AtomicBoolean(false); + + @ToolBox(Tools.class) + String chat(String userMessage); + + @HandleToolArgumentError + static ToolErrorHandlerResult handle(ToolErrorContext c, Exception e) { + ARGUMENT_ERROR_HANDLER_CALLED.set(true); + return ToolErrorHandlerResult.text(e.getMessage() + c.invocationContext().toString()); + } + + } + + @RegisterAiService(chatLanguageModelSupplier = TestChatModelSupplier.class) + interface Assistant2 { + + AtomicBoolean ARGUMENT_ERROR_HANDLER_CALLED = new AtomicBoolean(false); + + @ToolBox(Tools.class) + String chat(String userMessage); + + @HandleToolArgumentError + static String handle() { + ARGUMENT_ERROR_HANDLER_CALLED.set(true); + return "boom"; + } + + } + + @ApplicationScoped + public static final class Tools { + + public static final AtomicInteger INVOCATION_COUNTER = new AtomicInteger(0); + + @Tool + String getWeather(String ignored) { + INVOCATION_COUNTER.incrementAndGet(); + return "Sunny"; + } + } + + @ApplicationScoped + public static class TestChatModelSupplier implements Supplier { + @Override + public ChatModel get() { + // given + ToolExecutionRequest toolExecutionRequest1 = ToolExecutionRequest.builder() + .name("getWeather") + .arguments("{ invalid json }") + .build(); + + ToolExecutionRequest toolExecutionRequest2 = ToolExecutionRequest.builder() + .name("getWeather") + .arguments("{\"arg0\":\"Munich\"}") + .build(); + + return ChatModelMock.thatAlwaysResponds( + AiMessage.from(toolExecutionRequest1), + AiMessage.from(toolExecutionRequest2), + AiMessage.from("sunny")); + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ToolExecutionErrorHandlerTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ToolExecutionErrorHandlerTest.java new file mode 100644 index 000000000..355fc3340 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ToolExecutionErrorHandlerTest.java @@ -0,0 +1,135 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.mock.ChatModelMock; +import dev.langchain4j.service.tool.ToolErrorContext; +import dev.langchain4j.service.tool.ToolErrorHandlerResult; +import io.quarkiverse.langchain4j.HandleToolExecutionError; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.ToolBox; +import io.quarkus.test.QuarkusUnitTest; + +public class ToolExecutionErrorHandlerTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); + + @Inject + Assistant assistant; + + @Inject + Assistant2 assistant2; + + @Test + @ActivateRequestContext + void assistant() { + assertThat(Assistant.EXECUTION_ERROR_HANDLER_CALLED).isFalse(); + int initialToolInvocationCount = Tools.INVOCATION_COUNTER.get(); + + String hello = assistant.chat("hello"); + assertThat(hello).isNotNull(); + + int latestToolInvocationCount = Tools.INVOCATION_COUNTER.get(); + assertThat(latestToolInvocationCount).isEqualTo(initialToolInvocationCount + 1); + + assertThat(Assistant.EXECUTION_ERROR_HANDLER_CALLED).isTrue(); + } + + @Test + @ActivateRequestContext + void assistant2() { + assertThat(Assistant2.EXECUTION_ERROR_HANDLER_CALLED).isFalse(); + int initialToolInvocationCount = Tools.INVOCATION_COUNTER.get(); + + String hello = assistant2.chat("hello"); + assertThat(hello).isNotNull(); + + int latestToolInvocationCount = Tools.INVOCATION_COUNTER.get(); + assertThat(latestToolInvocationCount).isEqualTo(initialToolInvocationCount + 1); + + assertThat(Assistant2.EXECUTION_ERROR_HANDLER_CALLED).isTrue(); + } + + @RegisterAiService(chatLanguageModelSupplier = TestChatModelSupplier.class) + interface Assistant { + + AtomicBoolean EXECUTION_ERROR_HANDLER_CALLED = new AtomicBoolean(false); + + @ToolBox(Tools.class) + String chat(String userMessage); + + @HandleToolExecutionError + static ToolErrorHandlerResult handle(ToolErrorContext c, Exception e) { + assertThat(e).isInstanceOf(DummyException.class); + EXECUTION_ERROR_HANDLER_CALLED.set(true); + return ToolErrorHandlerResult.text(e.getMessage() + c.invocationContext().toString()); + } + + } + + @RegisterAiService(chatLanguageModelSupplier = TestChatModelSupplier.class) + interface Assistant2 { + + AtomicBoolean EXECUTION_ERROR_HANDLER_CALLED = new AtomicBoolean(false); + + @ToolBox(Tools.class) + String chat(String userMessage); + + @HandleToolExecutionError + static String handle() { + EXECUTION_ERROR_HANDLER_CALLED.set(true); + return "boom"; + } + + } + + @ApplicationScoped + public static final class Tools { + + public static final AtomicInteger INVOCATION_COUNTER = new AtomicInteger(0); + + @Tool + String getWeather(String ignored) { + INVOCATION_COUNTER.incrementAndGet(); + throw new DummyException(); + } + } + + public static class DummyException extends RuntimeException { + + } + + @ApplicationScoped + public static class TestChatModelSupplier implements Supplier { + @Override + public ChatModel get() { + ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .name("getWeather") + .arguments("{\"arg0\":\"Munich\"}") + .build(); + + return ChatModelMock.thatAlwaysResponds( + AiMessage.from(toolExecutionRequest), + AiMessage.from("I was not able to get the weather")); + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelTest.java index 19c8de5b8..d0fd44e02 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelTest.java @@ -41,6 +41,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.ToolBox; +import io.quarkiverse.langchain4j.runtime.BlockingToolNotAllowedException; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkiverse.langchain4j.test.Lists; import io.quarkus.arc.Arc; @@ -80,7 +81,7 @@ void testBlockingToolInvocationFromEventLoop() { try { Arc.container().requestContext().activate(); aiService.hello("abc", "hi - " + uuid); - } catch (IllegalStateException e) { + } catch (BlockingToolNotAllowedException e) { failure.set(e); } finally { Arc.container().requestContext().deactivate(); @@ -249,7 +250,7 @@ void testToolInvocationOnVirtualThreadFromEventLoop() { try { Arc.container().requestContext().activate(); aiService.hello("abc", "hiVirtualThread - " + uuid); - } catch (IllegalStateException e) { + } catch (BlockingToolNotAllowedException e) { failure.set(e); } finally { Arc.container().requestContext().deactivate(); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/HandleToolArgumentError.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/HandleToolArgumentError.java new file mode 100644 index 000000000..914c2cbe2 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/HandleToolArgumentError.java @@ -0,0 +1,26 @@ +package io.quarkiverse.langchain4j; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import dev.langchain4j.exception.ToolArgumentsException; +import dev.langchain4j.service.tool.ToolArgumentsErrorHandler; +import dev.langchain4j.service.tool.ToolErrorContext; +import dev.langchain4j.service.tool.ToolErrorHandlerResult; + +/** + * Can be used on a static method of an AI Service interface registered with {@link RegisterAiService} + * to handle {@link ToolArgumentsException}. + *

+ * The method can specify {@link Throwable} and/or {@link ToolErrorContext} as parameters. + * The return type of the method must be either {@link String} or {@link ToolErrorHandlerResult} + *

+ * See also: {@link ToolArgumentsErrorHandler} + */ +@Retention(RUNTIME) +@Target({ METHOD }) +public @interface HandleToolArgumentError { +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/HandleToolExecutionError.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/HandleToolExecutionError.java new file mode 100644 index 000000000..240387eb1 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/HandleToolExecutionError.java @@ -0,0 +1,26 @@ +package io.quarkiverse.langchain4j; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import dev.langchain4j.exception.ToolExecutionException; +import dev.langchain4j.service.tool.ToolErrorContext; +import dev.langchain4j.service.tool.ToolErrorHandlerResult; +import dev.langchain4j.service.tool.ToolExecutionErrorHandler; + +/** + * Can be used on a static method of an AI Service interface registered with {@link RegisterAiService} + * to handle {@link ToolExecutionException}. + *

+ * The method can specify {@link Throwable} and/or {@link ToolErrorContext} as parameters. + * The return type of the method must be either {@link String} or {@link ToolErrorHandlerResult} + *

+ * See also: {@link ToolExecutionErrorHandler} + */ +@Retention(RUNTIME) +@Target({ METHOD }) +public @interface HandleToolExecutionError { +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/ToolGuardrailException.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/ToolGuardrailException.java index e12919844..8bbf39eee 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/ToolGuardrailException.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/ToolGuardrailException.java @@ -1,5 +1,7 @@ package io.quarkiverse.langchain4j.guardrails; +import io.quarkiverse.langchain4j.runtime.PreventsErrorHandlerExecution; + /** * Exception thrown when tool guardrail validation fails critically. *

@@ -23,7 +25,7 @@ * @see ToolInputGuardrailResult * @see ToolOutputGuardrailResult */ -public class ToolGuardrailException extends RuntimeException { +public class ToolGuardrailException extends RuntimeException implements PreventsErrorHandlerExecution { private final boolean fatal; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java index e0d7696e9..d388305ec 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java @@ -21,6 +21,8 @@ import dev.langchain4j.model.image.ImageModel; import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.service.tool.ToolArgumentsErrorHandler; +import dev.langchain4j.service.tool.ToolExecutionErrorHandler; import dev.langchain4j.service.tool.ToolProvider; import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.RegisterAiService; @@ -117,8 +119,7 @@ public Function, QuarkusAiSe @Override public QuarkusAiServiceContext apply(SyntheticCreationalContext creationalContext) { try { - Class serviceClass = Thread.currentThread().getContextClassLoader() - .loadClass(info.serviceClassName()); + Class serviceClass = loadClass(info.serviceClassName()); QuarkusAiServiceContext aiServiceContext = new QuarkusAiServiceContext(serviceClass); // we don't really care about QuarkusAiServices here, all we care about is that it @@ -174,11 +175,11 @@ public QuarkusAiServiceContext apply(SyntheticCreationalContext toolProviderClass = Thread.currentThread().getContextClassLoader() - .loadClass(info.toolProviderSupplier()); + Class toolProviderClass = loadClass(info.toolProviderSupplier()); Supplier toolProvider = (Supplier) creationalContext .getInjectedReference(toolProviderClass); quarkusAiServices.toolProvider(toolProvider.get()); @@ -227,9 +240,8 @@ public QuarkusAiServiceContext apply(SyntheticCreationalContext supplier = (Supplier) Thread - .currentThread().getContextClassLoader() - .loadClass(info.chatMemoryProviderSupplierClassName()) + Supplier supplier = (Supplier) loadClass( + info.chatMemoryProviderSupplierClassName()) .getConstructor().newInstance(); quarkusAiServices.chatMemoryProvider(supplier.get()); } @@ -246,14 +258,12 @@ public QuarkusAiServiceContext apply(SyntheticCreationalContext instance = (Supplier) creationalContext - .getInjectedReference(Thread.currentThread().getContextClassLoader() - .loadClass(info.retrievalAugmentorSupplierClassName())); + .getInjectedReference(loadClass(info.retrievalAugmentorSupplierClassName())); quarkusAiServices.retrievalAugmentor(instance.get()); } catch (IllegalArgumentException e) { // the provided Supplier is not a CDI bean, build it manually - Supplier supplier = (Supplier) Thread - .currentThread().getContextClassLoader() - .loadClass(info.retrievalAugmentorSupplierClassName()) + Supplier supplier = (Supplier) loadClass( + info.retrievalAugmentorSupplierClassName()) .getConstructor().newInstance(); quarkusAiServices.retrievalAugmentor(supplier.get()); } @@ -273,9 +283,8 @@ public QuarkusAiServiceContext apply(SyntheticCreationalContext supplier = (Supplier) Thread - .currentThread().getContextClassLoader() - .loadClass(info.moderationModelSupplierClassName()) + Supplier supplier = (Supplier) loadClass( + info.moderationModelSupplierClassName()) .getConstructor().newInstance(); quarkusAiServices.moderationModel(supplier.get()); } @@ -294,18 +303,16 @@ public QuarkusAiServiceContext apply(SyntheticCreationalContext supplier = (Supplier) Thread - .currentThread().getContextClassLoader() - .loadClass(info.imageModelSupplierClassName()) + Supplier supplier = (Supplier) loadClass( + info.imageModelSupplierClassName()) .getConstructor().newInstance(); quarkusAiServices.imageModel(supplier.get()); } } if (info.chatMemorySeederClassName() != null) { - quarkusAiServices.chatMemorySeeder((ChatMemorySeeder) Thread - .currentThread().getContextClassLoader() - .loadClass(info.chatMemorySeederClassName()) + quarkusAiServices.chatMemorySeeder((ChatMemorySeeder) loadClass( + info.chatMemorySeederClassName()) .getConstructor().newInstance()); } if (info.maxSequentialToolInvocations() != null && info.maxSequentialToolInvocations() > 0) { @@ -326,6 +333,11 @@ public QuarkusAiServiceContext apply(SyntheticCreationalContext loadClass(String info) throws ClassNotFoundException { + return Thread.currentThread().getContextClassLoader() + .loadClass(info); + } }; } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/BlockingToolNotAllowedException.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/BlockingToolNotAllowedException.java new file mode 100644 index 000000000..579555456 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/BlockingToolNotAllowedException.java @@ -0,0 +1,8 @@ +package io.quarkiverse.langchain4j.runtime; + +public class BlockingToolNotAllowedException extends RuntimeException implements PreventsErrorHandlerExecution { + + public BlockingToolNotAllowedException(String message) { + super(message); + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/PreventsErrorHandlerExecution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/PreventsErrorHandlerExecution.java new file mode 100644 index 000000000..c15c97cfd --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/PreventsErrorHandlerExecution.java @@ -0,0 +1,9 @@ +package io.quarkiverse.langchain4j.runtime; + +import dev.langchain4j.service.tool.ToolExecutionErrorHandler; + +/** + * Marker interface that prevents the {@link ToolExecutionErrorHandler} from being engaged + */ +public interface PreventsErrorHandlerExecution { +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index d9b19b897..87bdd90c0 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -52,6 +52,7 @@ import dev.langchain4j.data.message.VideoContent; import dev.langchain4j.data.pdf.PdfFile; import dev.langchain4j.data.video.Video; +import dev.langchain4j.exception.ToolArgumentsException; import dev.langchain4j.guardrail.ChatExecutor; import dev.langchain4j.guardrail.GuardrailRequestParams; import dev.langchain4j.invocation.InvocationContext; @@ -86,8 +87,11 @@ import dev.langchain4j.service.IllegalConfigurationException; import dev.langchain4j.service.Result; import dev.langchain4j.service.output.ServiceOutputParser; +import dev.langchain4j.service.tool.ToolArgumentsErrorHandler; +import dev.langchain4j.service.tool.ToolErrorContext; import dev.langchain4j.service.tool.ToolErrorHandlerResult; import dev.langchain4j.service.tool.ToolExecution; +import dev.langchain4j.service.tool.ToolExecutionErrorHandler; import dev.langchain4j.service.tool.ToolExecutionResult; import dev.langchain4j.service.tool.ToolExecutor; import dev.langchain4j.service.tool.ToolProviderRequest; @@ -99,6 +103,7 @@ import io.quarkiverse.langchain4j.VideoUrl; import io.quarkiverse.langchain4j.response.ResponseAugmenterParams; import io.quarkiverse.langchain4j.runtime.ContextLocals; +import io.quarkiverse.langchain4j.runtime.PreventsErrorHandlerExecution; import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser; import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil; import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailsSupport.OutputGuardrailStreamingMapper; @@ -432,7 +437,8 @@ private static Object doImplement0(AiServiceMethodCreateInfo methodCreateInfo, I ToolExecutionResult toolExecutionResult = toolExecutor == null ? context.toolService.applyToolHallucinationStrategy(toolExecutionRequest) - : executeTool(toolExecutionRequest, toolExecutor, invocationContext); + : executeTool(toolExecutionRequest, toolExecutor, invocationContext, + context.toolService.argumentsErrorHandler(), context.toolService.executionErrorHandler()); // New firing context.eventListenerRegistrar.fireEvent( @@ -608,8 +614,49 @@ private static InvocationParameters findInvocationParams(Object[] args) { } private static ToolExecutionResult executeTool(ToolExecutionRequest toolExecutionRequest, ToolExecutor toolExecutor, - InvocationContext invocationContext) { - ToolExecutionResult toolExecutionResult = toolExecutor.executeWithContext(toolExecutionRequest, invocationContext); + InvocationContext invocationContext, + ToolArgumentsErrorHandler toolArgumentsErrorHandler, + ToolExecutionErrorHandler toolExecutionErrorHandler) { + ToolExecutionResult toolExecutionResult; + try { + toolExecutionResult = toolExecutor.executeWithContext(toolExecutionRequest, invocationContext); + } catch (ToolArgumentsException e) { + if (toolArgumentsErrorHandler != null) { + log.debugv(e, "Error occurred while executing tool arguments. Executing ", + toolArgumentsErrorHandler.getClass().getName() + "' to handle it"); + ToolErrorContext errorContext = ToolErrorContext.builder() + .toolExecutionRequest(toolExecutionRequest) + .invocationContext(invocationContext) + .build(); + ToolErrorHandlerResult toolErrorHandlerResult = toolArgumentsErrorHandler.handle(e, errorContext); + return ToolExecutionResult.builder() + .isError(true) + .resultText(toolErrorHandlerResult.text()) + .build(); + } else { + throw e; + } + } catch (Exception e) { + if (e instanceof PreventsErrorHandlerExecution) { + // preserve semantics for existing code + throw e; + } + if (toolExecutionErrorHandler != null) { + log.debugv(e, "Error occurred while executing tool. Executing '", + toolExecutionErrorHandler.getClass().getName() + "' to handle it"); + ToolErrorContext errorContext = ToolErrorContext.builder() + .toolExecutionRequest(toolExecutionRequest) + .invocationContext(invocationContext) + .build(); + ToolErrorHandlerResult toolErrorHandlerResult = toolExecutionErrorHandler.handle(e, errorContext); + return ToolExecutionResult.builder() + .isError(true) + .resultText(toolErrorHandlerResult.text()) + .build(); + } else { + throw e; + } + } log.debugv("Result of {0} is '{1}'", toolExecutionRequest, toolExecutionResult); return toolExecutionResult; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java index f5f4cd112..907508aff 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java @@ -25,6 +25,8 @@ public record DeclarativeAiServiceCreateInfo( boolean needsModerationModel, boolean needsImageModel, String toolHallucinationStrategyClassName, + String toolArgumentsErrorHandlerClassName, + String toolExecutionErrorHandlerClassName, InputGuardrailsLiteral inputGuardrails, OutputGuardrailsLiteral outputGuardrails, Integer maxSequentialToolInvocations, diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/QuarkusToolExecutor.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/QuarkusToolExecutor.java index 916b6689c..35bd79aba 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/QuarkusToolExecutor.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/QuarkusToolExecutor.java @@ -13,12 +13,13 @@ import dev.langchain4j.agent.tool.ReturnBehavior; import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.exception.ToolExecutionException; +import dev.langchain4j.exception.ToolArgumentsException; import dev.langchain4j.internal.Json; import dev.langchain4j.invocation.InvocationContext; import dev.langchain4j.service.tool.ToolExecutionResult; import dev.langchain4j.service.tool.ToolExecutor; import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; +import io.quarkiverse.langchain4j.runtime.BlockingToolNotAllowedException; import io.quarkiverse.langchain4j.runtime.prompt.Mappable; import io.quarkus.virtual.threads.VirtualThreadsRecorder; import io.smallrye.mutiny.Uni; @@ -79,14 +80,14 @@ public ToolExecutionResult executeWithContext(ToolExecutionRequest request, Invo switch (context.executionModel) { case BLOCKING: if (io.vertx.core.Context.isOnEventLoopThread()) { - throw new IllegalStateException("Cannot execute blocking tools on event loop thread"); + throw new BlockingToolNotAllowedException("Cannot execute blocking tools on event loop thread"); } return invoke(params, invokerInstance); case NON_BLOCKING: return invoke(params, invokerInstance); case VIRTUAL_THREAD: if (io.vertx.core.Context.isOnEventLoopThread()) { - throw new IllegalStateException("Cannot execute virtual thread tools on event loop thread"); + throw new BlockingToolNotAllowedException("Cannot execute virtual thread tools on event loop thread"); } try { return VirtualThreadsRecorder.getCurrent().submit(() -> invoke(params, invokerInstance)) @@ -118,7 +119,7 @@ private ToolExecutionResult invoke(Object[] params, ToolInvoker invokerInstance) String result; if (invocationResult instanceof Uni) { // TODO CS if (io.vertx.core.Context.isOnEventLoopThread()) { - throw new ToolExecutionException( + throw new BlockingToolNotAllowedException( "Cannot execute tools returning Uni on event loop thread due to a tool executor limitation"); } result = handleResult(invokerInstance, ((Uni) invocationResult).await().indefinitely()); @@ -127,17 +128,17 @@ private ToolExecutionResult invoke(Object[] params, ToolInvoker invokerInstance) } log.debugv("Tool execution result: {0}", result); return ToolExecutionResult.builder().result(invocationResult).resultText(result).build(); - } catch (ToolExecutionException e) { - throw e; } catch (Exception e) { - if (context.propagateToolExecutionExceptions) { - throw new ToolExecutionException(e); - } - log.error("Error while executing tool '" + context.tool.getClass() + "'", e); - return ToolExecutionResult.builder().isError(true).resultText(e.getMessage()).build(); + sneakyThrow(e); + // keep the compiler happy + return null; } } + private static void sneakyThrow(Throwable e) throws E { + throw (E) e; + } + private static String handleResult(ToolInvoker invokerInstance, Object invocationResult) { if (invokerInstance.methodMetadata().isReturnsVoid()) { return "Success"; @@ -221,7 +222,7 @@ private Class loadMapperClass() { } private void invalidMethodParams(String argumentsJsonStr) { - throw new IllegalArgumentException("params '" + argumentsJsonStr + throw new ToolArgumentsException("params '" + argumentsJsonStr + "' from request do not map onto the parameters needed by '" + context.tool.getClass().getName() + "#" + context.methodName + "'"); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/guardrails/ToolGuardrailsWrapper.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/guardrails/ToolGuardrailsWrapper.java index ee175c46b..f67c99ddb 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/guardrails/ToolGuardrailsWrapper.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/guardrails/ToolGuardrailsWrapper.java @@ -8,13 +8,13 @@ import org.jboss.logging.Logger; import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.exception.ToolExecutionException; import dev.langchain4j.invocation.InvocationContext; import dev.langchain4j.service.tool.ToolExecutionResult; import io.quarkiverse.langchain4j.guardrails.ToolGuardrailException; import io.quarkiverse.langchain4j.guardrails.ToolInputGuardrail; import io.quarkiverse.langchain4j.guardrails.ToolInvocationContext; import io.quarkiverse.langchain4j.guardrails.ToolOutputGuardrail; +import io.quarkiverse.langchain4j.runtime.BlockingToolNotAllowedException; import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor; import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo; @@ -68,7 +68,7 @@ public ToolExecutionResult wrap( // Check if we're on the Vert.x event loop // If so, dispatch guardrail execution to worker thread to prevent blocking if (io.vertx.core.Context.isOnEventLoopThread()) { - throw new ToolExecutionException( + throw new BlockingToolNotAllowedException( "Cannot execute guardrails tools on the event loop thread. Make sure your tool function is marked or detected as blocking."); }