Skip to content

Commit f9d4050

Browse files
authored
Merge pull request #1993 from cescoffier/tool-guradrail-metrics
Add observability and metrics support for tool guardrails
2 parents 0d76e40 + 46b2c50 commit f9d4050

29 files changed

+1916
-69
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/GuardrailObservabilityProcessor.java

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package io.quarkiverse.langchain4j.deployment;
22

3-
import static io.quarkiverse.langchain4j.deployment.GuardrailObservabilityProcessorSupport.*;
3+
import static io.quarkiverse.langchain4j.deployment.GuardrailObservabilityProcessorSupport.MICROMETER_COUNTED;
4+
import static io.quarkiverse.langchain4j.deployment.GuardrailObservabilityProcessorSupport.MICROMETER_TIMED;
5+
import static io.quarkiverse.langchain4j.deployment.GuardrailObservabilityProcessorSupport.TransformType;
6+
import static io.quarkiverse.langchain4j.deployment.GuardrailObservabilityProcessorSupport.WITH_SPAN;
7+
import static io.quarkiverse.langchain4j.deployment.GuardrailObservabilityProcessorSupport.shouldTransformMethod;
48

59
import java.util.Optional;
610
import java.util.function.Consumer;
@@ -16,6 +20,8 @@
1620

1721
import dev.langchain4j.observability.api.event.InputGuardrailExecutedEvent;
1822
import dev.langchain4j.observability.api.event.OutputGuardrailExecutedEvent;
23+
import io.quarkiverse.langchain4j.runtime.observability.ToolInputGuardrailExecutedEvent;
24+
import io.quarkiverse.langchain4j.runtime.observability.ToolOutputGuardrailExecutedEvent;
1925
import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem;
2026
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
2127
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
@@ -81,20 +87,43 @@ void addMetricObserver(
8187
MethodCreator onInputGuardrailExecuted = classCreator.getMethodCreator("onInputGuardrailExecuted", "V",
8288
InputGuardrailExecutedEvent.class);
8389
onInputGuardrailExecuted.getParameterAnnotations(0).addAnnotation(Observes.class);
84-
var support1 = MethodDescriptor.ofMethod(
90+
var inputGuardrailObserverMethod = MethodDescriptor.ofMethod(
8591
GUARDRAIL_METRICS_OBSERVER_SUPPORT_CLASS,
8692
"onInputGuardrailExecuted", "V", InputGuardrailExecutedEvent.class);
87-
onInputGuardrailExecuted.invokeStaticMethod(support1, onInputGuardrailExecuted.getMethodParam(0));
93+
onInputGuardrailExecuted.invokeStaticMethod(inputGuardrailObserverMethod,
94+
onInputGuardrailExecuted.getMethodParam(0));
8895
onInputGuardrailExecuted.returnVoid();
8996

9097
MethodCreator onOutputGuardrailExecuted = classCreator.getMethodCreator("onOutputGuardrailExecuted", "V",
9198
OutputGuardrailExecutedEvent.class);
9299
onOutputGuardrailExecuted.getParameterAnnotations(0).addAnnotation(Observes.class);
93-
var support2 = MethodDescriptor.ofMethod(
100+
var outputGuardrailObserverMethod = MethodDescriptor.ofMethod(
94101
GUARDRAIL_METRICS_OBSERVER_SUPPORT_CLASS,
95102
"onOutputGuardrailExecuted", "V", OutputGuardrailExecutedEvent.class);
96-
onOutputGuardrailExecuted.invokeStaticMethod(support2, onOutputGuardrailExecuted.getMethodParam(0));
103+
onOutputGuardrailExecuted.invokeStaticMethod(outputGuardrailObserverMethod,
104+
onOutputGuardrailExecuted.getMethodParam(0));
97105
onOutputGuardrailExecuted.returnVoid();
106+
107+
MethodCreator onToolInputGuardrailExecuted = classCreator.getMethodCreator("onToolInputGuardrailExecuted", "V",
108+
ToolInputGuardrailExecutedEvent.class);
109+
onToolInputGuardrailExecuted.getParameterAnnotations(0).addAnnotation(Observes.class);
110+
var inputToolGuardrailObserverMethod = MethodDescriptor.ofMethod(
111+
GUARDRAIL_METRICS_OBSERVER_SUPPORT_CLASS,
112+
"onToolInputGuardrailExecuted", "V", ToolInputGuardrailExecutedEvent.class);
113+
onToolInputGuardrailExecuted.invokeStaticMethod(inputToolGuardrailObserverMethod,
114+
onToolInputGuardrailExecuted.getMethodParam(0));
115+
onToolInputGuardrailExecuted.returnVoid();
116+
117+
MethodCreator onToolOutputGuardrailExecuted = classCreator.getMethodCreator("onToolOutputGuardrailExecuted",
118+
"V",
119+
ToolOutputGuardrailExecutedEvent.class);
120+
onToolOutputGuardrailExecuted.getParameterAnnotations(0).addAnnotation(Observes.class);
121+
var outputToolGuardrailObserverMethod = MethodDescriptor.ofMethod(
122+
GUARDRAIL_METRICS_OBSERVER_SUPPORT_CLASS,
123+
"onToolOutputGuardrailExecuted", "V", ToolOutputGuardrailExecutedEvent.class);
124+
onToolOutputGuardrailExecuted.invokeStaticMethod(outputToolGuardrailObserverMethod,
125+
onToolOutputGuardrailExecuted.getMethodParam(0));
126+
onToolOutputGuardrailExecuted.returnVoid();
98127
}
99128
}
100129

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/ErrorHandlingTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import dev.langchain4j.service.UserMessage;
3333
import io.quarkiverse.langchain4j.RegisterAiService;
3434
import io.quarkiverse.langchain4j.ToolBox;
35+
import io.quarkiverse.langchain4j.guardrails.ToolGuardrailException;
3536
import io.quarkiverse.langchain4j.guardrails.ToolInputGuardrail;
3637
import io.quarkiverse.langchain4j.guardrails.ToolInputGuardrailRequest;
3738
import io.quarkiverse.langchain4j.guardrails.ToolInputGuardrailResult;
@@ -96,9 +97,9 @@ void testInputGuardrail_throwsRuntimeException() {
9697
@Test
9798
@ActivateRequestContext
9899
void testInputGuardrail_throwsNullPointerException() {
99-
// NPE propagates directly
100100
assertThatThrownBy(() -> aiService.chat("test", "nullPointerTool - anything"))
101-
.isInstanceOf(NullPointerException.class);
101+
.isInstanceOf(ToolGuardrailException.class)
102+
.hasMessageContaining("null");
102103

103104
assertThat(NullPointerExceptionGuardrail.executionCount).isEqualTo(1);
104105
assertThat(tools.nullPointerToolExecuted).isFalse();

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailChainOrderingTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ public ToolInputGuardrailResult validate(ToolInputGuardrailRequest request) {
491491
ExecutionTracker.recordInput("InputGuardrail3");
492492

493493
if (request.arguments().contains("fatal-at-3")) {
494-
return ToolInputGuardrailResult.failure(
494+
return ToolInputGuardrailResult.fatal(
495495
"Fatal failure at guardrail 3",
496496
new SecurityException("Unauthorized"));
497497
}
@@ -602,7 +602,7 @@ public ToolOutputGuardrailResult validate(ToolOutputGuardrailRequest request) {
602602
ExecutionTracker.recordOutput("OutputGuardrail3");
603603

604604
if (request.resultText().contains("fatal-at-3")) {
605-
return ToolOutputGuardrailResult.failure(
605+
return ToolOutputGuardrailResult.fatal(
606606
"Fatal failure at guardrail 3",
607607
new SecurityException("Data leak detected"));
608608
}

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/MoreRealWorldToolGuardrailsTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,14 +563,14 @@ public ToolInputGuardrailResult validate(ToolInputGuardrailRequest request) {
563563

564564
// Check if user is authenticated
565565
if (user == null) {
566-
return ToolInputGuardrailResult.failure(
566+
return ToolInputGuardrailResult.fatal(
567567
"Authentication required. Please log in to use this tool.",
568568
new SecurityException("Unauthenticated access attempt"));
569569
}
570570

571571
// Check if user has required role
572572
if (!ADMIN_USERS.contains(user)) {
573-
return ToolInputGuardrailResult.failure(
573+
return ToolInputGuardrailResult.fatal(
574574
"Insufficient permissions. Administrator role required.",
575575
new SecurityException("Unauthorized access attempt by user: " + user));
576576
}

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/ToolGuardrailWithObjectTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ public ToolInputGuardrailResult validate(ToolInputGuardrailRequest request) {
445445

446446
return ToolInputGuardrailResult.success();
447447
} catch (Exception e) {
448-
return ToolInputGuardrailResult.failure("Failed to validate customer ID: " + e.getMessage(), e);
448+
return ToolInputGuardrailResult.fatal("Failed to validate customer ID: " + e.getMessage(), e);
449449
}
450450
}
451451

@@ -481,7 +481,7 @@ public ToolOutputGuardrailResult validate(ToolOutputGuardrailRequest request) {
481481

482482
return ToolOutputGuardrailResult.success();
483483
} catch (Exception e) {
484-
return ToolOutputGuardrailResult.failure("Failed to filter sensitive fields: " + e.getMessage(), e);
484+
return ToolOutputGuardrailResult.fatal("Failed to filter sensitive fields: " + e.getMessage(), e);
485485
}
486486
}
487487

@@ -529,7 +529,7 @@ public ToolOutputGuardrailResult validate(ToolOutputGuardrailRequest request) {
529529
limitedSize = originalSize;
530530
return ToolOutputGuardrailResult.success();
531531
} catch (Exception e) {
532-
return ToolOutputGuardrailResult.failure("Failed to limit customer list: " + e.getMessage(), e);
532+
return ToolOutputGuardrailResult.fatal("Failed to limit customer list: " + e.getMessage(), e);
533533
}
534534
}
535535

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/ToolGuardrailWithUniTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ public ToolInputGuardrailResult validate(ToolInputGuardrailRequest request) {
436436

437437
return ToolInputGuardrailResult.success();
438438
} catch (Exception e) {
439-
return ToolInputGuardrailResult.failure("Failed to validate message: " + e.getMessage(), e);
439+
return ToolInputGuardrailResult.fatal("Failed to validate message: " + e.getMessage(), e);
440440
}
441441
}
442442

@@ -467,7 +467,7 @@ public ToolInputGuardrailResult validate(ToolInputGuardrailRequest request) {
467467

468468
return ToolInputGuardrailResult.success();
469469
} catch (Exception e) {
470-
return ToolInputGuardrailResult.failure("Failed to validate product ID: " + e.getMessage(), e);
470+
return ToolInputGuardrailResult.fatal("Failed to validate product ID: " + e.getMessage(), e);
471471
}
472472
}
473473

@@ -501,7 +501,7 @@ public ToolOutputGuardrailResult validate(ToolOutputGuardrailRequest request) {
501501

502502
return ToolOutputGuardrailResult.success();
503503
} catch (Exception e) {
504-
return ToolOutputGuardrailResult.failure("Failed to filter prices: " + e.getMessage(), e);
504+
return ToolOutputGuardrailResult.fatal("Failed to filter prices: " + e.getMessage(), e);
505505
}
506506
}
507507

@@ -541,7 +541,7 @@ public ToolOutputGuardrailResult validate(ToolOutputGuardrailRequest request) {
541541
.resultText(modifiedJson)
542542
.build());
543543
} catch (Exception e) {
544-
return ToolOutputGuardrailResult.failure(
544+
return ToolOutputGuardrailResult.fatal(
545545
"Failed to transform inventory status: " + e.getMessage(), e);
546546
}
547547
}

0 commit comments

Comments
 (0)