Skip to content

Commit 0d85038

Browse files
committed
add some API tests for hybrid
1 parent b9f3709 commit 0d85038

File tree

3 files changed

+52
-21
lines changed

3 files changed

+52
-21
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/**
2+
* @license
3+
* Copyright 2024 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
import {
18+
getAI,
19+
getGenerativeModel,
20+
} from './api';
21+
import { expect } from 'chai';
22+
import { InferenceMode } from './public-types';
23+
import { getFullApp } from '../test-utils/get-fake-firebase-services';
24+
import { DEFAULT_HYBRID_IN_CLOUD_MODEL } from './constants';
25+
import { factory } from './factory-browser';
26+
27+
/**
28+
* Browser-only top level API tests using a factory that provides
29+
* a ChromeAdapter.
30+
*/
31+
describe('Top level API', () => {
32+
describe('getAI()', () => {
33+
it('getGenerativeModel with HybridParams sets a default model', () => {
34+
const ai = getAI(getFullApp({ apiKey: 'key', appId: 'id'}, factory));
35+
const genModel = getGenerativeModel(ai, {
36+
mode: InferenceMode.ONLY_ON_DEVICE,
37+
});
38+
expect(genModel.model).to.equal(
39+
`models/${DEFAULT_HYBRID_IN_CLOUD_MODEL}`
40+
);
41+
});
42+
it('getGenerativeModel with HybridParams honors a model override', () => {
43+
const ai = getAI(getFullApp({ apiKey: 'key', appId: 'id'}, factory));
44+
const genModel = getGenerativeModel(ai, {
45+
mode: InferenceMode.PREFER_ON_DEVICE,
46+
inCloudParams: { model: 'my-model' }
47+
});
48+
expect(genModel.model).to.equal('models/my-model');
49+
});
50+
});
51+
});

packages/ai/src/api.test.ts

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,6 @@ describe('Top level API', () => {
144144
expect(genModel).to.be.an.instanceOf(GenerativeModel);
145145
expect(genModel.model).to.equal('publishers/google/models/my-model');
146146
});
147-
it('getGenerativeModel with HybridParams sets a default model', () => {
148-
const genModel = getGenerativeModel(fakeAI, {
149-
mode: 'only_on_device'
150-
});
151-
expect(genModel.model).to.equal(
152-
`publishers/google/models/${DEFAULT_HYBRID_IN_CLOUD_MODEL}`
153-
);
154-
});
155-
it('getGenerativeModel with HybridParams honors a model override', () => {
156-
const genModel = getGenerativeModel(fakeAI, {
157-
mode: 'prefer_on_device',
158-
inCloudParams: { model: 'my-model' }
159-
});
160-
expect(genModel.model).to.equal('publishers/google/models/my-model');
161-
});
162147
it('getImagenModel throws if no model is provided', () => {
163148
try {
164149
getImagenModel(fakeAI, {} as ImagenModelParams);

packages/ai/src/factory-node.ts

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,5 @@ export function factory(
4242
const auth = container.getProvider('auth-internal');
4343
const appCheckProvider = container.getProvider('app-check-internal');
4444

45-
return new AIService(
46-
app,
47-
backend,
48-
auth,
49-
appCheckProvider
50-
);
45+
return new AIService(app, backend, auth, appCheckProvider);
5146
}

0 commit comments

Comments
 (0)