Skip to content

Commit bc51f59

Browse files
KavithaSivacloud-sdk-jsZhongpinWangdeekshas8
authored
feat: Add convenience function for LlamaGuard Filter (#555)
* chore: add new filter utility for LlamaGuard * chore: update default * chore: fix test * chore: cleanup * chore: refactor * chore: minor doc fix * chore: adapt changes as per slack discussion * chore: add orchestration client test * chore: add documentation * chore: add release note * fix: Changes from lint * chore: fox typo * chore: add mutiple filters in sample code * chore: fix typo * Update .changeset/hip-tools-smile.md Co-authored-by: Deeksha Sinha <[email protected]> * review * minor * fix: type tests * fix: Changes from lint * reivew * readme --------- Co-authored-by: cloud-sdk-js <[email protected]> Co-authored-by: Zhongpin Wang <[email protected]> Co-authored-by: Deeksha Sinha <[email protected]>
1 parent b1d244c commit bc51f59

File tree

10 files changed

+287
-51
lines changed

10 files changed

+287
-51
lines changed

.changeset/hip-tools-smile.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@sap-ai-sdk/orchestration': minor
3+
---
4+
5+
[New Functionality] Introduce `buildLlamaGuardFilter()` convenience function to build Llama guard filters.

packages/orchestration/README.md

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -336,20 +336,12 @@ Use the orchestration client with filtering to restrict content that is passed t
336336
337337
This feature allows filtering both the [input](https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/input-filtering) and [output](https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/output-filtering) of a model based on content safety criteria.
338338
339-
#### Azure Content Filter
340-
341-
Use `buildAzureContentSafetyFilter()` function to build an Azure content filter for both input and output.
342-
Each category of the filter can be assigned a specific severity level, which corresponds to an Azure threshold value.
343-
344-
| Severity Level | Azure Threshold Value |
345-
| ----------------------- | --------------------- |
346-
| `ALLOW_SAFE` | 0 |
347-
| `ALLOW_SAFE_LOW` | 2 |
348-
| `ALLOW_SAFE_LOW_MEDIUM` | 4 |
349-
| `ALLOW_ALL` | 6 |
339+
The following example demonstrates how to use content filtering with the orchestration client.
340+
See the sections below for details on the available content filters and how to build them.
350341
351342
```ts
352-
import { OrchestrationClient, ContentFilters } from '@sap-ai-sdk/orchestration';
343+
import { OrchestrationClient } from '@sap-ai-sdk/orchestration';
344+
353345
const llm = {
354346
model_name: 'gpt-4o',
355347
model_params: { max_tokens: 50, temperature: 0.1 }
@@ -358,16 +350,14 @@ const templating = {
358350
template: [{ role: 'user', content: '{{?input}}' }]
359351
};
360352

361-
const filter = buildAzureContentSafetyFilter({
362-
Hate: 'ALLOW_SAFE_LOW',
363-
Violence: 'ALLOW_SAFE_LOW_MEDIUM'
364-
});
353+
const filter = ... // Use a build function to create a content filter
354+
365355
const orchestrationClient = new OrchestrationClient({
366356
llm,
367357
templating,
368358
filtering: {
369359
input: {
370-
filters: [filter]
360+
filters: [filter] // Multiple filters can be applied
371361
},
372362
output: {
373363
filters: [filter]
@@ -385,6 +375,38 @@ try {
385375
}
386376
```
387377
378+
Multiple filters can be applied at the same time for both input and output filtering.
379+
380+
#### Azure Content Filter
381+
382+
Use `buildAzureContentSafetyFilter()` function to build an Azure content filter.
383+
Each category of the filter can be assigned a specific severity level, which corresponds to an Azure threshold value.
384+
385+
| Severity Level | Azure Threshold Value |
386+
| ----------------------- | --------------------- |
387+
| `ALLOW_SAFE` | 0 |
388+
| `ALLOW_SAFE_LOW` | 2 |
389+
| `ALLOW_SAFE_LOW_MEDIUM` | 4 |
390+
| `ALLOW_ALL` | 6 |
391+
392+
```ts
393+
const filter = buildAzureContentSafetyFilter({
394+
Hate: 'ALLOW_SAFE_LOW',
395+
Violence: 'ALLOW_SAFE_LOW_MEDIUM'
396+
});
397+
```
398+
399+
#### Llama Guard Filter
400+
401+
Use `buildLlamaGuardFilter()` function to build a Llama Guard content filter.
402+
403+
Available categories can be found with autocompletion.
404+
Pass the categories as arguments to the function to enable them.
405+
406+
```ts
407+
const filter = buildLlamaGuardFilter('hate', 'violent_crimes');
408+
```
409+
388410
#### Error Handling
389411
390412
Both `chatCompletion()` and `getContent()` methods can throw errors.

packages/orchestration/src/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ export type {
99
DocumentGroundingServiceConfig,
1010
DocumentGroundingServiceFilter,
1111
LlmModelParams,
12+
LlamaGuardCategory,
1213
AzureContentFilter,
1314
AzureFilterThreshold
1415
} from './orchestration-types.js';
@@ -24,6 +25,7 @@ export { OrchestrationClient } from './orchestration-client.js';
2425
export {
2526
buildAzureContentFilter,
2627
buildAzureContentSafetyFilter,
28+
buildLlamaGuardFilter,
2729
buildDocumentGroundingConfig
2830
} from './util/index.js';
2931

packages/orchestration/src/orchestration-client.test.ts

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ import { OrchestrationResponse } from './orchestration-response.js';
1313
import {
1414
constructCompletionPostRequestFromJsonModuleConfig,
1515
constructCompletionPostRequest,
16-
buildAzureContentSafetyFilter
16+
buildAzureContentSafetyFilter,
17+
buildLlamaGuardFilter
1718
} from './util/index.js';
1819
import type { CompletionPostResponse } from './client/api/schema/index.js';
1920
import type {
@@ -206,6 +207,59 @@ describe('orchestration service client', () => {
206207
expect(response.data).toEqual(mockResponse);
207208
});
208209

210+
it('calls chatCompletion with filter configuration supplied using multiple convenience functions', async () => {
211+
const llamaFilter = buildLlamaGuardFilter('self_harm');
212+
const azureContentFilter = buildAzureContentSafetyFilter({
213+
Sexual: 'ALLOW_SAFE'
214+
});
215+
const config: OrchestrationModuleConfig = {
216+
llm: {
217+
model_name: 'gpt-35-turbo-16k',
218+
model_params: { max_tokens: 50, temperature: 0.1 }
219+
},
220+
templating: {
221+
template: [
222+
{
223+
role: 'user',
224+
content: 'Create {{?number}} paraphrases of {{?phrase}}'
225+
}
226+
]
227+
},
228+
filtering: {
229+
input: {
230+
filters: [llamaFilter, azureContentFilter]
231+
},
232+
output: {
233+
filters: [llamaFilter, azureContentFilter]
234+
}
235+
}
236+
};
237+
const prompt = {
238+
inputParams: { phrase: 'I like myself.', number: '20' }
239+
};
240+
const mockResponse = await parseMockResponse<CompletionPostResponse>(
241+
'orchestration',
242+
'orchestration-chat-completion-multiple-filter-config.json'
243+
);
244+
245+
mockInference(
246+
{
247+
data: constructCompletionPostRequest(config, prompt)
248+
},
249+
{
250+
data: mockResponse,
251+
status: 200
252+
},
253+
{
254+
url: 'inference/deployments/1234/completion'
255+
}
256+
);
257+
const response = await new OrchestrationClient(config).chatCompletion(
258+
prompt
259+
);
260+
expect(response.data).toEqual(mockResponse);
261+
});
262+
209263
it('calls chatCompletion with filtering configuration', async () => {
210264
const config: OrchestrationModuleConfig = {
211265
llm: {

packages/orchestration/src/orchestration-types.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import type {
88
FilteringStreamOptions,
99
GlobalStreamOptions,
1010
GroundingModuleConfig,
11+
LlamaGuard38B,
1112
MaskingModuleConfig,
1213
LlmModuleConfig as OriginalLlmModuleConfig,
1314
TemplatingModuleConfig
@@ -201,3 +202,8 @@ export const supportedAzureFilterThresholds = {
201202
*
202203
*/
203204
export type AzureFilterThreshold = keyof typeof supportedAzureFilterThresholds;
205+
206+
/**
207+
* The filter categories supported for Llama guard filter.
208+
*/
209+
export type LlamaGuardCategory = keyof LlamaGuard38B;

packages/orchestration/src/util/filtering.test.ts

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import {
22
buildAzureContentFilter,
3-
buildAzureContentSafetyFilter
3+
buildAzureContentSafetyFilter,
4+
buildLlamaGuardFilter
45
} from './filtering.js';
56
import { constructCompletionPostRequest } from './module-config.js';
67
import type { OrchestrationModuleConfig } from '../orchestration-types.js';
@@ -211,4 +212,34 @@ describe('Content filter util', () => {
211212
);
212213
});
213214
});
215+
216+
describe('Llama Guard filter', () => {
217+
it('builds filter config with custom config', async () => {
218+
const filterConfig = buildLlamaGuardFilter('elections', 'hate');
219+
const expectedFilterConfig = {
220+
type: 'llama_guard_3_8b',
221+
config: {
222+
elections: true,
223+
hate: true
224+
}
225+
};
226+
expect(filterConfig).toEqual(expectedFilterConfig);
227+
});
228+
229+
it('builds filter config without duplicates', async () => {
230+
const filterConfig = buildLlamaGuardFilter(
231+
'non_violent_crimes',
232+
'privacy',
233+
'non_violent_crimes'
234+
);
235+
const expectedFilterConfig = {
236+
type: 'llama_guard_3_8b',
237+
config: {
238+
non_violent_crimes: true,
239+
privacy: true
240+
}
241+
};
242+
expect(filterConfig).toEqual(expectedFilterConfig);
243+
});
244+
});
214245
});

packages/orchestration/src/util/filtering.ts

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@ import type {
33
AzureContentSafety,
44
AzureContentSafetyFilterConfig,
55
InputFilteringConfig,
6+
LlamaGuard38BFilterConfig,
67
OutputFilteringConfig
78
} from '../client/api/schema/index.js';
89
import type {
910
AzureContentFilter,
10-
AzureFilterThreshold
11+
AzureFilterThreshold,
12+
LlamaGuardCategory
1113
} from '../orchestration-types.js';
1214

1315
/**
14-
* Convenience function to create Azure content filters.
16+
* Convenience function to build Azure content filter.
1517
* @param filter - Filtering configuration for Azure filter. If skipped, the default Azure content filter configuration is used.
1618
* @returns An object with the Azure filtering configuration.
1719
* @deprecated Since 1.8.0. Use {@link buildAzureContentSafetyFilter()} instead.
@@ -33,10 +35,11 @@ export function buildAzureContentFilter(
3335
}
3436

3537
/**
36-
* Convenience function to create Azure content filters.
38+
* Convenience function to build Azure content filter.
3739
* @param config - Configuration for Azure content safety filter.
3840
* If skipped, the default configuration of `ALLOW_SAFE_LOW` is used for all filter categories.
3941
* @returns Filter config object.
42+
* @example "buildAzureContentSafetyFilter({ Hate: 'ALLOW_SAFE', Violence: 'ALLOW_SAFE_LOW_MEDIUM' })"
4043
*/
4144
export function buildAzureContentSafetyFilter(
4245
config?: AzureContentFilter
@@ -58,3 +61,20 @@ export function buildAzureContentSafetyFilter(
5861
})
5962
};
6063
}
64+
65+
/**
66+
* Convenience function to build Llama guard filter.
67+
* @param categories - Categories to be enabled for filtering. Provide at least one category.
68+
* @returns Filter config object.
69+
* @example "buildLlamaGuardFilter('self_harm', 'hate')"
70+
*/
71+
export function buildLlamaGuardFilter(
72+
...categories: [LlamaGuardCategory, ...LlamaGuardCategory[]]
73+
): LlamaGuard38BFilterConfig {
74+
return {
75+
type: 'llama_guard_3_8b',
76+
config: Object.fromEntries(
77+
[...categories].map(category => [category, true])
78+
)
79+
};
80+
}

0 commit comments

Comments
 (0)