Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions core/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@
<artifactId>quarkus-test-vertx</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${langchain4j.version}</version>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import jakarta.enterprise.inject.spi.DeploymentException;
import jakarta.enterprise.util.AnnotationLiteral;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.interceptor.InterceptorBinding;

import org.eclipse.microprofile.config.ConfigProvider;
Expand Down Expand Up @@ -77,6 +78,10 @@
import dev.langchain4j.service.memory.ChatMemoryAccess;
import dev.langchain4j.service.output.JsonSchemas;
import dev.langchain4j.service.output.ServiceOutputParser;
import dev.langchain4j.service.tool.ToolArgumentsErrorHandler;
import dev.langchain4j.service.tool.ToolErrorContext;
import dev.langchain4j.service.tool.ToolErrorHandlerResult;
import dev.langchain4j.service.tool.ToolExecutionErrorHandler;
import dev.langchain4j.spi.classloading.ClassInstanceFactory;
import dev.langchain4j.spi.classloading.ClassMetadataProviderFactory;
import io.quarkiverse.langchain4j.ModelName;
Expand Down Expand Up @@ -341,6 +346,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
BuildProducer<DeclarativeAiServiceBuildItem> declarativeAiServiceProducer,
BuildProducer<ToolProviderMetaBuildItem> toolProviderProducer,
BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer,
BuildProducer<GeneratedBeanBuildItem> generatedBeanProducer,
BuildProducer<GeneratedClassBuildItem> generatedClassProducer) {
IndexView index = indexBuildItem.getIndex();

Expand Down Expand Up @@ -496,6 +502,8 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
toolHallucinationStrategy(instance),
classInputGuardrails(declarativeAiServiceClassInfo, index),
classOutputGuardrails(declarativeAiServiceClassInfo, index),
toolArgumentsErrorHandlerDotName(declarativeAiServiceClassInfo, generatedBeanProducer),
toolExecutionErrorHandlerDotName(declarativeAiServiceClassInfo, generatedBeanProducer),
maxSequentialToolInvocations,
allowContinuousForcedToolCalling,
// we need to make these @DefaultBean because there could be other CDI beans of the same type that need to take precedence
Expand Down Expand Up @@ -647,6 +655,97 @@ private static OutputGuardrailsLiteral classOutputGuardrails(DeclarativeAiServic
declarativeAiServiceBuildItem.getOutputGuardrails().maxRetries());
}

private DotName toolArgumentsErrorHandlerDotName(ClassInfo aiServiceClassInfo,
BuildProducer<GeneratedBeanBuildItem> generatedBean) {
return toolErrorHandlerDotName(aiServiceClassInfo, LangChain4jDotNames.HANDLE_TOOL_ARGUMENT_ERROR, generatedBean,
ToolArgumentsErrorHandler.class);
}

private DotName toolExecutionErrorHandlerDotName(ClassInfo aiServiceClassInfo,
BuildProducer<GeneratedBeanBuildItem> generatedBean) {
return toolErrorHandlerDotName(aiServiceClassInfo, LangChain4jDotNames.HANDLE_TOOL_EXECUTION_ERROR, generatedBean,
ToolExecutionErrorHandler.class);
}

