|
45 | 45 | import jakarta.enterprise.inject.spi.DeploymentException; |
46 | 46 | import jakarta.enterprise.util.AnnotationLiteral; |
47 | 47 | import jakarta.inject.Inject; |
| 48 | +import jakarta.inject.Singleton; |
48 | 49 | import jakarta.interceptor.InterceptorBinding; |
49 | 50 |
|
50 | 51 | import org.eclipse.microprofile.config.ConfigProvider; |
|
77 | 78 | import dev.langchain4j.service.memory.ChatMemoryAccess; |
78 | 79 | import dev.langchain4j.service.output.JsonSchemas; |
79 | 80 | import dev.langchain4j.service.output.ServiceOutputParser; |
| 81 | +import dev.langchain4j.service.tool.ToolArgumentsErrorHandler; |
| 82 | +import dev.langchain4j.service.tool.ToolErrorContext; |
| 83 | +import dev.langchain4j.service.tool.ToolErrorHandlerResult; |
| 84 | +import dev.langchain4j.service.tool.ToolExecutionErrorHandler; |
80 | 85 | import dev.langchain4j.spi.classloading.ClassInstanceFactory; |
81 | 86 | import dev.langchain4j.spi.classloading.ClassMetadataProviderFactory; |
82 | 87 | import io.quarkiverse.langchain4j.ModelName; |
@@ -341,6 +346,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, |
341 | 346 | BuildProducer<DeclarativeAiServiceBuildItem> declarativeAiServiceProducer, |
342 | 347 | BuildProducer<ToolProviderMetaBuildItem> toolProviderProducer, |
343 | 348 | BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer, |
| 349 | + BuildProducer<GeneratedBeanBuildItem> generatedBeanProducer, |
344 | 350 | BuildProducer<GeneratedClassBuildItem> generatedClassProducer) { |
345 | 351 | IndexView index = indexBuildItem.getIndex(); |
346 | 352 |
|
@@ -496,6 +502,8 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, |
496 | 502 | toolHallucinationStrategy(instance), |
497 | 503 | classInputGuardrails(declarativeAiServiceClassInfo, index), |
498 | 504 | classOutputGuardrails(declarativeAiServiceClassInfo, index), |
| 505 | + toolArgumentsErrorHandlerDotName(declarativeAiServiceClassInfo, generatedBeanProducer), |
| 506 | + toolExecutionErrorHandlerDotName(declarativeAiServiceClassInfo, generatedBeanProducer), |
499 | 507 | maxSequentialToolInvocations, |
500 | 508 | allowContinuousForcedToolCalling, |
501 | 509 | // we need to make these @DefaultBean because there could be other CDI beans of the same type that need to take precedence |
@@ -647,6 +655,97 @@ private static OutputGuardrailsLiteral classOutputGuardrails(DeclarativeAiServic |
647 | 655 | declarativeAiServiceBuildItem.getOutputGuardrails().maxRetries()); |
648 | 656 | } |
649 | 657 |
|
| 658 | + private DotName toolArgumentsErrorHandlerDotName(ClassInfo aiServiceClassInfo, |
| 659 | + BuildProducer<GeneratedBeanBuildItem> generatedBean) { |
| 660 | + return toolErrorHandlerDotName(aiServiceClassInfo, LangChain4jDotNames.HANDLE_TOOL_ARGUMENT_ERROR, generatedBean, |
| 661 | + ToolArgumentsErrorHandler.class); |
| 662 | + } |
| 663 | + |
| 664 | + private DotName toolExecutionErrorHandlerDotName(ClassInfo aiServiceClassInfo, |
| 665 | + BuildProducer<GeneratedBeanBuildItem> generatedBean) { |
| 666 | + return toolErrorHandlerDotName(aiServiceClassInfo, LangChain4jDotNames.HANDLE_TOOL_EXECUTION_ERROR, generatedBean, |
| 667 | + ToolExecutionErrorHandler.class); |
| 668 | + } |
| 669 | + |
| 670 | + private DotName toolErrorHandlerDotName(ClassInfo aiServiceClassInfo, DotName annotationName, |
| 671 | + BuildProducer<GeneratedBeanBuildItem> generatedBean, |
| 672 | + Class<?> interfaceType) { |
| 673 | + List<AnnotationInstance> instances = aiServiceClassInfo.annotations(annotationName); |
| 674 | + if (instances.isEmpty()) { |
| 675 | + return null; |
| 676 | + } |
| 677 | + if (instances.size() > 1) { |
| 678 | + throw new IllegalConfigurationException( |
| 679 | + "`@%s` can be used only once in an AI Service. Offending class is '%s'".formatted(annotationName, |
| 680 | + aiServiceClassInfo.name())); |
| 681 | + } |
| 682 | + AnnotationTarget target = instances.get(0).target(); |
| 683 | + if (target.kind() != AnnotationTarget.Kind.METHOD) { |
| 684 | + throw new IllegalConfigurationException( |
| 685 | + "`@%s` can be used only methods. Offending class is '%s'".formatted(annotationName, |
| 686 | + aiServiceClassInfo.name())); |
| 687 | + } |
| 688 | + MethodInfo targetMethod = target.asMethod(); |
| 689 | + if (!Modifier.isStatic(targetMethod.flags())) { |
| 690 | + throw new IllegalConfigurationException( |
| 691 | + "`@%s` can be used only on static methods. Offending class is '%s'".formatted(annotationName, |
| 692 | + aiServiceClassInfo.name())); |
| 693 | + } |
| 694 | + DotName returnType = targetMethod.returnType().name(); |
| 695 | + if ((!returnType.equals(DotNames.STRING)) && !returnType.equals(LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT)) { |
| 696 | + throw new IllegalConfigurationException( |
| 697 | + "`@%s` can be used only on static methods that return '%s' or '%s'. Offending class is '%s'" |
| 698 | + .formatted(annotationName, |
| 699 | + DotNames.STRING, LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT, aiServiceClassInfo.name())); |
| 700 | + } |
| 701 | + |
| 702 | + ClassOutput output = new GeneratedBeanGizmoAdaptor(generatedBean); |
| 703 | + String generatedClassName = aiServiceClassInfo.name().toString() + "$" + annotationName.withoutPackagePrefix(); |
| 704 | + ClassCreator.Builder classCreatorBuilder = ClassCreator.builder() |
| 705 | + .classOutput(output) |
| 706 | + .interfaces(interfaceType) |
| 707 | + .className(generatedClassName); |
| 708 | + try (ClassCreator classCreator = classCreatorBuilder.build()) { |
| 709 | + classCreator.addAnnotation(Singleton.class); |
| 710 | + |
| 711 | + MethodCreator handleMethod = classCreator.getMethodCreator(MethodDescriptor.ofMethod(generatedClassName, "handle", |
| 712 | + ToolErrorHandlerResult.class, Throwable.class, ToolErrorContext.class)); |
| 713 | + |
| 714 | + List<ResultHandle> paramHandles = new ArrayList<>(); |
| 715 | + for (MethodParameterInfo parameter : targetMethod.parameters()) { |
| 716 | + DotName paramTypeDotName = parameter.type().name(); |
| 717 | + if (paramTypeDotName.equals(DotNames.THROWABLE) || paramTypeDotName.equals(DotNames.EXCEPTION)) { |
| 718 | + paramHandles.add(handleMethod.getMethodParam(0)); |
| 719 | + } else if (paramTypeDotName.equals(LangChain4jDotNames.TOOL_ERROR_CONTEXT)) { |
| 720 | + paramHandles.add(handleMethod.getMethodParam(1)); |
| 721 | + } else { |
| 722 | + throw new IllegalConfigurationException( |
| 723 | + "`@%s` can be used only on static methods that use the parameters of type '%s' or '%s'. Offending class is '%s'" |
| 724 | + .formatted(annotationName, |
| 725 | + DotNames.THROWABLE, LangChain4jDotNames.TOOL_ERROR_CONTEXT, |
| 726 | + aiServiceClassInfo.name())); |
| 727 | + } |
| 728 | + } |
| 729 | + |
| 730 | + ResultHandle result = handleMethod.invokeStaticInterfaceMethod(targetMethod, |
| 731 | + paramHandles.toArray(new ResultHandle[0])); |
| 732 | + if (returnType.equals(LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT)) { |
| 733 | + handleMethod.returnValue(result); |
| 734 | + } else if (returnType.equals(DotNames.STRING)) { |
| 735 | + ResultHandle toolErrorHandlerResultResult = handleMethod.invokeStaticMethod( |
| 736 | + MethodDescriptor.ofMethod(ToolErrorHandlerResult.class, "text", ToolErrorHandlerResult.class, |
| 737 | + String.class), |
| 738 | + result); |
| 739 | + handleMethod.returnValue(toolErrorHandlerResultResult); |
| 740 | + } else { |
| 741 | + throw new IllegalStateException("Unhandled result type: " + returnType); |
| 742 | + } |
| 743 | + |
| 744 | + } |
| 745 | + |
| 746 | + return DotName.createSimple(generatedClassName); |
| 747 | + } |
| 748 | + |
650 | 749 | private static List<ClassInfo> tools(AnnotationInstance instance, IndexView index) { |
651 | 750 | AnnotationValue toolsInstance = instance.value("tools"); |
652 | 751 | if (toolsInstance != null) { |
@@ -792,6 +891,14 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, |
792 | 891 | allToolHallucinationStrategies.add(bi.getToolHallucinationStrategyClassDotName()); |
793 | 892 | } |
794 | 893 |
|
| 894 | + String toolArgumentsErrorHandlerDotName = (bi.getToolArgumentsErrorHandlerDotName() != null |
| 895 | + ? bi.getToolArgumentsErrorHandlerDotName().toString() |
| 896 | + : null); |
| 897 | + |
| 898 | + String toolExecutionErrorHandlerDotName = (bi.getToolExecutionErrorHandlerDotName() != null |
| 899 | + ? bi.getToolExecutionErrorHandlerDotName().toString() |
| 900 | + : null); |
| 901 | + |
795 | 902 | String chatMemoryProviderSupplierClassName = bi.getChatMemoryProviderSupplierClassDotName() != null |
796 | 903 | ? bi.getChatMemoryProviderSupplierClassDotName().toString() |
797 | 904 | : null; |
@@ -884,6 +991,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, |
884 | 991 | injectModerationModelBean, |
885 | 992 | injectImageModel, |
886 | 993 | toolHallucinationStrategyClassName, |
| 994 | + toolArgumentsErrorHandlerDotName, |
| 995 | + toolExecutionErrorHandlerDotName, |
887 | 996 | classInputGuardrails(bi), |
888 | 997 | classOutputGuardrails(bi), |
889 | 998 | maxSequentialToolInvocations, |
@@ -931,6 +1040,12 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, |
931 | 1040 | if (bi.getToolHallucinationStrategyClassDotName() != null) { |
932 | 1041 | configurator.addInjectionPoint(ClassType.create(bi.getToolHallucinationStrategyClassDotName())); |
933 | 1042 | } |
| 1043 | + if (bi.getToolArgumentsErrorHandlerDotName() != null) { |
| 1044 | + configurator.addInjectionPoint(ClassType.create(bi.getToolArgumentsErrorHandlerDotName())); |
| 1045 | + } |
| 1046 | + if (bi.getToolExecutionErrorHandlerDotName() != null) { |
| 1047 | + configurator.addInjectionPoint(ClassType.create(bi.getToolExecutionErrorHandlerDotName())); |
| 1048 | + } |
934 | 1049 |
|
935 | 1050 | if (LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) { |
936 | 1051 | configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MEMORY_PROVIDER)); |
|
0 commit comments