Skip to content

Commit fdd35d9

Browse files
committed
Align with upstream LangChai4j error handling
This is done by introducing support for `@HandleToolExecutionError` and `@HandleToolArgumentError` Closes: #2008
1 parent dc40f6f commit fdd35d9

20 files changed

+609
-48
lines changed

core/deployment/pom.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@
121121
<artifactId>quarkus-test-vertx</artifactId>
122122
<scope>test</scope>
123123
</dependency>
124+
<dependency>
125+
<groupId>dev.langchain4j</groupId>
126+
<artifactId>langchain4j-core</artifactId>
127+
<version>${langchain4j.version}</version>
128+
<classifier>tests</classifier>
129+
<type>test-jar</type>
130+
<scope>test</scope>
131+
</dependency>
124132
</dependencies>
125133
<build>
126134
<plugins>

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import jakarta.enterprise.inject.spi.DeploymentException;
4646
import jakarta.enterprise.util.AnnotationLiteral;
4747
import jakarta.inject.Inject;
48+
import jakarta.inject.Singleton;
4849
import jakarta.interceptor.InterceptorBinding;
4950

5051
import org.eclipse.microprofile.config.ConfigProvider;
@@ -77,6 +78,10 @@
7778
import dev.langchain4j.service.memory.ChatMemoryAccess;
7879
import dev.langchain4j.service.output.JsonSchemas;
7980
import dev.langchain4j.service.output.ServiceOutputParser;
81+
import dev.langchain4j.service.tool.ToolArgumentsErrorHandler;
82+
import dev.langchain4j.service.tool.ToolErrorContext;
83+
import dev.langchain4j.service.tool.ToolErrorHandlerResult;
84+
import dev.langchain4j.service.tool.ToolExecutionErrorHandler;
8085
import dev.langchain4j.spi.classloading.ClassInstanceFactory;
8186
import dev.langchain4j.spi.classloading.ClassMetadataProviderFactory;
8287
import io.quarkiverse.langchain4j.ModelName;
@@ -341,6 +346,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
341346
BuildProducer<DeclarativeAiServiceBuildItem> declarativeAiServiceProducer,
342347
BuildProducer<ToolProviderMetaBuildItem> toolProviderProducer,
343348
BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer,
349+
BuildProducer<GeneratedBeanBuildItem> generatedBeanProducer,
344350
BuildProducer<GeneratedClassBuildItem> generatedClassProducer) {
345351
IndexView index = indexBuildItem.getIndex();
346352

@@ -496,6 +502,8 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
496502
toolHallucinationStrategy(instance),
497503
classInputGuardrails(declarativeAiServiceClassInfo, index),
498504
classOutputGuardrails(declarativeAiServiceClassInfo, index),
505+
toolArgumentsErrorHandlerDotName(declarativeAiServiceClassInfo, generatedBeanProducer),
506+
toolExecutionErrorHandlerDotName(declarativeAiServiceClassInfo, generatedBeanProducer),
499507
maxSequentialToolInvocations,
500508
allowContinuousForcedToolCalling,
501509
// we need to make these @DefaultBean because there could be other CDI beans of the same type that need to take precedence
@@ -647,6 +655,97 @@ private static OutputGuardrailsLiteral classOutputGuardrails(DeclarativeAiServic
647655
declarativeAiServiceBuildItem.getOutputGuardrails().maxRetries());
648656
}
649657

