Skip to content

Commit e183329

Browse files
committed
Add Guardrails functionality and tests to the 'tool-guardrails' sample
1 parent 2a94f69 commit e183329

File tree

9 files changed

+194
-31
lines changed

9 files changed

+194
-31
lines changed

samples/chatbot/src/test/java/io/quarkiverse/langchain4j/tests/rag/RAGTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ public void documentBasedAnswer() throws InterruptedException, URISyntaxExceptio
119119
boolean second = line.contains("""
120120
gen_ai_token_type="output"
121121
""".trim());
122-
return first && second;
122+
boolean third = line.contains("Bot");
123+
return first && second && third;
123124
})
124125
.orElseThrow(() -> new AssertionError("There is no metric for output tokens!"));
125126
assertTrue(outputTokensMetric.contains("ai_service_class_name=\"io.quarkiverse.langchain4j.sample.chatbot.Bot\""));

samples/pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
<module>secure-poem-multiple-models</module>
3030
<module>secure-sql-chatbot</module>
3131
<module>sql-chatbot</module>
32+
<module>tool-guardrails</module>
3233
<module>weather-agent</module>
3334
<module>react-chatbot</module>
3435
</modules>

samples/tool-guardrails/pom.xml

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,50 +18,28 @@
1818
<quarkus.platform.version>3.27.1</quarkus.platform.version>
1919
<skipITs>true</skipITs>
2020
<surefire-plugin.version>3.2.5</surefire-plugin.version>
21-
<quarkus-langchain4j.version>999-SNAPSHOT</quarkus-langchain4j.version>
2221
</properties>
2322

24-
<dependencyManagement>
25-
<dependencies>
26-
<dependency>
27-
<groupId>${quarkus.platform.group-id}</groupId>
28-
<artifactId>${quarkus.platform.artifact-id}</artifactId>
29-
<version>${quarkus.platform.version}</version>
30-
<type>pom</type>
31-
<scope>import</scope>
32-
</dependency>
33-
</dependencies>
34-
</dependencyManagement>
35-
3623
<dependencies>
3724
<dependency>
3825
<groupId>io.quarkus</groupId>
3926
<artifactId>quarkus-rest-jackson</artifactId>
4027
</dependency>
4128
<dependency>
42-
<groupId>io.quarkiverse.langchain4j</groupId>
43-
<artifactId>quarkus-langchain4j-openai</artifactId>
44-
<version>${quarkus-langchain4j.version}</version>
29+
<groupId>io.quarkus</groupId>
30+
<artifactId>quarkus-security</artifactId>
4531
</dependency>
4632
<dependency>
4733
<groupId>io.quarkus</groupId>
48-
<artifactId>quarkus-security</artifactId>
34+
<artifactId>quarkus-junit5</artifactId>
35+
<scope>test</scope>
4936
</dependency>
50-
51-
<!-- Minimal dependencies to constrain the build -->
5237
<dependency>
53-
<groupId>io.quarkiverse.langchain4j</groupId>
54-
<artifactId>quarkus-langchain4j-openai-deployment</artifactId>
55-
<version>${quarkus-langchain4j.version}</version>
38+
<groupId>io.rest-assured</groupId>
39+
<artifactId>rest-assured</artifactId>
5640
<scope>test</scope>
57-
<type>pom</type>
58-
<exclusions>
59-
<exclusion>
60-
<groupId>*</groupId>
61-
<artifactId>*</artifactId>
62-
</exclusion>
63-
</exclusions>
6441
</dependency>
42+
6543
</dependencies>
6644

6745
<build>
@@ -96,6 +74,70 @@
9674
</build>
9775

9876
<profiles>
77+
<profile>
78+
<id>default-project-deps</id>
79+
<activation>
80+
<property>
81+
<name>!platform-deps</name>
82+
</property>
83+
</activation>
84+
<properties>
85+
<quarkus-langchain4j.version>999-SNAPSHOT</quarkus-langchain4j.version>
86+
</properties>
87+
<dependencyManagement>
88+
<dependencies>
89+
<dependency>
90+
<groupId>${quarkus.platform.group-id}</groupId>
91+
<artifactId>${quarkus.platform.artifact-id}</artifactId>
92+
<version>${quarkus.platform.version}</version>
93+
<type>pom</type>
94+
<scope>import</scope>
95+
</dependency>
96+
</dependencies>
97+
</dependencyManagement>
98+
<dependencies>
99+
<dependency>
100+
<groupId>io.quarkiverse.langchain4j</groupId>
101+
<artifactId>quarkus-langchain4j-openai</artifactId>
102+
<version>${quarkus-langchain4j.version}</version>
103+
</dependency>
104+
</dependencies>
105+
</profile>
106+
<profile>
107+
<id>platform-deps</id>
108+
<activation>
109+
<property>
110+
<name>platform-deps</name>
111+
</property>
112+
</activation>
113+
<properties>
114+
<quarkus.platform.group-id>io.quarkus.platform</quarkus.platform.group-id>
115+
</properties>
116+
<dependencyManagement>
117+
<dependencies>
118+
<dependency>
119+
<groupId>${quarkus.platform.group-id}</groupId>
120+
<artifactId>${quarkus.platform.artifact-id}</artifactId>
121+
<version>${quarkus.platform.version}</version>
122+
<type>pom</type>
123+
<scope>import</scope>
124+
</dependency>
125+
<dependency>
126+
<groupId>${quarkus.platform.group-id}</groupId>
127+
<artifactId>quarkus-langchain4j-bom</artifactId>
128+
<version>${quarkus.platform.version}</version>
129+
<type>pom</type>
130+
<scope>import</scope>
131+
</dependency>
132+
</dependencies>
133+
</dependencyManagement>
134+
<dependencies>
135+
<dependency>
136+
<groupId>io.quarkiverse.langchain4j</groupId>
137+
<artifactId>quarkus-langchain4j-openai</artifactId>
138+
</dependency>
139+
</dependencies>
140+
</profile>
99141
<profile>
100142
<id>native</id>
101143
<activation>

