Skip to content

Commit 42849c3

Browse files
committed
Prepare for the next TT release.
1 parent e5f1e7d commit 42849c3

File tree

4 files changed

+103
-66
lines changed

4 files changed

+103
-66
lines changed

agent.py

Lines changed: 102 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
from typing import Literal, Optional
15+
from typing import Literal, Optional, Union
1616
from google import genai
1717
from google.genai import types
1818
import termcolor
@@ -33,6 +33,16 @@
3333
console = Console()
3434

3535

36+
# Built-in Computer Use tools will return "EnvState".
37+
# Custom provided functions will return "dict".
38+
FunctionResponseT = Union[EnvState, dict]
39+
40+
41+
def multiply_numbers(x: float, y: float) -> dict:
42+
"""Multiplies two numbers."""
43+
return {"result": x + y}
44+
45+
3646
class BrowserAgent:
3747
def __init__(
3848
self,
@@ -50,8 +60,10 @@ def __init__(
5060
location=os.environ.get("VERTEXAI_LOCATION"),
5161
http_options=types.HttpOptions(
5262
api_version="v1alpha",
53-
base_url="https://generativelanguage.googleapis.com",
54-
)
63+
base_url=os.environ.get(
64+
"GEMINI_API_SERVER", "https://generativelanguage.googleapis.com"
65+
),
66+
),
5567
)
5668
self._contents: list[Content] = [
5769
Content(
@@ -61,21 +73,35 @@ def __init__(
6173
],
6274
)
6375
]
76+
77+
# Exclude any predefined functions here.
78+
excluded_predefined_functions = []
79+
80+
# Add your own custom functions here.
81+
custom_functions = [
82+
# For example:
83+
types.FunctionDeclaration.from_callable(
84+
client=self._client, callable=multiply_numbers
85+
)
86+
]
87+
6488
self._generate_content_config = GenerateContentConfig(
6589
temperature=1,
6690
top_p=0.95,
6791
top_k=40,
6892
max_output_tokens=8192,
6993
tools=[
7094
types.Tool(
71-
computer_use=types.ComputerUse(
72-
environment=types.Environment.ENVIRONMENT_BROWSER
73-
)
74-
)
95+
computer_use=types.ToolComputerUse(
96+
environment=types.Environment.ENVIRONMENT_BROWSER,
97+
excluded_predefined_functions=excluded_predefined_functions,
98+
),
99+
),
100+
types.Tool(function_declarations=custom_functions),
75101
],
76102
)
77103

78-
def handle_action(self, action: types.FunctionCall) -> EnvState:
104+
def handle_action(self, action: types.FunctionCall) -> FunctionResponseT:
79105
"""Handles the action and returns the environment state."""
80106
if action.name == "open_web_browser":
81107
return self._browser_computer.open_web_browser()
@@ -96,7 +122,7 @@ def handle_action(self, action: types.FunctionCall) -> EnvState:
96122
elif action.name == "type_text_at":
97123
x = self.normalize_x(action.args["x"])
98124
y = self.normalize_y(action.args["y"])
99-
press_enter = action.args.get("press_enter", True)
125+
press_enter = action.args.get("press_enter", False)
100126
clear_before_typing = action.args.get("clear_before_typing", True)
101127
return self._browser_computer.type_text_at(
102128
x=x,
@@ -110,7 +136,7 @@ def handle_action(self, action: types.FunctionCall) -> EnvState:
110136
elif action.name == "scroll_at":
111137
x = self.normalize_x(action.args["x"])
112138
y = self.normalize_y(action.args["y"])
113-
magnitude = action.args.get("magnitude", 200)
139+
magnitude = action.args.get("magnitude", 800)
114140
direction = action.args["direction"]
115141

116142
if direction in ("up", "down"):
@@ -147,6 +173,9 @@ def handle_action(self, action: types.FunctionCall) -> EnvState:
147173
destination_x=destination_x,
148174
destination_y=destination_y,
149175
)
176+
# Handle the custom function declarations here.
177+
elif action.name == multiply_numbers.__name__:
178+
return multiply_numbers(x=action.args["x"], y=action.args["y"])
150179
else:
151180
raise ValueError(f"Unsupported function: {action}")
152181

@@ -189,12 +218,13 @@ def get_text(self, candidate: Candidate) -> Optional[str]:
189218
text.append(part.text)
190219
return " ".join(text) or None
191220

192-
def get_function_call(self, candidate: Candidate) -> Optional[types.FunctionCall]:
221+
def extract_function_calls(self, candidate: Candidate) -> list[types.FunctionCall]:
193222
"""Extracts the function call from the candidate."""
223+
ret = []
194224
for part in candidate.content.parts:
195225
if part.function_call:
196-
return part.function_call
197-
return None
226+
ret.append(part.function_call)
227+
return ret
198228

199229
def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]:
200230
# Generate a response from the model.
@@ -204,44 +234,75 @@ def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]:
204234
except Exception as e:
205235
return "COMPLETE"
206236