private DotName toolErrorHandlerDotName(ClassInfo aiServiceClassInfo, DotName annotationName,
BuildProducer<GeneratedBeanBuildItem> generatedBean,
Class<?> interfaceType) {
List<AnnotationInstance> instances = aiServiceClassInfo.annotations(annotationName);
if (instances.isEmpty()) {
return null;
}
if (instances.size() > 1) {
throw new IllegalConfigurationException(
"`@%s` can be used only once in an AI Service. Offending class is '%s'".formatted(annotationName,
aiServiceClassInfo.name()));
}
AnnotationTarget target = instances.get(0).target();
if (target.kind() != AnnotationTarget.Kind.METHOD) {
throw new IllegalConfigurationException(
"`@%s` can be used only methods. Offending class is '%s'".formatted(annotationName,
aiServiceClassInfo.name()));
}
MethodInfo targetMethod = target.asMethod();
if (!Modifier.isStatic(targetMethod.flags())) {
throw new IllegalConfigurationException(
"`@%s` can be used only on static methods. Offending class is '%s'".formatted(annotationName,
aiServiceClassInfo.name()));
}
DotName returnType = targetMethod.returnType().name();
if ((!returnType.equals(DotNames.STRING)) && !returnType.equals(LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT)) {
throw new IllegalConfigurationException(
"`@%s` can be used only on static methods that return '%s' or '%s'. Offending class is '%s'"
.formatted(annotationName,
DotNames.STRING, LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT, aiServiceClassInfo.name()));
}

ClassOutput output = new GeneratedBeanGizmoAdaptor(generatedBean);
String generatedClassName = aiServiceClassInfo.name().toString() + "$" + annotationName.withoutPackagePrefix();
ClassCreator.Builder classCreatorBuilder = ClassCreator.builder()
.classOutput(output)
.interfaces(interfaceType)
.className(generatedClassName);
try (ClassCreator classCreator = classCreatorBuilder.build()) {
classCreator.addAnnotation(Singleton.class);

MethodCreator handleMethod = classCreator.getMethodCreator(MethodDescriptor.ofMethod(generatedClassName, "handle",
ToolErrorHandlerResult.class, Throwable.class, ToolErrorContext.class));

List<ResultHandle> paramHandles = new ArrayList<>();
for (MethodParameterInfo parameter : targetMethod.parameters()) {
DotName paramTypeDotName = parameter.type().name();
if (paramTypeDotName.equals(DotNames.THROWABLE) || paramTypeDotName.equals(DotNames.EXCEPTION)) {
paramHandles.add(handleMethod.getMethodParam(0));
} else if (paramTypeDotName.equals(LangChain4jDotNames.TOOL_ERROR_CONTEXT)) {
paramHandles.add(handleMethod.getMethodParam(1));
} else {
throw new IllegalConfigurationException(
"`@%s` can be used only on static methods that use the parameters of type '%s' or '%s'. Offending class is '%s'"
.formatted(annotationName,
DotNames.THROWABLE, LangChain4jDotNames.TOOL_ERROR_CONTEXT,
aiServiceClassInfo.name()));
}
}

ResultHandle result = handleMethod.invokeStaticInterfaceMethod(targetMethod,
paramHandles.toArray(new ResultHandle[0]));
if (returnType.equals(LangChain4jDotNames.TOOL_ERROR_HANDLER_RESULT)) {
handleMethod.returnValue(result);
} else if (returnType.equals(DotNames.STRING)) {
ResultHandle toolErrorHandlerResultResult = handleMethod.invokeStaticMethod(
MethodDescriptor.ofMethod(ToolErrorHandlerResult.class, "text", ToolErrorHandlerResult.class,
String.class),
result);
handleMethod.returnValue(toolErrorHandlerResultResult);
} else {
throw new IllegalStateException("Unhandled result type: " + returnType);
}

}

return DotName.createSimple(generatedClassName);
}

