Skip to content

Commit 160997e

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Support fully override base_url and raw model name when none of the project, locations, api_key are configured
PiperOrigin-RevId: 819363064
1 parent 9c8147b commit 160997e

File tree

3 files changed

+203
-117
lines changed

3 files changed

+203
-117
lines changed

google/genai/_api_client.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def __init__(
547547
http_options: Optional[HttpOptionsOrDict] = None,
548548
):
549549
self.vertexai = vertexai
550+
self.custom_base_url = None
550551
if self.vertexai is None:
551552
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [
552553
'true',
@@ -628,29 +629,37 @@ def __init__(
628629
)
629630
self.api_key = None
630631

632+
self.custom_base_url = (
633+
validated_http_options.base_url
634+
if validated_http_options.base_url
635+
else None
636+
)
637+
631638
# Skip fetching project from ADC if base url is provided in http options.
632639
if (
633640
not self.project
634641
and not self.api_key
635-
and not validated_http_options.base_url
642+
and not self.custom_base_url
636643
):
637644
credentials, self.project = load_auth(project=None)
638645
if not self._credentials:
639646
self._credentials = credentials
640647

641648
has_sufficient_auth = (self.project and self.location) or self.api_key
642649

643-
if not has_sufficient_auth and not validated_http_options.base_url:
650+
if not has_sufficient_auth and not self.custom_base_url:
644651
# Skip sufficient auth check if base url is provided in http options.
645652
raise ValueError(
646653
'Project and location or API key must be set when using the Vertex '
647654
'AI API.'
648655
)
649656
if self.api_key or self.location == 'global':
650657
self._http_options.base_url = f'https://aiplatform.googleapis.com/'
651-
elif validated_http_options.base_url and not has_sufficient_auth:
658+
elif self.custom_base_url and not has_sufficient_auth:
652659
# Avoid setting default base url and api version if base_url provided.
653-
self._http_options.base_url = validated_http_options.base_url
660+
# API gateway proxy can use the auth in custom headers, not url.
661+
# Enable custom url if auth is not sufficient.
662+
self._http_options.base_url = self.custom_base_url
654663
else:
655664
self._http_options.base_url = (
656665
f'https://{self.location}-aiplatform.googleapis.com/'
@@ -897,6 +906,11 @@ def _use_aiohttp(self) -> bool:
897906
)
898907

899908
def _websocket_base_url(self) -> str:
909+
has_sufficient_auth = (self.project and self.location) or self.api_key
910+
if self.custom_base_url and not has_sufficient_auth:
911+
# API gateway proxy can use the auth in custom headers, not url.
912+
# Enable custom url if auth is not sufficient.
913+
return self.custom_base_url
900914
url_parts = urlparse(self._http_options.base_url)
901915
return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
902916

google/genai/live.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,8 @@ async def connect(
980980
api_key = self._api_client.api_key
981981
version = self._api_client._http_options.api_version
982982
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
983-
headers = self._api_client._http_options.headers or {}
983+
original_headers = self._api_client._http_options.headers
984+
headers = original_headers.copy() if original_headers is not None else {}
984985

985986
request_dict = _common.convert_to_dict(
986987
live_converters._LiveConnectParameters_to_vertex(
@@ -1012,12 +1013,24 @@ async def connect(
10121013
bearer_token = creds.token
10131014
original_headers = self._api_client._http_options.headers
10141015
headers = original_headers.copy() if original_headers is not None else {}
1015-
headers['Authorization'] = f'Bearer {bearer_token}'
1016+
if not headers.get('Authorization'):
1017+
headers['Authorization'] = f'Bearer {bearer_token}'
10161018
version = self._api_client._http_options.api_version
1017-
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
1019+
1020+
has_sufficient_auth = (
1021+
self._api_client.project and self._api_client.location
1022+
)
1023+
if self._api_client.custom_base_url and not has_sufficient_auth:
1024+
# API gateway proxy can use the auth in custom headers, not url.
1025+
# Enable custom url if auth is not sufficient.
1026+
uri = self._api_client.custom_base_url
1027+
# Keep the model as is.
1028+
transformed_model = model
1029+
else:
1030+
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
10181031
location = self._api_client.location
10191032
project = self._api_client.project
1020-
if transformed_model.startswith('publishers/'):
1033+
if transformed_model.startswith('publishers/') and project and location:
10211034
transformed_model = (
10221035
f'projects/{project}/locations/{location}/' + transformed_model
10231036
)

0 commit comments

Comments
 (0)