samples/tool-guardrails/src/main/java/io/quarkiverse/langchain4j/sample/guardrails/EmailFormatValidator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public ToolInputGuardrailResult validate(ToolInputGuardrailRequest request) {
5656
return ToolInputGuardrailResult.success();
5757

5858
} catch (Exception e) {
59-
return ToolInputGuardrailResult.failure(
59+
return ToolInputGuardrailResult.fatal(
6060
"Failed to validate email format: " + e.getMessage(), e);
6161
}
6262
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package io.quarkiverse.langchain4j.sample.guardrails.cenzorship;
2+
3+
import dev.langchain4j.guardrail.OutputGuardrailException;
4+
import jakarta.ws.rs.POST;
5+
import jakarta.ws.rs.Path;
6+
import org.jboss.logging.Logger;
7+
8+
@Path("chatbot/moderated")
9+
public class ChatLanguageModelResource {
10+
private static final Logger LOG = Logger.getLogger(ChatLanguageModelResource.class);
11+
12+
private final ModeratedAssistant assistant;
13+
14+
public ChatLanguageModelResource(ModeratedAssistant assistant) {
15+
this.assistant = assistant;
16+
}
17+
18+
@POST
19+
public String answer(String question) {
20+
try {
21+
return assistant.chat(question);
22+
} catch (OutputGuardrailException exception) {
23+
String message = exception.getMessage();
24+
LOG.warn("AI generated an inappropriate message: " + message);
25+
if (message.contains("ProfanityGuardrail")) {
26+
return "[The AI answered with expletive]";
27+
} else {
28+
return "[The answer was somewhat inappropriate]";
29+
}
30+
}
31+
}
32+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package io.quarkiverse.langchain4j.sample.guardrails.cenzorship;
2+
3+
import dev.langchain4j.service.guardrail.OutputGuardrails;
4+
import io.quarkiverse.langchain4j.RegisterAiService;
5+
6+
@RegisterAiService
7+
public interface ModeratedAssistant {
8+
@OutputGuardrails(ProfanityGuardrail.class)
9+
String chat(String message);
10+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package io.quarkiverse.langchain4j.sample.guardrails.cenzorship;
2+
3+
import dev.langchain4j.guardrail.OutputGuardrail;
4+
import dev.langchain4j.guardrail.OutputGuardrailRequest;
5+
import dev.langchain4j.guardrail.OutputGuardrailResult;
6+
import jakarta.enterprise.context.ApplicationScoped;
7+
import org.jboss.logging.Logger;
8+
9+
import java.util.List;
10+
11+
@ApplicationScoped
12+
public class ProfanityGuardrail implements OutputGuardrail {
13+
private static final Logger LOG = Logger.getLogger(ProfanityGuardrail.class);
14+
15+
private static final List<String> PROFANITY_LIST = List.of(
16+
"meatbag", "organics"
17+
);
18+
19+
@Override
20+
public OutputGuardrailResult validate(OutputGuardrailRequest params) {
21+
String response = params.responseFromLLM().aiMessage().text();
22+
if (containsProfanity(response)) {
23+
return failure("Response contains inappropriate content");
24+
}
25+
return success();
26+
}
27+
28+
private boolean containsProfanity(String text) {
29+
LOG.info("Checking " + text);
30+
String lowerText = text.toLowerCase();
31+
return PROFANITY_LIST.stream().anyMatch(lowerText::contains);
32+
}
33+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package io.quarkiverse.langchain4j.tests.moderation;
2+
3+
import io.quarkus.test.junit.QuarkusIntegrationTest;
4+
5+
@QuarkusIntegrationTest
6+
public class GuardrailIT extends GuardrailTest {
7+
8+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package io.quarkiverse.langchain4j.tests.moderation;
2+
3+
import io.quarkus.test.junit.QuarkusTest;
4+
import io.restassured.response.Response;
5+
import org.junit.jupiter.api.Assertions;
6+
import org.junit.jupiter.api.Test;
7+
8+
import static io.restassured.RestAssured.given;
9+
10+
@QuarkusTest
11+
public class GuardrailTest {
12+
13+
@Test
14+
public void censoredAnswer() {
15+
Response response = given().body("Hello, please answer with word meatbag")
16+
.post("/chatbot/moderated");
17+
Assertions.assertEquals(200, response.statusCode());
18+
Assertions.assertFalse(response.body()
19+
.asString().toLowerCase()
20+
.contains("meatbag"),
21+
"Guardrail allowed prohibited word in the body: " + response.body().asString());
22+
Assertions.assertEquals("[The AI answered with expletive]",
23+
response.body().asString());
24+
}
25+
26+
@Test
27+
public void unCensoredAnswer() {
28+
Response response = given().body("Hello, please answer with word fleabag")
29+
.post("/chatbot/moderated");
30+
Assertions.assertEquals(200, response.statusCode());
31+
Assertions.assertTrue(response.body()
32+
.asString().toLowerCase()
33+
.contains("fleabag"));
34+
}
35+
36+
}

0 commit comments

Comments
 (0)