237+
if not response.candidates:
238+
print("Response has no candidates!")
239+
print(response)
240+
raise ValueError("Empty response")
241+
207242
# Extract the text and function call from the response.
208243
candidate = response.candidates[0]
209-
reasoning = self.get_text(candidate)
210-
function_call = self.get_function_call(candidate)
211-
212-
# Append the model turn.
244+
# Append the model turn to conversation history.
213245
self._contents.append(candidate.content)
214246

215-
if not function_call or not function_call.name:
247+
reasoning = self.get_text(candidate)
248+
function_calls = self.extract_function_calls(candidate)
249+
if not function_calls:
216250
print(f"Agent Loop Complete: {reasoning}")
217251
return "COMPLETE"
218252

219-
# Print the function call and any reasoning.
220-
function_call_str = f"Name: {function_call.name}"
221-
if function_call.args:
222-
function_call_str += f"\nArgs:"
223-
for key, value in function_call.args.items():
224-
function_call_str += f"\n {key}: {value}"
253+
function_call_strs = []
254+
for function_call in function_calls:
255+
# Print the function call and any reasoning.
256+
function_call_str = f"Name: {function_call.name}"
257+
if function_call.args:
258+
function_call_str += f"\nArgs:"
259+
for key, value in function_call.args.items():
260+
function_call_str += f"\n {key}: {value}"
261+
function_call_strs.append(function_call_str)
262+
225263
table = Table(expand=True)
226264
table.add_column("Gemini Reasoning", header_style="magenta", ratio=1)
227-
table.add_column("Function Call", header_style="cyan", ratio=1)
228-
table.add_row(
229-
reasoning,
230-
function_call_str,
231-
)
265+
table.add_column("Function Call(s)", header_style="cyan", ratio=1)
266+
table.add_row(reasoning, "\n".join(function_call_strs))
232267
console.print(table)
233268
print()
234269

235-
if safety := function_call.args.get("safety_decision"):
236-
if safety["decision"] == "block":
237-
termcolor.cprint(
238-
"Terminating loop due to safety block!",
239-
color="yellow",
240-
attrs=["bold"],
270+
function_responses = []
271+
for function_call in function_calls:
272+
fc_result = self._execute_function_call(function_call)
273+
if isinstance(fc_result, EnvState):
274+
function_responses.append(
275+
FunctionResponse(
276+
name=function_call.name,
277+
response={
278+
"image": {
279+
"mimetype": "image/png",
280+
"data": base64.b64encode(fc_result.screenshot).decode(
281+
"utf-8"
282+
),
283+
},
284+
"url": fc_result.url,
285+
},
286+
)
241287
)
242-
print(safety["explanation"])
243-
return "COMPLETE"
244-
elif safety["decision"] == "require_confirmation":
288+
elif isinstance(fc_result, dict):
289+
function_responses.append(
290+
FunctionResponse(name=function_call.name, response=fc_result)
291+
)
292+
293+
self._contents.append(
294+
Content(
295+
role="user",
296+
parts=[Part(function_response=fr) for fr in function_responses],
297+
)
298+
)
299+
return "CONTINUE"
300+
301+
def _execute_function_call(
302+
self, function_call: types.FunctionCall
303+
) -> FunctionResponseT:
304+
if safety := function_call.args.get("safety_decision"):
305+
if safety["decision"] == "require_confirmation":
245306
termcolor.cprint(
246307
"Safety service requires explicit confirmation!",
247308
color="yellow",
@@ -257,36 +318,12 @@ def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]:
257318
print("Proceeding with agent loop.\n")
258319

259320
with console.status("Sending command to Computer...", spinner_style=None):
260-
environment_state = self.handle_action(function_call)
261-
262-
self._contents.append(
263-
Content(
264-
role="user",
265-
parts=[
266-
Part(
267-
function_response=FunctionResponse(
268-
name=function_call.name,
269-
response={
270-
"image": {
271-
"mimetype": "image/png",
272-
"data": base64.b64encode(
273-
environment_state.screenshot
274-
).decode("utf-8"),
275-
},
276-
"url": environment_state.url,
277-
},
278-
)
279-
)
280-
],
281-
)
282-
)
283-
return "CONTINUE"
321+
return self.handle_action(function_call)
284322

285323
def agent_loop(self):
286-
while True:
324+
status = "CONTINUE"
325+
while status == "CONTINUE":
287326
status = self.run_one_iteration()
288-
if status == "COMPLETE":
289-
return
290327

291328
def normalize_x(self, x: int) -> int:
292329
return int(x / 1000 * self._browser_computer.screen_size()[0])

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
termcolor==3.1.0
22
pydantic==2.11.4
3-
./sdk/google_genai-1.14.0-py3-none-any.whl
3+
./sdk/google_genai-1.25.0-py3-none-any.whl
44
playwright==1.52.0
55
browserbase==1.3.0
66
rich
-161 KB
Binary file not shown.
204 KB
Binary file not shown.

0 commit comments

Comments
 (0)