Skip to content

Commit 4ef7e9a

Browse files
authored
Merge pull request #27 from google/ericpts/07_16
Prepare code for 07/16 iteration.
2 parents e5f1e7d + 0b3310c commit 4ef7e9a

File tree

9 files changed

+333
-77
lines changed

9 files changed

+333
-77
lines changed

agent.py

Lines changed: 114 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
from typing import Literal, Optional
15+
from typing import Literal, Optional, Union, Any
1616
from google import genai
1717
from google.genai import types
1818
import termcolor
@@ -33,6 +33,16 @@
3333
console = Console()
3434

3535

36+
# Built-in Computer Use tools will return "EnvState".
37+
# Custom provided functions will return "dict".
38+
FunctionResponseT = Union[EnvState, dict]
39+
40+
41+
def multiply_numbers(x: float, y: float) -> dict:
42+
"""Multiplies two numbers."""
43+
return {"result": x * y}
44+
45+
3646
class BrowserAgent:
3747
def __init__(
3848
self,
@@ -50,8 +60,10 @@ def __init__(
5060
location=os.environ.get("VERTEXAI_LOCATION"),
5161
http_options=types.HttpOptions(
5262
api_version="v1alpha",
53-
base_url="https://generativelanguage.googleapis.com",
54-
)
63+
base_url=os.environ.get(
64+
"GEMINI_API_SERVER", "https://generativelanguage.googleapis.com"
65+
),
66+
),
5567
)
5668
self._contents: list[Content] = [
5769
Content(
@@ -61,21 +73,36 @@ def __init__(
6173
],
6274
)
6375
]
76+
77+
# Exclude any predefined functions here.
78+
excluded_predefined_functions = []
79+
80+
# Add your own custom functions here.
81+
custom_functions = [
82+
# For example:
83+
types.FunctionDeclaration.from_callable(
84+
client=self._client, callable=multiply_numbers
85+
)
86+
]
87+
6488
self._generate_content_config = GenerateContentConfig(
6589
temperature=1,
6690
top_p=0.95,
6791
top_k=40,
6892
max_output_tokens=8192,
6993
tools=[
7094
types.Tool(
71-
computer_use=types.ComputerUse(
72-
environment=types.Environment.ENVIRONMENT_BROWSER
73-
)
74-
)
95+
computer_use=types.ToolComputerUse(
96+
environment=types.Environment.ENVIRONMENT_BROWSER,
97+
excluded_predefined_functions=excluded_predefined_functions,
98+
),
99+
),
100+
types.Tool(function_declarations=custom_functions),
75101
],
102+
thinking_config=types.ThinkingConfig(include_thoughts=True),
76103
)
77104

78-
def handle_action(self, action: types.FunctionCall) -> EnvState:
105+
def handle_action(self, action: types.FunctionCall) -> FunctionResponseT:
79106
"""Handles the action and returns the environment state."""
80107
if action.name == "open_web_browser":
81108
return self._browser_computer.open_web_browser()
@@ -96,7 +123,7 @@ def handle_action(self, action: types.FunctionCall) -> EnvState:
96123
elif action.name == "type_text_at":
97124
x = self.normalize_x(action.args["x"])
98125
y = self.normalize_y(action.args["y"])
99-
press_enter = action.args.get("press_enter", True)
126+
press_enter = action.args.get("press_enter", False)
100127
clear_before_typing = action.args.get("clear_before_typing", True)
101128
return self._browser_computer.type_text_at(
102129
x=x,
@@ -110,7 +137,7 @@ def handle_action(self, action: types.FunctionCall) -> EnvState:
110137
elif action.name == "scroll_at":
111138
x = self.normalize_x(action.args["x"])
112139
y = self.normalize_y(action.args["y"])
113-
magnitude = action.args.get("magnitude", 200)
140+
magnitude = action.args.get("magnitude", 800)
114141
direction = action.args["direction"]
115142

116143
if direction in ("up", "down"):
@@ -147,6 +174,9 @@ def handle_action(self, action: types.FunctionCall) -> EnvState:
147174
destination_x=destination_x,
148175
destination_y=destination_y,
149176
)
177+
# Handle the custom function declarations here.
178+
elif action.name == multiply_numbers.__name__:
179+
return multiply_numbers(x=action.args["x"], y=action.args["y"])
150180
else:
151181
raise ValueError(f"Unsupported function: {action}")
152182

