Skip to content

Commit dc40f6f

Browse files
authored
Merge pull request #2013 from quarkiverse/#2009
Introduce `allowContinuousForcedToolCalling` into @RegisterAiService
2 parents 0446320 + 6519ffb commit dc40f6f

File tree

8 files changed

+54
-10
lines changed

8 files changed

+54
-10
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,11 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
470470
Integer maxSequentialToolInvocations = instance.value("maxSequentialToolInvocations") != null
471471
? instance.value("maxSequentialToolInvocations").asInt()
472472
: 0;
473+
474+
boolean allowContinuousForcedToolCalling = instance.value("allowContinuousForcedToolCalling") != null
475+
? instance.value("allowContinuousForcedToolCalling").asBoolean()
476+
: false;
477+
473478
declarativeAiServiceProducer.produce(
474479
new DeclarativeAiServiceBuildItem(
475480
declarativeAiServiceClassInfo,
@@ -492,6 +497,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
492497
classInputGuardrails(declarativeAiServiceClassInfo, index),
493498
classOutputGuardrails(declarativeAiServiceClassInfo, index),
494499
maxSequentialToolInvocations,
500+
allowContinuousForcedToolCalling,
495501
// we need to make these @DefaultBean because there could be other CDI beans of the same type that need to take precedence
496502
impliedRegisterAiServiceTarget.contains(declarativeAiServiceClassInfo.name())));
497503

@@ -752,6 +758,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
752758
ClassInfo declarativeAiServiceClassInfo = bi.getServiceClassInfo();
753759
String serviceClassName = declarativeAiServiceClassInfo.name().toString();
754760
Integer maxSequentialToolInvocations = bi.getMaxSequentialToolInvocations();
761+
boolean allowContinuousForcedToolCalling = bi.isAllowContinuousForcedToolCalling();
755762

756763
String chatLanguageModelSupplierClassName = (bi.getChatLanguageModelSupplierClassDotName() != null
757764
? bi.getChatLanguageModelSupplierClassDotName().toString()
@@ -879,7 +886,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
879886
toolHallucinationStrategyClassName,
880887
classInputGuardrails(bi),
881888
classOutputGuardrails(bi),
882-
maxSequentialToolInvocations)))
889+
maxSequentialToolInvocations,
890+
allowContinuousForcedToolCalling)))
883891
.setRuntimeInit()
884892
.addQualifier()
885893
.annotation(LangChain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER).addValue("value", serviceClassName)

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
3434
private final DeclarativeAiServiceInputGuardrails inputGuardrails;
3535
private final DeclarativeAiServiceOutputGuardrails outputGuardrails;
3636
private final Integer maxSequentialToolInvocations;
37+
private final boolean allowContinuousForcedToolCalling;
3738
private final boolean makeDefaultBean;
3839