private static List<ClassInfo> tools(AnnotationInstance instance, IndexView index) {
AnnotationValue toolsInstance = instance.value("tools");
if (toolsInstance != null) {
Expand Down Expand Up @@ -792,6 +891,14 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
allToolHallucinationStrategies.add(bi.getToolHallucinationStrategyClassDotName());
}

String toolArgumentsErrorHandlerDotName = (bi.getToolArgumentsErrorHandlerDotName() != null
? bi.getToolArgumentsErrorHandlerDotName().toString()
: null);

String toolExecutionErrorHandlerDotName = (bi.getToolExecutionErrorHandlerDotName() != null
? bi.getToolExecutionErrorHandlerDotName().toString()
: null);

String chatMemoryProviderSupplierClassName = bi.getChatMemoryProviderSupplierClassDotName() != null
? bi.getChatMemoryProviderSupplierClassDotName().toString()
: null;
Expand Down Expand Up @@ -884,6 +991,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
injectModerationModelBean,
injectImageModel,
toolHallucinationStrategyClassName,
toolArgumentsErrorHandlerDotName,
toolExecutionErrorHandlerDotName,
classInputGuardrails(bi),
classOutputGuardrails(bi),
maxSequentialToolInvocations,
Expand Down Expand Up @@ -931,6 +1040,12 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (bi.getToolHallucinationStrategyClassDotName() != null) {
configurator.addInjectionPoint(ClassType.create(bi.getToolHallucinationStrategyClassDotName()));
}
if (bi.getToolArgumentsErrorHandlerDotName() != null) {
configurator.addInjectionPoint(ClassType.create(bi.getToolArgumentsErrorHandlerDotName()));
}
if (bi.getToolExecutionErrorHandlerDotName() != null) {
configurator.addInjectionPoint(ClassType.create(bi.getToolExecutionErrorHandlerDotName()));
}

if (LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MEMORY_PROVIDER));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final Optional<String> beanName;
private final DeclarativeAiServiceInputGuardrails inputGuardrails;
private final DeclarativeAiServiceOutputGuardrails outputGuardrails;
private final DotName toolArgumentsErrorHandlerDotName;
private final DotName toolExecutionErrorHandlerDotName;
private final Integer maxSequentialToolInvocations;
private final boolean allowContinuousForcedToolCalling;
private final boolean makeDefaultBean;
Expand All @@ -57,6 +59,8 @@ public DeclarativeAiServiceBuildItem(
DotName toolHallucinationStrategyClassDotName,
DeclarativeAiServiceInputGuardrails inputGuardrails,
DeclarativeAiServiceOutputGuardrails outputGuardrails,
DotName toolArgumentsErrorHandlerDotName,
DotName toolExecutionErrorHandlerDotName,
Integer maxSequentialToolInvocations,
boolean allowContinuousForcedToolCalling,
boolean makeDefaultBean) {
Expand All @@ -79,6 +83,8 @@ public DeclarativeAiServiceBuildItem(
this.toolHallucinationStrategyClassDotName = toolHallucinationStrategyClassDotName;
this.inputGuardrails = inputGuardrails;
this.outputGuardrails = outputGuardrails;
this.toolArgumentsErrorHandlerDotName = toolArgumentsErrorHandlerDotName;
this.toolExecutionErrorHandlerDotName = toolExecutionErrorHandlerDotName;
this.maxSequentialToolInvocations = maxSequentialToolInvocations;
this.allowContinuousForcedToolCalling = allowContinuousForcedToolCalling;
this.makeDefaultBean = makeDefaultBean;
Expand Down Expand Up @@ -160,6 +166,14 @@ public DeclarativeAiServiceOutputGuardrails getOutputGuardrails() {
return outputGuardrails;
}

public DotName getToolArgumentsErrorHandlerDotName() {
return toolArgumentsErrorHandlerDotName;
}

public DotName getToolExecutionErrorHandlerDotName() {
return toolExecutionErrorHandlerDotName;
}

public boolean isMakeDefaultBean() {
return makeDefaultBean;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ public class DotNames {

public static final DotName EXECUTOR = DotName.createSimple(Executor.class);

public static final DotName THROWABLE = DotName.createSimple(Throwable.class);
public static final DotName EXCEPTION = DotName.createSimple(Exception.class);

public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class);
public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class);
public static final DotName TOOL = DotName.createSimple(Tool.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.pdf.PdfFile;
import dev.langchain4j.exception.ToolArgumentsException;
import dev.langchain4j.exception.ToolExecutionException;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatModel;
Expand All @@ -29,11 +31,15 @@
import dev.langchain4j.service.UserName;
import dev.langchain4j.service.guardrail.InputGuardrails;
import dev.langchain4j.service.guardrail.OutputGuardrails;
import dev.langchain4j.service.tool.ToolErrorContext;
import dev.langchain4j.service.tool.ToolErrorHandlerResult;
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.web.search.WebSearchEngine;
import dev.langchain4j.web.search.WebSearchTool;
import io.quarkiverse.langchain4j.AudioUrl;
import io.quarkiverse.langchain4j.CreatedAware;
import io.quarkiverse.langchain4j.HandleToolArgumentError;
import io.quarkiverse.langchain4j.HandleToolExecutionError;
import io.quarkiverse.langchain4j.ImageUrl;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.PdfUrl;
Expand Down Expand Up @@ -124,6 +130,9 @@ public class LangChain4jDotNames {

static final DotName SEED_MEMORY = DotName.createSimple(SeedMemory.class);

static final DotName HANDLE_TOOL_ARGUMENT_ERROR = DotName.createSimple(HandleToolArgumentError.class);
static final DotName HANDLE_TOOL_EXECUTION_ERROR = DotName.createSimple(HandleToolExecutionError.class);

static final DotName WEB_SEARCH_TOOL = DotName.createSimple(WebSearchTool.class);
static final DotName WEB_SEARCH_ENGINE = DotName.createSimple(WebSearchEngine.class);
static final DotName IMAGE = DotName.createSimple(Image.class);
Expand All @@ -136,4 +145,8 @@ public class LangChain4jDotNames {
public static final DotName MCP_TOOLBOX = DotName.createSimple("io.quarkiverse.langchain4j.mcp.runtime.McpToolBox");
public static final DotName CHAT_EVENT = DotName.createSimple(ChatEvent.class);
public static final DotName CHAT_MEMORY = DotName.createSimple(ChatMemory.class);
public static final DotName TOOL_ERROR_HANDLER_RESULT = DotName.createSimple(ToolErrorHandlerResult.class);
public static final DotName TOOL_ARGUMENTS_EXCEPTION = DotName.createSimple(ToolArgumentsException.class);
public static final DotName TOOL_EXECUTION_EXCEPTION = DotName.createSimple(ToolExecutionException.class);
public static final DotName TOOL_ERROR_CONTEXT = DotName.createSimple(ToolErrorContext.class);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.exception.ToolArgumentsException;
import dev.langchain4j.service.tool.ToolExecutor;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor;
Expand Down Expand Up @@ -272,7 +273,7 @@ private void executeAndExpectFailure(String arguments, String methodName) {
ToolExecutor toolExecutor = getToolExecutor(methodName);

assertThatThrownBy(() -> toolExecutor.execute(request, null))
.isExactlyInstanceOf(IllegalArgumentException.class);
.isExactlyInstanceOf(ToolArgumentsException.class);
}

private ToolExecutor getToolExecutor(String methodName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import io.quarkiverse.langchain4j.guardrails.ToolOutputGuardrailRequest;
import io.quarkiverse.langchain4j.guardrails.ToolOutputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.ToolOutputGuardrails;
import io.quarkiverse.langchain4j.runtime.BlockingToolNotAllowedException;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
import io.quarkiverse.langchain4j.test.Lists;
import io.quarkus.arc.Arc;
Expand Down Expand Up @@ -130,7 +131,7 @@ void toolWithInputGuardrail_throwsException_whenCalledFromEventLoop() throws Exc
Throwable exception = exceptionRef.get();
assertThat(exception)
.isNotNull()
.isInstanceOf(ToolExecutionException.class);
.isInstanceOf(BlockingToolNotAllowedException.class);

// Our ToolGuardrailsWrapper should catch the event loop and throw a clear error
assertThat(exception.getMessage())
Expand Down Expand Up @@ -169,7 +170,7 @@ void toolWithOutputGuardrail_throwsException_whenCalledFromEventLoop() throws Ex
Throwable exception = exceptionRef.get();
assertThat(exception)
.isNotNull()
.isInstanceOf(ToolExecutionException.class);
.isInstanceOf(BlockingToolNotAllowedException.class);

assertThat(exception.getMessage())
.contains("Cannot execute guardrails")
Expand Down
Loading
Loading