658+
private DotName toolArgumentsErrorHandlerDotName(ClassInfo aiServiceClassInfo,
659+
BuildProducer<GeneratedBeanBuildItem> generatedBean) {
660+
return toolErrorHandlerDotName(aiServiceClassInfo, LangChain4jDotNames.HANDLE_TOOL_ARGUMENT_ERROR, generatedBean,
661+
ToolArgumentsErrorHandler.class);
662+
}
663+
664+
private DotName toolExecutionErrorHandlerDotName(ClassInfo aiServiceClassInfo,
665+
BuildProducer<GeneratedBeanBuildItem> generatedBean) {
666+
return toolErrorHandlerDotName(aiServiceClassInfo, LangChain4jDotNames.HANDLE_TOOL_EXECUTION_ERROR, generatedBean,
667+
ToolExecutionErrorHandler.class);
668+
}
669+
670+
private DotName toolErrorHandlerDotName(ClassInfo aiServiceClassInfo, DotName annotationName,
671+
BuildProducer<GeneratedBeanBuildItem> generatedBean,
672+
Class<?> interfaceType) {
673+
List<AnnotationInstance> instances = aiServiceClassInfo.annotations(annotationName);
674+
if (instances.isEmpty()) {
675+
return null;
676+
}
677+
if (instances.size() > 1) {
678+
throw new IllegalConfigurationException(
679+
"`@%s` can be used only once in an AI Service. Offending class is '%s'".formatted(annotationName,
680+
aiServiceClassInfo.name()));
681+
}
682+
AnnotationTarget target = instances.get(0).target();
683+
if (target.kind() != AnnotationTarget.Kind.METHOD) {
684+
throw new IllegalConfigurationException(
685+
"`@%s` can be used only methods. Offending class is '%s'".formatted(annotationName,
686+
aiServiceClassInfo.name()));
687+
}
688+
MethodInfo targetMethod = target.asMethod();
689+
if (!Modifier.isStatic(targetMethod.flags())) {
690+
throw new IllegalConfigurationException(
691+
"`@%s` can be used only on static methods. Offending class is '%s'".formatted(annotationName,
692+
aiServiceClassInfo.name()));
693+
}
694+
DotName returnType = targetMethod.returnType().name();
695+
if ((!returnType.equals(DotNames.STRING)) && !returnType.equals(LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT)) {
696+
throw new IllegalConfigurationException(
697+
"`@%s` can be used only on static methods that return '%s' or '%s'. Offending class is '%s'"
698+
.formatted(annotationName,
699+
DotNames.STRING, LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT, aiServiceClassInfo.name()));
700+
}
701+
702+
ClassOutput output = new GeneratedBeanGizmoAdaptor(generatedBean);
703+
String generatedClassName = aiServiceClassInfo.name().toString() + "$" + annotationName.withoutPackagePrefix();
704+
ClassCreator.Builder classCreatorBuilder = ClassCreator.builder()
705+
.classOutput(output)
706+
.interfaces(interfaceType)
707+
.className(generatedClassName);
708+
try (ClassCreator classCreator = classCreatorBuilder.build()) {
709+
classCreator.addAnnotation(Singleton.class);
710+
711+
MethodCreator handleMethod = classCreator.getMethodCreator(MethodDescriptor.ofMethod(generatedClassName, "handle",
712+
ToolErrorHandlerResult.class, Throwable.class, ToolErrorContext.class));
713+
714+
List<ResultHandle> paramHandles = new ArrayList<>();
715+
for (MethodParameterInfo parameter : targetMethod.parameters()) {
716+
DotName paramTypeDotName = parameter.type().name();
717+
if (paramTypeDotName.equals(DotNames.THROWABLE) || paramTypeDotName.equals(DotNames.EXCEPTION)) {
718+
paramHandles.add(handleMethod.getMethodParam(0));
719+
} else if (paramTypeDotName.equals(LangChain4jDotNames.TOOL_ERROR_CONTEXT)) {
720+
paramHandles.add(handleMethod.getMethodParam(1));
721+
} else {
722+
throw new IllegalConfigurationException(
723+
"`@%s` can be used only on static methods that use the parameters of type '%s' or '%s'. Offending class is '%s'"
724+
.formatted(annotationName,
725+
DotNames.THROWABLE, LangChain4jDotNames.TOOL_ERROR_CONTEXT,
726+
aiServiceClassInfo.name()));
727+
}
728+
}
729+
730+
ResultHandle result = handleMethod.invokeStaticInterfaceMethod(targetMethod,
731+
paramHandles.toArray(new ResultHandle[0]));
732+
if (returnType.equals(LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT)) {
733+
handleMethod.returnValue(result);
734+
} else if (returnType.equals(DotNames.STRING)) {
735+
ResultHandle toolErrorHandlerResultResult = handleMethod.invokeStaticMethod(
736+
MethodDescriptor.ofMethod(ToolErrorHandlerResult.class, "text", ToolErrorHandlerResult.class,
737+
String.class),
738+
result);
739+
handleMethod.returnValue(toolErrorHandlerResultResult);
740+
} else {
741+
throw new IllegalStateException("Unhandled result type: " + returnType);
742+
}
743+
744+
}
745+
746+
return DotName.createSimple(generatedClassName);
747+
}
748+
650749
private static List<ClassInfo> tools(AnnotationInstance instance, IndexView index) {
651750
AnnotationValue toolsInstance = instance.value("tools");
652751
if (toolsInstance != null) {
@@ -792,6 +891,14 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
792891
allToolHallucinationStrategies.add(bi.getToolHallucinationStrategyClassDotName());
793892
}
794893

894+
String toolArgumentsErrorHandlerDotName = (bi.getToolArgumentsErrorHandlerDotName() != null
895+
? bi.getToolArgumentsErrorHandlerDotName().toString()
896+
: null);
897+
898+
String toolExecutionErrorHandlerDotName = (bi.getToolExecutionErrorHandlerDotName() != null
899+
? bi.getToolExecutionErrorHandlerDotName().toString()
900+
: null);
901+
795902
String chatMemoryProviderSupplierClassName = bi.getChatMemoryProviderSupplierClassDotName() != null
796903
? bi.getChatMemoryProviderSupplierClassDotName().toString()
797904
: null;
@@ -884,6 +991,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
884991
injectModerationModelBean,
885992
injectImageModel,
886993
toolHallucinationStrategyClassName,
994+
toolArgumentsErrorHandlerDotName,
995+
toolExecutionErrorHandlerDotName,
887996
classInputGuardrails(bi),
888997
classOutputGuardrails(bi),
889998
maxSequentialToolInvocations,
@@ -931,6 +1040,12 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
9311040
if (bi.getToolHallucinationStrategyClassDotName() != null) {
9321041
configurator.addInjectionPoint(ClassType.create(bi.getToolHallucinationStrategyClassDotName()));
9331042
}
1043+
if (bi.getToolArgumentsErrorHandlerDotName() != null) {
1044+
configurator.addInjectionPoint(ClassType.create(bi.getToolArgumentsErrorHandlerDotName()));
1045+
}
1046+
if (bi.getToolExecutionErrorHandlerDotName() != null) {
1047+
configurator.addInjectionPoint(ClassType.create(bi.getToolExecutionErrorHandlerDotName()));
1048+
}
9341049

9351050
if (LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) {
9361051
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MEMORY_PROVIDER));

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
3333
private final Optional<String> beanName;
3434
private final DeclarativeAiServiceInputGuardrails inputGuardrails;
3535
private final DeclarativeAiServiceOutputGuardrails outputGuardrails;
36+
private final DotName toolArgumentsErrorHandlerDotName;
37+
private final DotName toolExecutionErrorHandlerDotName;
3638
private final Integer maxSequentialToolInvocations;
3739
private final boolean allowContinuousForcedToolCalling;
3840
private final boolean makeDefaultBean;
@@ -57,6 +59,8 @@ public DeclarativeAiServiceBuildItem(
5759
DotName toolHallucinationStrategyClassDotName,
5860
DeclarativeAiServiceInputGuardrails inputGuardrails,
5961
DeclarativeAiServiceOutputGuardrails outputGuardrails,
62+
DotName toolArgumentsErrorHandlerDotName,
63+
DotName toolExecutionErrorHandlerDotName,
6064
Integer maxSequentialToolInvocations,
6165
boolean allowContinuousForcedToolCalling,
6266
boolean makeDefaultBean) {
@@ -79,6 +83,8 @@ public DeclarativeAiServiceBuildItem(
7983
this.toolHallucinationStrategyClassDotName = toolHallucinationStrategyClassDotName;
8084
this.inputGuardrails = inputGuardrails;
8185
this.outputGuardrails = outputGuardrails;
86+
this.toolArgumentsErrorHandlerDotName = toolArgumentsErrorHandlerDotName;
87+
this.toolExecutionErrorHandlerDotName = toolExecutionErrorHandlerDotName;
8288
this.maxSequentialToolInvocations = maxSequentialToolInvocations;
8389
this.allowContinuousForcedToolCalling = allowContinuousForcedToolCalling;
8490
this.makeDefaultBean = makeDefaultBean;
@@ -160,6 +166,14 @@ public DeclarativeAiServiceOutputGuardrails getOutputGuardrails() {
160166
return outputGuardrails;
161167
}
162168

169+
public DotName getToolArgumentsErrorHandlerDotName() {
170+
return toolArgumentsErrorHandlerDotName;
171+
}
172+
173+
public DotName getToolExecutionErrorHandlerDotName() {
174+
return toolExecutionErrorHandlerDotName;
175+
}
176+
163177
public boolean isMakeDefaultBean() {
164178
return makeDefaultBean;
165179
}

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ public class DotNames {
8181

8282
public static final DotName EXECUTOR = DotName.createSimple(Executor.class);
8383

84+
public static final DotName THROWABLE = DotName.createSimple(Throwable.class);
85+
public static final DotName EXCEPTION = DotName.createSimple(Exception.class);
86+
8487
public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class);
8588
public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class);
8689
public static final DotName TOOL = DotName.createSimple(Tool.class);

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import dev.langchain4j.data.message.AiMessage;
77
import dev.langchain4j.data.message.ChatMessage;
88
import dev.langchain4j.data.pdf.PdfFile;
9+
import dev.langchain4j.exception.ToolArgumentsException;
10+
import dev.langchain4j.exception.ToolExecutionException;
911
import dev.langchain4j.memory.ChatMemory;
1012
import dev.langchain4j.memory.chat.ChatMemoryProvider;
1113
import dev.langchain4j.model.chat.ChatModel;
@@ -29,11 +31,15 @@
2931
import dev.langchain4j.service.UserName;
3032
import dev.langchain4j.service.guardrail.InputGuardrails;
3133
import dev.langchain4j.service.guardrail.OutputGuardrails;
34+
import dev.langchain4j.service.tool.ToolErrorContext;
35+
import dev.langchain4j.service.tool.ToolErrorHandlerResult;
3236
import dev.langchain4j.service.tool.ToolProvider;
3337
import dev.langchain4j.web.search.WebSearchEngine;
3438
import dev.langchain4j.web.search.WebSearchTool;
3539
import io.quarkiverse.langchain4j.AudioUrl;
3640
import io.quarkiverse.langchain4j.CreatedAware;
41+
import io.quarkiverse.langchain4j.HandleToolArgumentError;
42+
import io.quarkiverse.langchain4j.HandleToolExecutionError;
3743
import io.quarkiverse.langchain4j.ImageUrl;
3844
import io.quarkiverse.langchain4j.ModelName;
3945
import io.quarkiverse.langchain4j.PdfUrl;
@@ -124,6 +130,9 @@ public class LangChain4jDotNames {
124130

125131
static final DotName SEED_MEMORY = DotName.createSimple(SeedMemory.class);
126132

133+
static final DotName HANDLE_TOOL_ARGUMENT_ERROR = DotName.createSimple(HandleToolArgumentError.class);
134+
static final DotName HANDLE_TOOL_EXECUTION_ERROR = DotName.createSimple(HandleToolExecutionError.class);
135+
127136
static final DotName WEB_SEARCH_TOOL = DotName.createSimple(WebSearchTool.class);
128137
static final DotName WEB_SEARCH_ENGINE = DotName.createSimple(WebSearchEngine.class);
129138
static final DotName IMAGE = DotName.createSimple(Image.class);
@@ -136,4 +145,8 @@ public class LangChain4jDotNames {
136145
public static final DotName MCP_TOOLBOX = DotName.createSimple("io.quarkiverse.langchain4j.mcp.runtime.McpToolBox");
137146
public static final DotName CHAT_EVENT = DotName.createSimple(ChatEvent.class);
138147
public static final DotName CHAT_MEMORY = DotName.createSimple(ChatMemory.class);
148+
public static final DotName TOOL_ERROR_HANDLER_RESULT = DotName.createSimple(ToolErrorHandlerResult.class);
149+
public static final DotName TOOL_ARGUMENTS_EXCEPTION = DotName.createSimple(ToolArgumentsException.class);
150+
public static final DotName TOOL_EXECUTION_EXCEPTION = DotName.createSimple(ToolExecutionException.class);
151+
public static final DotName TOOL_ERROR_CONTEXT = DotName.createSimple(ToolErrorContext.class);
139152
}

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import dev.langchain4j.agent.tool.Tool;
2020
import dev.langchain4j.agent.tool.ToolExecutionRequest;
2121
import dev.langchain4j.agent.tool.ToolSpecification;
22+
import dev.langchain4j.exception.ToolArgumentsException;
2223
import dev.langchain4j.service.tool.ToolExecutor;
2324
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
2425
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor;
@@ -272,7 +273,7 @@ private void executeAndExpectFailure(String arguments, String methodName) {
272273
ToolExecutor toolExecutor = getToolExecutor(methodName);
273274

274275
assertThatThrownBy(() -> toolExecutor.execute(request, null))
275-
.isExactlyInstanceOf(IllegalArgumentException.class);
276+
.isExactlyInstanceOf(ToolArgumentsException.class);
276277
}
277278

278279
private ToolExecutor getToolExecutor(String methodName) {

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailEventLoopBlockingTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import io.quarkiverse.langchain4j.guardrails.ToolOutputGuardrailRequest;
4242
import io.quarkiverse.langchain4j.guardrails.ToolOutputGuardrailResult;
4343
import io.quarkiverse.langchain4j.guardrails.ToolOutputGuardrails;
44+
import io.quarkiverse.langchain4j.runtime.BlockingToolNotAllowedException;
4445
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
4546
import io.quarkiverse.langchain4j.test.Lists;
4647
import io.quarkus.arc.Arc;
@@ -130,7 +131,7 @@ void toolWithInputGuardrail_throwsException_whenCalledFromEventLoop() throws Exc
130131
Throwable exception = exceptionRef.get();
131132
assertThat(exception)
132133
.isNotNull()
133-
.isInstanceOf(ToolExecutionException.class);
134+
.isInstanceOf(BlockingToolNotAllowedException.class);
134135

135136
// Our ToolGuardrailsWrapper should catch the event loop and throw a clear error
136137
assertThat(exception.getMessage())
@@ -169,7 +170,7 @@ void toolWithOutputGuardrail_throwsException_whenCalledFromEventLoop() throws Ex
169170
Throwable exception = exceptionRef.get();
170171
assertThat(exception)
171172
.isNotNull()
172-
.isInstanceOf(ToolExecutionException.class);
173+
.isInstanceOf(BlockingToolNotAllowedException.class);
173174

174175
assertThat(exception.getMessage())
175176
.contains("Cannot execute guardrails")

0 commit comments

Comments
 (0)