@@ -189,12 +219,13 @@ def get_text(self, candidate: Candidate) -> Optional[str]:
189219
text.append(part.text)
190220
return " ".join(text) or None
191221

192-
def get_function_call(self, candidate: Candidate) -> Optional[types.FunctionCall]:
222+
def extract_function_calls(self, candidate: Candidate) -> list[types.FunctionCall]:
193223
"""Extracts the function call from the candidate."""
224+
ret = []
194225
for part in candidate.content.parts:
195226
if part.function_call:
196-
return part.function_call
197-
return None
227+
ret.append(part.function_call)
228+
return ret
198229

199230
def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]:
200231
# Generate a response from the model.
@@ -204,89 +235,100 @@ def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]:
204235
except Exception as e:
205236
return "COMPLETE"
206237

238+
if not response.candidates:
239+
print("Response has no candidates!")
240+
print(response)
241+
raise ValueError("Empty response")
242+
207243
# Extract the text and function call from the response.
208244
candidate = response.candidates[0]
209-
reasoning = self.get_text(candidate)
210-
function_call = self.get_function_call(candidate)
211-
212-
# Append the model turn.
245+
# Append the model turn to conversation history.
213246
self._contents.append(candidate.content)
214247

215-
if not function_call or not function_call.name:
248+
reasoning = self.get_text(candidate)
249+
function_calls = self.extract_function_calls(candidate)
250+
if not function_calls:
216251
print(f"Agent Loop Complete: {reasoning}")
217252
return "COMPLETE"
218253

219-
# Print the function call and any reasoning.
220-
function_call_str = f"Name: {function_call.name}"
221-
if function_call.args:
222-
function_call_str += f"\nArgs:"
223-
for key, value in function_call.args.items():
224-
function_call_str += f"\n {key}: {value}"
254+
function_call_strs = []
255+
for function_call in function_calls:
256+
# Print the function call and any reasoning.
257+
function_call_str = f"Name: {function_call.name}"
258+
if function_call.args:
259+
function_call_str += f"\nArgs:"
260+
for key, value in function_call.args.items():
261+
function_call_str += f"\n {key}: {value}"
262+
function_call_strs.append(function_call_str)
263+
225264
table = Table(expand=True)
226265
table.add_column("Gemini Reasoning", header_style="magenta", ratio=1)
227-
table.add_column("Function Call", header_style="cyan", ratio=1)
228-
table.add_row(
229-
reasoning,
230-
function_call_str,
231-
)
266+
table.add_column("Function Call(s)", header_style="cyan", ratio=1)
267+
table.add_row(reasoning, "\n".join(function_call_strs))
232268
console.print(table)
233269
print()
234270

