1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import os
15- from typing import Literal , Optional
15+ from typing import Literal , Optional , Union , Any
1616from google import genai
1717from google .genai import types
1818import termcolor
3333console = Console ()
3434
3535
36+ # Built-in Computer Use tools will return "EnvState".
37+ # Custom provided functions will return "dict".
38+ FunctionResponseT = Union [EnvState , dict ]
39+
40+
41+ def multiply_numbers (x : float , y : float ) -> dict :
42+ """Multiplies two numbers."""
43+ return {"result" : x * y }
44+
45+
3646class BrowserAgent :
3747 def __init__ (
3848 self ,
@@ -50,8 +60,10 @@ def __init__(
5060 location = os .environ .get ("VERTEXAI_LOCATION" ),
5161 http_options = types .HttpOptions (
5262 api_version = "v1alpha" ,
53- base_url = "https://generativelanguage.googleapis.com" ,
54- )
63+ base_url = os .environ .get (
64+ "GEMINI_API_SERVER" , "https://generativelanguage.googleapis.com"
65+ ),
66+ ),
5567 )
5668 self ._contents : list [Content ] = [
5769 Content (
@@ -61,21 +73,36 @@ def __init__(
6173 ],
6274 )
6375 ]
76+
77+ # Exclude any predefined functions here.
78+ excluded_predefined_functions = []
79+
80+ # Add your own custom functions here.
81+ custom_functions = [
82+ # For example:
83+ types .FunctionDeclaration .from_callable (
84+ client = self ._client , callable = multiply_numbers
85+ )
86+ ]
87+
6488 self ._generate_content_config = GenerateContentConfig (
6589 temperature = 1 ,
6690 top_p = 0.95 ,
6791 top_k = 40 ,
6892 max_output_tokens = 8192 ,
6993 tools = [
7094 types .Tool (
71- computer_use = types .ComputerUse (
72- environment = types .Environment .ENVIRONMENT_BROWSER
73- )
74- )
95+ computer_use = types .ToolComputerUse (
96+ environment = types .Environment .ENVIRONMENT_BROWSER ,
97+ excluded_predefined_functions = excluded_predefined_functions ,
98+ ),
99+ ),
100+ types .Tool (function_declarations = custom_functions ),
75101 ],
102+ thinking_config = types .ThinkingConfig (include_thoughts = True ),
76103 )
77104
78- def handle_action (self , action : types .FunctionCall ) -> EnvState :
105+ def handle_action (self , action : types .FunctionCall ) -> FunctionResponseT :
79106 """Handles the action and returns the environment state."""
80107 if action .name == "open_web_browser" :
81108 return self ._browser_computer .open_web_browser ()
@@ -96,7 +123,7 @@ def handle_action(self, action: types.FunctionCall) -> EnvState:
96123 elif action .name == "type_text_at" :
97124 x = self .normalize_x (action .args ["x" ])
98125 y = self .normalize_y (action .args ["y" ])
99- press_enter = action .args .get ("press_enter" , True )
126+ press_enter = action .args .get ("press_enter" , False )
100127 clear_before_typing = action .args .get ("clear_before_typing" , True )
101128 return self ._browser_computer .type_text_at (
102129 x = x ,
@@ -110,7 +137,7 @@ def handle_action(self, action: types.FunctionCall) -> EnvState:
110137 elif action .name == "scroll_at" :
111138 x = self .normalize_x (action .args ["x" ])
112139 y = self .normalize_y (action .args ["y" ])
113- magnitude = action .args .get ("magnitude" , 200 )
140+ magnitude = action .args .get ("magnitude" , 800 )
114141 direction = action .args ["direction" ]
115142
116143 if direction in ("up" , "down" ):
@@ -147,6 +174,9 @@ def handle_action(self, action: types.FunctionCall) -> EnvState:
147174 destination_x = destination_x ,
148175 destination_y = destination_y ,
149176 )
177+ # Handle the custom function declarations here.
178+ elif action .name == multiply_numbers .__name__ :
179+ return multiply_numbers (x = action .args ["x" ], y = action .args ["y" ])
150180 else :
151181 raise ValueError (f"Unsupported function: { action } " )
152182
@@ -189,12 +219,13 @@ def get_text(self, candidate: Candidate) -> Optional[str]:
189219 text .append (part .text )
190220 return " " .join (text ) or None
191221
192- def get_function_call (self , candidate : Candidate ) -> Optional [types .FunctionCall ]:
222+ def extract_function_calls (self , candidate : Candidate ) -> list [types .FunctionCall ]:
193223 """Extracts the function call from the candidate."""
224+ ret = []
194225 for part in candidate .content .parts :
195226 if part .function_call :
196- return part .function_call
197- return None
227+ ret . append ( part .function_call )
228+ return ret
198229
199230 def run_one_iteration (self ) -> Literal ["COMPLETE" , "CONTINUE" ]:
200231 # Generate a response from the model.
@@ -204,89 +235,100 @@ def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]:
204235 except Exception as e :
205236 return "COMPLETE"
206237
238+ if not response .candidates :
239+ print ("Response has no candidates!" )
240+ print (response )
241+ raise ValueError ("Empty response" )
242+
207243 # Extract the text and function call from the response.
208244 candidate = response .candidates [0 ]
209- reasoning = self .get_text (candidate )
210- function_call = self .get_function_call (candidate )
211-
212- # Append the model turn.
245+ # Append the model turn to conversation history.
213246 self ._contents .append (candidate .content )
214247
215- if not function_call or not function_call .name :
248+ reasoning = self .get_text (candidate )
249+ function_calls = self .extract_function_calls (candidate )
250+ if not function_calls :
216251 print (f"Agent Loop Complete: { reasoning } " )
217252 return "COMPLETE"
218253
219- # Print the function call and any reasoning.
220- function_call_str = f"Name: { function_call .name } "
221- if function_call .args :
222- function_call_str += f"\n Args:"
223- for key , value in function_call .args .items ():
224- function_call_str += f"\n { key } : { value } "
254+ function_call_strs = []
255+ for function_call in function_calls :
256+ # Print the function call and any reasoning.
257+ function_call_str = f"Name: { function_call .name } "
258+ if function_call .args :
259+ function_call_str += f"\n Args:"
260+ for key , value in function_call .args .items ():
261+ function_call_str += f"\n { key } : { value } "
262+ function_call_strs .append (function_call_str )
263+
225264 table = Table (expand = True )
226265 table .add_column ("Gemini Reasoning" , header_style = "magenta" , ratio = 1 )
227- table .add_column ("Function Call" , header_style = "cyan" , ratio = 1 )
228- table .add_row (
229- reasoning ,
230- function_call_str ,
231- )
266+ table .add_column ("Function Call(s)" , header_style = "cyan" , ratio = 1 )
267+ table .add_row (reasoning , "\n " .join (function_call_strs ))
232268 console .print (table )
233269 print ()
234270
235- if safety := function_call .args .get ("safety_decision" ):
236- if safety ["decision" ] == "block" :
237- termcolor .cprint (
238- "Terminating loop due to safety block!" ,
239- color = "yellow" ,
240- attrs = ["bold" ],
271+ function_responses = []
272+ for function_call in function_calls :
273+ if function_call .args and (
274+ safety := function_call .args .get ("safety_decision" )
275+ ):
276+ decision = self ._get_safety_confirmation (safety )
277+ if decision == "TERMINATE" :
278+ print ("Terminating agent loop" )
279+ return "COMPLETE"
280+ with console .status ("Sending command to Computer..." , spinner_style = None ):
281+ fc_result = self .handle_action (function_call )
282+ if isinstance (fc_result , EnvState ):
283+ function_responses .append (
284+ FunctionResponse (
285+ name = function_call .name ,
286+ response = {
287+ "image" : {
288+ "mimetype" : "image/png" ,
289+ "data" : base64 .b64encode (fc_result .screenshot ).decode (
290+ "utf-8"
291+ ),
292+ },
293+ "url" : fc_result .url ,
294+ },
295+ )
241296 )
242- print (safety ["explanation" ])
243- return "COMPLETE"
244- elif safety ["decision" ] == "require_confirmation" :
245- termcolor .cprint (
246- "Safety service requires explicit confirmation!" ,
247- color = "yellow" ,
248- attrs = ["bold" ],
297+ elif isinstance (fc_result , dict ):
298+ function_responses .append (
299+ FunctionResponse (name = function_call .name , response = fc_result )
249300 )
250- print (safety ["explanation" ])
251- decision = ""
252- while decision .lower () not in ("y" , "n" , "ye" , "yes" , "no" ):
253- decision = input ("Do you wish to proceed? [Y]es/[n]o\n " )
254- if decision .lower () in ("n" , "no" ):
255- print ("Terminating agent loop." )
256- return "COMPLETE"
257- print ("Proceeding with agent loop.\n " )
258-
259- with console .status ("Sending command to Computer..." , spinner_style = None ):
260- environment_state = self .handle_action (function_call )
261301
262302 self ._contents .append (
263303 Content (
264304 role = "user" ,
265- parts = [
266- Part (
267- function_response = FunctionResponse (
268- name = function_call .name ,
269- response = {
270- "image" : {
271- "mimetype" : "image/png" ,
272- "data" : base64 .b64encode (
273- environment_state .screenshot
274- ).decode ("utf-8" ),
275- },
276- "url" : environment_state .url ,
277- },
278- )
279- )
280- ],
305+ parts = [Part (function_response = fr ) for fr in function_responses ],
281306 )
282307 )
283308 return "CONTINUE"
284309
310+ def _get_safety_confirmation (
311+ self , safety : dict [str , Any ]
312+ ) -> Literal ["CONTINUE" , "TERMINATE" ]:
313+ if safety ["decision" ] != "require_confirmation" :
314+ raise ValueError (f"Unknown safety decision: safety['decision']" )
315+ termcolor .cprint (
316+ "Safety service requires explicit confirmation!" ,
317+ color = "yellow" ,
318+ attrs = ["bold" ],
319+ )
320+ print (safety ["explanation" ])
321+ decision = ""
322+ while decision .lower () not in ("y" , "n" , "ye" , "yes" , "no" ):
323+ decision = input ("Do you wish to proceed? [Y]es/[n]o\n " )
324+ if decision .lower () in ("n" , "no" ):
325+ return "TERMINATE"
326+ return "CONTINUE"
327+
285328 def agent_loop (self ):
286- while True :
329+ status = "CONTINUE"
330+ while status == "CONTINUE" :
287331 status = self .run_one_iteration ()
288- if status == "COMPLETE" :
289- return
290332
291333 def normalize_x (self , x : int ) -> int :
292334 return int (x / 1000 * self ._browser_computer .screen_size ()[0 ])
0 commit comments