3940
public DeclarativeAiServiceBuildItem(
@@ -56,7 +57,9 @@ public DeclarativeAiServiceBuildItem(
5657
DotName toolHallucinationStrategyClassDotName,
5758
DeclarativeAiServiceInputGuardrails inputGuardrails,
5859
DeclarativeAiServiceOutputGuardrails outputGuardrails,
59-
Integer maxSequentialToolInvocations, boolean makeDefaultBean) {
60+
Integer maxSequentialToolInvocations,
61+
boolean allowContinuousForcedToolCalling,
62+
boolean makeDefaultBean) {
6063
this.serviceClassInfo = serviceClassInfo;
6164
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
6265
this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName;
@@ -77,6 +80,7 @@ public DeclarativeAiServiceBuildItem(
7780
this.inputGuardrails = inputGuardrails;
7881
this.outputGuardrails = outputGuardrails;
7982
this.maxSequentialToolInvocations = maxSequentialToolInvocations;
83+
this.allowContinuousForcedToolCalling = allowContinuousForcedToolCalling;
8084
this.makeDefaultBean = makeDefaultBean;
8185
}
8286

@@ -184,4 +188,8 @@ public List<String> asClassNames() {
184188
public Integer getMaxSequentialToolInvocations() {
185189
return maxSequentialToolInvocations;
186190
}
191+
192+
public boolean isAllowContinuousForcedToolCalling() {
193+
return allowContinuousForcedToolCalling;
194+
}
187195
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusAiServicesFactory.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ public AiServices<T> maxSequentialToolInvocations(Integer maxSequentialToolInvoc
6868
return this;
6969
}
7070

71+
public AiServices<T> allowContinuousForcedToolCalling(boolean allowContinuousForcedToolCalling) {
72+
quarkusAiServiceContext().allowContinuousForcedToolCalling = allowContinuousForcedToolCalling;
73+
return this;
74+
}
75+
7176
@SuppressWarnings("unchecked")
7277
@Override
7378
public T build() {

core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
1717
import dev.langchain4j.model.chat.ChatModel;
1818
import dev.langchain4j.model.chat.StreamingChatModel;
19+
import dev.langchain4j.model.chat.request.ToolChoice;
1920
import dev.langchain4j.model.image.ImageModel;
2021
import dev.langchain4j.model.moderation.ModerationModel;
2122
import dev.langchain4j.rag.RetrievalAugmentor;
@@ -31,7 +32,7 @@
3132
* Used to create LangChain4j's {@link AiServices} in a declarative manner that the application can then use simply by
3233
* using the class as a CDI bean.
3334
* Under the hood LangChain4j's {@link AiServices#builder(Class)} is called
34-
* while also providing the builder with the proper {@link ChatLanguageModel} bean (mandatory), {@code tools} bean (optional),
35+
* while also providing the builder with the proper {@link ChatModel} bean (mandatory), {@code tools} bean (optional),
3536
* {@link ChatMemoryProvider} and {@link ContentRetriever} beans (which by default are configured if such beans exist).
3637
* <p>
3738
* NOTE: The resulting CDI bean is {@link jakarta.enterprise.context.RequestScoped} by default. If you need to change the scope,
@@ -48,7 +49,7 @@
4849
public @interface RegisterAiService {
4950

5051
/**
51-
* Configures the way to obtain the {@link StreamingChatLanguageModel} to use.
52+
* Configures the way to obtain the {@link StreamingChatModel} to use.
5253
* If not configured, the default CDI bean implementing the model is looked up.
5354
* Such a bean provided automatically by extensions such as {@code quarkus-langchain4j-openai},
5455
* {@code quarkus-langchain4j-azure-openai} or
@@ -57,7 +58,7 @@
5758
Class<? extends Supplier<StreamingChatModel>> streamingChatLanguageModelSupplier() default BeanStreamingChatLanguageModelSupplier.class;
5859

5960
/**
60-
* Configures the way to obtain the {@link ChatLanguageModel} to use.
61+
* Configures the way to obtain the {@link ChatModel} to use.
6162
* If not configured, the default CDI bean implementing the model is looked up.
6263
* Such a bean provided automatically by extensions such as {@code quarkus-langchain4j-openai},
6364
* {@code quarkus-langchain4j-azure-openai} or
@@ -67,7 +68,7 @@
6768

6869
/**
6970
* When {@code chatLanguageModelSupplier} is set to {@code BeanChatLanguageModelSupplier.class} (which is the default)
70-
* this allows the selection of the {@link ChatLanguageModel} CDI bean to use.
71+
* this allows the selection of the {@link ChatModel} CDI bean to use.
7172
* <p>
7273
* If not set, the default model (i.e. the one configured without setting the model name) is used.
7374
* An example of the default model configuration is the following:
@@ -148,7 +149,20 @@
148149
Class<? extends Supplier<ToolProvider>> toolProviderSupplier() default BeanIfExistsToolProviderSupplier.class;
149150

150151
/**
151-
* Marker that is used to tell Quarkus to use the {@link ChatLanguageModel} that has been configured as a CDI bean by
152+
* By default, after first tool call execution, in subsequent prompts the {@code toolChoice} of
153+
* {@link dev.langchain4j.model.chat.request.ChatRequestParameters}
154+
* is set to {@link ToolChoice#AUTO}.
155+
* By enabling this option {@link ToolChoice#AUTO} will not be set and instead whatever value was used in the initial prompt
156+
* will
157+
* continue to be used.
158+
* <p>
159+
* BEWARE: This is dangerous as it can result in an infinite-loop when using the AiService in combination with the
160+
* {@code toolChoice} option set to {@link ToolChoice#REQUIRED}.
161+
*/
162+
boolean allowContinuousForcedToolCalling() default false;
163+
164+
/**
165+
* Marker that is used to tell Quarkus to use the {@link ChatModel} that has been configured as a CDI bean by
152166
* any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and
153167
* {@code quarkus-langchain4j-hugging-face}).
154168
*/
@@ -161,7 +175,7 @@ public ChatModel get() {
161175
}
162176

163177
/**
164-
* Marker that is used to tell Quarkus to use the {@link StreamingChatLanguageModel} that has been configured as a CDI bean
178+
* Marker that is used to tell Quarkus to use the {@link StreamingChatModel} that has been configured as a CDI bean
165179
* by * any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and
166180
* {@code quarkus-langchain4j-hugging-face}).
167181
*/

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,12 @@ public QuarkusAiServiceContext apply(SyntheticCreationalContext<QuarkusAiService
312312
quarkusAiServices.maxSequentialToolInvocations(info.maxSequentialToolInvocations());
313313
}
314314

315+
if (info.maxSequentialToolInvocations() != null && info.maxSequentialToolInvocations() > 0) {
316+
quarkusAiServices.maxSequentialToolInvocations(info.maxSequentialToolInvocations());
317+
}
318+
319+
quarkusAiServices.allowContinuousForcedToolCalling(info.allowContinuousForcedToolCalling());
320+
315321
return aiServiceContext;
316322
} catch (ClassNotFoundException e) {
317323
throw new IllegalStateException(e);

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,8 @@ private static Object doImplement0(AiServiceMethodCreateInfo methodCreateInfo, I
503503

504504
if (nonNull(context.chatModel.defaultRequestParameters())) {
505505
var toolChoice = context.chatModel.defaultRequestParameters().toolChoice();
506-
if (nonNull(toolChoice) && toolChoice.equals(ToolChoice.REQUIRED)) {
506+
if (nonNull(toolChoice) && toolChoice.equals(ToolChoice.REQUIRED)
507+
&& !context.allowContinuousForcedToolCalling) {
507508
// This code is needed to avoid a infinite-loop when using the AiService
508509
// in combination with the tool-choice option set to REQUIRED.
509510
// If the tool-choice option is not set to AUTO after calling the tool,

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,6 @@ public record DeclarativeAiServiceCreateInfo(
2727
String toolHallucinationStrategyClassName,
2828
InputGuardrailsLiteral inputGuardrails,
2929
OutputGuardrailsLiteral outputGuardrails,
30-
Integer maxSequentialToolInvocations) {
30+
Integer maxSequentialToolInvocations,
31+
boolean allowContinuousForcedToolCalling) {
3132
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public class QuarkusAiServiceContext extends AiServiceContext {
2222
public ChatMemorySeeder chatMemorySeeder;
2323
public ImageModel imageModel;
2424
public Integer maxSequentialToolExecutions;
25+
public boolean allowContinuousForcedToolCalling;
2526

2627
// needed by Arc
2728
public QuarkusAiServiceContext() {

0 commit comments

Comments
 (0)