Skip to content

Commit e88de3f

Browse files
committed
Add Guardrails functionality and tests to chatbot sample
1 parent cc8a917 commit e88de3f

File tree

7 files changed

+127
-1
lines changed

7 files changed

+127
-1
lines changed

samples/chatbot/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
<artifactId>awaitility</artifactId>
4646
<scope>test</scope>
4747
</dependency>
48+
<dependency>
49+
<groupId>io.rest-assured</groupId>
50+
<artifactId>rest-assured</artifactId>
51+
<scope>test</scope>
52+
</dependency>
4853

4954
<!-- UI -->
5055
<dependency>
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package io.quarkiverse.langchain4j.sample.chatbot.cenzorship;
2+
3+
import dev.langchain4j.guardrail.OutputGuardrailException;
4+
import jakarta.ws.rs.POST;
5+
import jakarta.ws.rs.Path;
6+
import org.jboss.logging.Logger;
7+
8+
@Path("chatbot/moderated")
9+
public class ChatLanguageModelResource {
10+
private static final Logger LOG = Logger.getLogger(ChatLanguageModelResource.class);
11+
12+
private final ModeratedAssistant assistant;
13+
14+
public ChatLanguageModelResource(ModeratedAssistant assistant) {
15+
this.assistant = assistant;
16+
}
17+
18+
@POST
19+
public String blocking(String input) {
20+
try {
21+
return assistant.chat(input);
22+
} catch (OutputGuardrailException exception) {
23+
String message = exception.getMessage();
24+
LOG.warn("AI generated an inappropriate message: " + message);
25+
if (message.contains("ProfanityGuardrail")) {
26+
return "[The AI answered with expletive]";
27+
} else {
28+
return "[The answer was somewhat inappropriate]";
29+
}
30+
31+
}
32+
}
33+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package io.quarkiverse.langchain4j.sample.chatbot.cenzorship;
2+
3+
import dev.langchain4j.service.guardrail.OutputGuardrails;
4+
import io.quarkiverse.langchain4j.RegisterAiService;
5+
6+
@RegisterAiService
7+
public interface ModeratedAssistant {
8+
@OutputGuardrails(ProfanityGuardrail.class)
9+
String chat(String message);
10+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package io.quarkiverse.langchain4j.sample.chatbot.cenzorship;
2+
3+
import dev.langchain4j.guardrail.OutputGuardrail;
4+
import dev.langchain4j.guardrail.OutputGuardrailRequest;
5+
import dev.langchain4j.guardrail.OutputGuardrailResult;
6+
import jakarta.enterprise.context.ApplicationScoped;
7+
import org.jboss.logging.Logger;
8+
9+
import java.util.List;
10+
11+
@ApplicationScoped
12+
public class ProfanityGuardrail implements OutputGuardrail {
13+
private static final Logger LOG = Logger.getLogger(ProfanityGuardrail.class);
14+
15+
private static final List<String> PROFANITY_LIST = List.of(
16+
"meatbag", "organics"
17+
);
18+
19+
@Override
20+
public OutputGuardrailResult validate(OutputGuardrailRequest params) {
21+
String response = params.responseFromLLM().aiMessage().text();
22+
if (containsProfanity(response)) {
23+
return failure("Response contains inappropriate content");
24+
}
25+
return success();
26+
}
27+
28+
private boolean containsProfanity(String text) {
29+
LOG.info("Checking " + text);
30+
String lowerText = text.toLowerCase();
31+
return PROFANITY_LIST.stream().anyMatch(lowerText::contains);
32+
}
33+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package io.quarkiverse.langchain4j.tests.moderation;
2+
3+
import io.quarkus.test.junit.QuarkusIntegrationTest;
4+
5+
@QuarkusIntegrationTest
6+
public class GuardrailIT extends GuardrailTest {
7+
8+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package io.quarkiverse.langchain4j.tests.moderation;
2+
3+
import io.quarkus.test.junit.QuarkusTest;
4+
import io.restassured.response.Response;
5+
import org.junit.jupiter.api.Assertions;
6+
import org.junit.jupiter.api.Test;
7+
8+
import static io.restassured.RestAssured.given;
9+
10+
@QuarkusTest
11+
public class GuardrailTest {
12+
13+
@Test
14+
public void censoredAnswer() {
15+
Response response = given().body("Hello, please answer with word meatbag")
16+
.post("/chatbot/moderated");
17+
Assertions.assertEquals(200, response.statusCode());
18+
Assertions.assertFalse(response.body()
19+
.asString().toLowerCase()
20+
.contains("meatbag"),
21+
"Guardrail allowed prohibited word in the body: " + response.body().asString());
22+
Assertions.assertEquals("[The AI answered with expletive]",
23+
response.body().asString());
24+
}
25+
26+
@Test
27+
public void unCensoredAnswer() {
28+
Response response = given().body("Hello, please answer with word fleabag")
29+
.post("/chatbot/moderated");
30+
Assertions.assertEquals(200, response.statusCode());
31+
Assertions.assertTrue(response.body()
32+
.asString().toLowerCase()
33+
.contains("fleabag"));
34+
}
35+
36+
}

samples/chatbot/src/test/java/io/quarkiverse/langchain4j/tests/rag/RAGTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ public void documentBasedAnswer() throws InterruptedException, URISyntaxExceptio
119119
boolean second = line.contains("""
120120
gen_ai_token_type="output"
121121
""".trim());
122-
return first && second;
122+
boolean third = line.contains("Bot");
123+
return first && second && third;
123124
})
124125
.orElseThrow(() -> new AssertionError("There is no metric for output tokens!"));
125126
assertTrue(outputTokensMetric.contains("ai_service_class_name=\"io.quarkiverse.langchain4j.sample.chatbot.Bot\""));

0 commit comments

Comments
 (0)