Skip to content

Commit e68fb50

Browse files
committed
Add tests for streaming chat sesssions to ensure thought signatures are passed back.
1 parent 9cc0669 commit e68fb50

File tree

1 file changed

+130
-0
lines changed
  • firebase-ai/src/test/java/com/google/firebase/ai

1 file changed

+130
-0
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.firebase.ai
18+
19+
import com.google.firebase.FirebaseApp
20+
import com.google.firebase.ai.common.APIController
21+
import com.google.firebase.ai.common.util.doBlocking
22+
import com.google.firebase.ai.type.RequestOptions
23+
import com.google.firebase.ai.type.TextPart
24+
import io.ktor.client.engine.mock.MockEngine
25+
import io.ktor.client.engine.mock.respond
26+
import io.ktor.http.HttpHeaders
27+
import io.ktor.http.HttpStatusCode
28+
import io.ktor.http.headersOf
29+
import kotlin.time.Duration.Companion.seconds
30+
import kotlinx.coroutines.flow.collect
31+
import kotlinx.coroutines.withTimeout
32+
import org.junit.Before
33+
import org.junit.Test
34+
import org.mockito.Mockito
35+
import com.google.common.truth.Truth.assertThat
36+
import com.google.firebase.ai.type.content
37+
38+
class ChatTest {
39+
private val TEST_CLIENT_ID = "test"
40+
private val TEST_APP_ID = "1:android:12345"
41+
private val TEST_VERSION = 1
42+
43+
private var mockFirebaseApp: FirebaseApp = Mockito.mock<FirebaseApp>()
44+
45+
@Before
46+
fun setup() {
47+
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
48+
}
49+
50+
@Test
51+
fun `sendMessageStream preserves thoughtSignature in history`() = doBlocking {
52+
val mockResponse = """
53+
[
54+
{
55+
"candidates": [
56+
{
57+
"content": {
58+
"role": "model",
59+
"parts": [
60+
{
61+
"text": "This is a thought.",
62+
"thought": true,
63+
"thoughtSignature": "thought1"
64+
}
65+
]
66+
}
67+
}
68+
]
69+
},
70+
{
71+
"candidates": [
72+
{
73+
"content": {
74+
"role": "model",
75+
"parts": [
76+
{
77+
"text": "This is not a thought."
78+
}
79+
]
80+
}
81+
}
82+
]
83+
}
84+
]
85+
""".trimIndent()
86+
val mockEngine = MockEngine {
87+
respond(
88+
mockResponse,
89+
HttpStatusCode.OK,
90+
headersOf(HttpHeaders.ContentType, "application/json")
91+
)
92+
}
93+
94+
val apiController =
95+
APIController(
96+
"super_cool_test_key",
97+
"gemini-2.5-flash",
98+
RequestOptions(timeout = 5.seconds),
99+
mockEngine,
100+
TEST_CLIENT_ID,
101+
mockFirebaseApp,
102+
TEST_VERSION,
103+
TEST_APP_ID,
104+
null,
105+
)
106+
107+
val generativeModel =
108+
GenerativeModel(
109+
"gemini-2.5-flash",
110+
controller = apiController
111+
)
112+
val chat = Chat(generativeModel)
113+
114+
withTimeout(5.seconds) {
115+
chat.sendMessageStream("my test prompt").collect()
116+
}
117+
118+
val history = chat.history
119+
assertThat(history).hasSize(2)
120+
val modelResponse = history[1]
121+
assertThat(modelResponse.role).isEqualTo("model")
122+
assertThat(modelResponse.parts).hasSize(2)
123+
val thoughtPart = modelResponse.parts[0] as TextPart
124+
assertThat(thoughtPart.isThought).isTrue()
125+
assertThat(thoughtPart.thoughtSignature).isEqualTo("thought1")
126+
val textPart = modelResponse.parts[1] as TextPart
127+
assertThat(textPart.isThought).isFalse()
128+
assertThat(textPart.thoughtSignature).isNull()
129+
}
130+
}

0 commit comments

Comments
 (0)