235-
if safety := function_call.args.get("safety_decision"):
236-
if safety["decision"] == "block":
237-
termcolor.cprint(
238-
"Terminating loop due to safety block!",
239-
color="yellow",
240-
attrs=["bold"],
271+
function_responses = []
272+
for function_call in function_calls:
273+
if function_call.args and (
274+
safety := function_call.args.get("safety_decision")
275+
):
276+
decision = self._get_safety_confirmation(safety)
277+
if decision == "TERMINATE":
278+
print("Terminating agent loop")
279+
return "COMPLETE"
280+
with console.status("Sending command to Computer...", spinner_style=None):
281+
fc_result = self.handle_action(function_call)
282+
if isinstance(fc_result, EnvState):
283+
function_responses.append(
284+
FunctionResponse(
285+
name=function_call.name,
286+
response={
287+
"image": {
288+
"mimetype": "image/png",
289+
"data": base64.b64encode(fc_result.screenshot).decode(
290+
"utf-8"
291+
),
292+
},
293+
"url": fc_result.url,
294+
},
295+
)
241296
)
242-
print(safety["explanation"])
243-
return "COMPLETE"
244-
elif safety["decision"] == "require_confirmation":
245-
termcolor.cprint(
246-
"Safety service requires explicit confirmation!",
247-
color="yellow",
248-
attrs=["bold"],
297+
elif isinstance(fc_result, dict):
298+
function_responses.append(
299+
FunctionResponse(name=function_call.name, response=fc_result)
249300
)
250-
print(safety["explanation"])
251-
decision = ""
252-
while decision.lower() not in ("y", "n", "ye", "yes", "no"):
253-
decision = input("Do you wish to proceed? [Y]es/[n]o\n")
254-
if decision.lower() in ("n", "no"):
255-
print("Terminating agent loop.")
256-
return "COMPLETE"
257-
print("Proceeding with agent loop.\n")
258-
259-
with console.status("Sending command to Computer...", spinner_style=None):
260-
environment_state = self.handle_action(function_call)
261301

262302
self._contents.append(
263303
Content(
264304
role="user",
265-
parts=[
266-
Part(
267-
function_response=FunctionResponse(
268-
name=function_call.name,
269-
response={
270-
"image": {
271-
"mimetype": "image/png",
272-
"data": base64.b64encode(
273-
environment_state.screenshot
274-
).decode("utf-8"),
275-
},
276-
"url": environment_state.url,
277-
},
278-
)
279-
)
280-
],
305+
parts=[Part(function_response=fr) for fr in function_responses],
281306
)
282307
)
283308
return "CONTINUE"
284309

310+
def _get_safety_confirmation(
311+
self, safety: dict[str, Any]
312+
) -> Literal["CONTINUE", "TERMINATE"]:
313+
if safety["decision"] != "require_confirmation":
314+
raise ValueError(f"Unknown safety decision: safety['decision']")
315+
termcolor.cprint(
316+
"Safety service requires explicit confirmation!",
317+
color="yellow",
318+
attrs=["bold"],
319+
)
320+
print(safety["explanation"])
321+
decision = ""
322+
while decision.lower() not in ("y", "n", "ye", "yes", "no"):
323+
decision = input("Do you wish to proceed? [Y]es/[n]o\n")
324+
if decision.lower() in ("n", "no"):
325+
return "TERMINATE"
326+
return "CONTINUE"
327+
285328
def agent_loop(self):
286-
while True:
329+
status = "CONTINUE"
330+
while status == "CONTINUE":
287331
status = self.run_one_iteration()
288-
if status == "COMPLETE":
289-
return
290332

291333
def normalize_x(self, x: int) -> int:
292334
return int(x / 1000 * self._browser_computer.screen_size()[0])

apiserver/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ async def delete_session(
188188

189189

190190
# Static HTML5 to test the API.
191-
app.mount("/", StaticFiles(directory="static", html=True), name="static")
191+
static_dir = os.path.join(os.path.dirname(__file__), "static")
192+
app.mount("/", StaticFiles(directory=static_dir, html=True), name="static")
192193

193194
if __name__ == "__main__":
194195
port = int(os.environ.get("PORT", "8000"))

computers/playwright/playwright.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import termcolor
1515
import time
16+
import sys
1617
from ..computer import (
1718
Computer,
1819
EnvState,
@@ -144,15 +145,18 @@ def type_text_at(
144145
x: int,
145146
y: int,
146147
text: str,
147-
press_enter: bool = True,
148+
press_enter: bool = False,
148149
clear_before_typing: bool = True,
149150
) -> EnvState:
150151
self.highlight_mouse(x, y)
151152
self._page.mouse.click(x, y)
152153
self._page.wait_for_load_state()
153154

154155
if clear_before_typing:
155-
self.key_combination(["Control", "A"])
156+
if sys.platform == "darwin":
157+
self.key_combination(["Command", "A"])
158+
else:
159+
self.key_combination(["Control", "A"])
156160
self.key_combination(["Delete"])
157161

158162
self._page.keyboard.type(text)

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
CLOUD_RUN_SCREEN_SIZE = (1920, 1080)
22-
PLAYWRIGHT_SCREEN_SIZE = (1440, 810)
22+
PLAYWRIGHT_SCREEN_SIZE = (1920, 1080)
2323

2424

2525
def main() -> int:

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
termcolor==3.1.0
22
pydantic==2.11.4
3-
./sdk/google_genai-1.14.0-py3-none-any.whl
3+
./sdk/google_genai-1.25.0-py3-none-any.whl
44
playwright==1.52.0
55
browserbase==1.3.0
66
rich
-161 KB
Binary file not shown.
204 KB
Binary file not shown.

0 commit comments

Comments
 (0)