Skip to content

Commit b9b99b4

Browse files
authored
Adding unit tests (#12)
* Adding unit tests, coverage up to 82% * Updating package-lock file
1 parent 7925ba3 commit b9b99b4

21 files changed

+7727
-5251
lines changed

package-lock.json

Lines changed: 5508 additions & 5238 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
"eslint": "^8.0.0",
5656
"prettier": "^3.0.0",
5757
"rimraf": "^5.0.0",
58+
"@vitest/coverage-v8": "^1.6.0",
5859
"typescript": "^5.0.0",
5960
"vitest": "^1.0.0"
6061
},
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
/**
2+
* Unit tests for GuardrailsBaseClient shared helpers.
3+
*
4+
* These tests focus on the guardrail orchestration helpers that organize
5+
* contexts, apply PII masking, and coordinate guardrail execution for the
6+
* higher-level clients.
7+
*/
8+
9+
import { describe, it, expect, vi, beforeEach } from 'vitest';
10+
import { GuardrailsBaseClient, GuardrailResultsImpl, StageGuardrails } from '../../base-client';
11+
import { GuardrailTripwireTriggered } from '../../exceptions';
12+
import { GuardrailLLMContext, GuardrailResult } from '../../types';
13+
14+
interface MockGuardrail {
15+
definition: {
16+
name: string;
17+
metadata?: Record<string, unknown>;
18+
};
19+
run: ReturnType<typeof vi.fn>;
20+
}
21+
22+
class TestGuardrailsClient extends GuardrailsBaseClient {
23+
public setContext(ctx: GuardrailLLMContext): void {
24+
(this as any).context = ctx;
25+
}
26+
27+
public setGuardrails(guardrails: StageGuardrails): void {
28+
(this as any).guardrails = guardrails;
29+
}
30+
31+
public setPipeline(pipeline: any): void {
32+
(this as any).pipeline = pipeline;
33+
}
34+
35+
protected createDefaultContext(): GuardrailLLMContext {
36+
return { guardrailLlm: {} as any };
37+
}
38+
39+
protected overrideResources(): void {
40+
// Not needed for unit tests
41+
}
42+
}
43+
44+
const createGuardrail = (
45+
name: string,
46+
implementation: (ctx: any, text: string) => GuardrailResult | Promise<GuardrailResult>
47+
): MockGuardrail => ({
48+
definition: { name },
49+
run: vi.fn(implementation),
50+
});
51+
52+
describe('GuardrailsBaseClient helpers', () => {
53+
let client: TestGuardrailsClient;
54+
55+
beforeEach(() => {
56+
client = new TestGuardrailsClient();
57+
client.setContext({ guardrailLlm: {} } as GuardrailLLMContext);
58+
client.setGuardrails({
59+
pre_flight: [],
60+
input: [],
61+
output: [],
62+
});
63+
});
64+
65+
describe('extractLatestUserMessage', () => {
66+
it('returns the latest user message and index for string content', () => {
67+
const messages = [
68+
{ role: 'system', content: 'hi' },
69+
{ role: 'user', content: ' first ' },
70+
{ role: 'assistant', content: 'ok' },
71+
{ role: 'user', content: ' second ' },
72+
];
73+
74+
const [text, index] = client.extractLatestUserMessage(messages);
75+
76+
expect(text).toBe('second');
77+
expect(index).toBe(3);
78+
});
79+
80+
it('handles responses API content parts', () => {
81+
const messages = [
82+
{ role: 'user', content: [{ type: 'text', text: 'hello' }] },
83+
{
84+
role: 'user',
85+
content: [
86+
{ type: 'text', text: 'part1' },
87+
{ type: 'text', text: 'part2' },
88+
],
89+
},
90+
];
91+
92+
const [text, index] = client.extractLatestUserMessage(messages);
93+
94+
expect(text).toBe('part1 part2');
95+
expect(index).toBe(1);
96+
});
97+
98+
it('returns empty string when no user messages exist', () => {
99+
const [text, index] = client.extractLatestUserMessage([
100+
{ role: 'assistant', content: 'hi' },
101+
]);
102+
expect(text).toBe('');
103+
expect(index).toBe(-1);
104+
});
105+
});
106+
107+
describe('applyPreflightModifications', () => {
108+
it('masks detected PII in string inputs', () => {
109+
const results: GuardrailResult[] = [
110+
{
111+
tripwireTriggered: false,
112+
info: {
113+
detected_entities: {
114+
EMAIL: ['[email protected]'],
115+
},
116+
},
117+
},
118+
];
119+
120+
const masked = client.applyPreflightModifications(
121+
'Reach me at [email protected]',
122+
results
123+
) as string;
124+
125+
expect(masked).toBe('Reach me at <EMAIL>');
126+
});
127+
128+
it('masks detected PII in the latest user message with structured content', () => {
129+
const messages = [
130+
{ role: 'assistant', content: 'hello' },
131+
{
132+
role: 'user',
133+
content: [
134+
{ type: 'text', text: 'Call me at 123-456-7890' },
135+
{ type: 'text', text: 'or email [email protected]' },
136+
],
137+
},
138+
];
139+
140+
const results: GuardrailResult[] = [
141+
{
142+
tripwireTriggered: false,
143+
info: {
144+
detected_entities: {
145+
PHONE: ['123-456-7890'],
146+
EMAIL: ['[email protected]'],
147+
},
148+
},
149+
},
150+
];
151+
152+
const masked = client.applyPreflightModifications(messages, results) as any[];
153+
const [, latestMessage] = masked;
154+
155+
expect(latestMessage.content[0].text).toBe('Call me at <PHONE>');
156+
expect(latestMessage.content[1].text).toBe('or email <EMAIL>');
157+
// Ensure assistant message unchanged
158+
expect(masked[0]).toEqual(messages[0]);
159+
});
160+
161+
it('returns original payload when no detected entities exist', () => {
162+
const data = 'Nothing to mask';
163+
const result = client.applyPreflightModifications(data, []);
164+
expect(result).toBe(data);
165+
});
166+
});
167+
168+
describe('runStageGuardrails', () => {
169+
const baseResult = {
170+
tripwireTriggered: false,
171+
info: {},
172+
};
173+
174+
beforeEach(() => {
175+
client.setGuardrails({
176+
pre_flight: [createGuardrail('Test Guard', async () => ({ ...baseResult }))],
177+
input: [],
178+
output: [],
179+
});
180+
});
181+
182+
it('executes guardrails and annotates info metadata', async () => {
183+
const results = await client.runStageGuardrails('pre_flight', 'payload');
184+
185+
expect(results).toHaveLength(1);
186+
expect(results[0].info).toMatchObject({
187+
stage_name: 'pre_flight',
188+
guardrail_name: 'Test Guard',
189+
});
190+
});
191+
192+
it('throws GuardrailTripwireTriggered when guardrail reports tripwire', async () => {
193+
client.setGuardrails({
194+
pre_flight: [
195+
createGuardrail('Tripwire', async () => ({
196+
tripwireTriggered: true,
197+
info: { reason: 'bad' },
198+
})),
199+
],
200+
input: [],
201+
output: [],
202+
});
203+
204+
await expect(client.runStageGuardrails('pre_flight', 'payload')).rejects.toBeInstanceOf(
205+
GuardrailTripwireTriggered
206+
);
207+
});
208+
209+
it('suppresses tripwire errors when suppressTripwire=true', async () => {
210+
client.setGuardrails({
211+
pre_flight: [
212+
createGuardrail('Tripwire', async () => ({
213+
tripwireTriggered: true,
214+
info: { reason: 'bad' },
215+
})),
216+
],
217+
input: [],
218+
output: [],
219+
});
220+
221+
const results = await client.runStageGuardrails('pre_flight', 'payload', undefined, true);
222+
expect(results).toHaveLength(1);
223+
expect(results[0].tripwireTriggered).toBe(true);
224+
});
225+
226+
it('rethrows execution errors when raiseGuardrailErrors=true', async () => {
227+
client.setGuardrails({
228+
pre_flight: [
229+
createGuardrail('Faulty', async () => {
230+
throw new Error('boom');
231+
}),
232+
],
233+
input: [],
234+
output: [],
235+
});
236+
237+
await expect(
238+
client.runStageGuardrails('pre_flight', 'payload', undefined, false, true)
239+
).rejects.toThrow('boom');
240+
});
241+
242+
it('creates a conversation-aware context for prompt injection detection guardrails', async () => {
243+
const guardrail = createGuardrail('Prompt Injection Detection', async () => ({
244+
tripwireTriggered: false,
245+
info: {},
246+
}));
247+
client.setGuardrails({
248+
pre_flight: [guardrail],
249+
input: [],
250+
output: [],
251+
});
252+
const spy = vi.spyOn(client as any, 'createContextWithConversation');
253+
254+
await client.runStageGuardrails(
255+
'pre_flight',
256+
'payload',
257+
[{ role: 'user', content: 'hi' }],
258+
false,
259+
false
260+
);
261+
262+
expect(spy).toHaveBeenCalled();
263+
});
264+
});
265+
266+
describe('handleLlmResponse', () => {
267+
it('appends LLM response to conversation history and returns guardrail results', async () => {
268+
const conversation = [{ role: 'user', content: 'hi' }];
269+
const outputResult = { tripwireTriggered: false, info: {} };
270+
const runSpy = vi
271+
.spyOn(client as any, 'runStageGuardrails')
272+
.mockResolvedValue([outputResult]);
273+
274+
const llmResponse: any = {
275+
choices: [{ message: { role: 'assistant', content: 'All good' } }],
276+
};
277+
278+
const response = await (client as any).handleLlmResponse(
279+
llmResponse,
280+
[],
281+
[],
282+
conversation
283+
);
284+
285+
expect(runSpy).toHaveBeenCalledWith(
286+
'output',
287+
'All good',
288+
expect.arrayContaining([
289+
{ role: 'user', content: 'hi' },
290+
{ role: 'assistant', content: 'All good' },
291+
]),
292+
false
293+
);
294+
expect(response.guardrail_results).toBeInstanceOf(GuardrailResultsImpl);
295+
expect(response.guardrail_results.output).toEqual([outputResult]);
296+
});
297+
});
298+
});

0 commit comments

Comments
 (0)