Skip to content

Commit a009d51

Browse files
authored
Update get_services method to expose async service clients (#1037)
1 parent 179d430 commit a009d51

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

google/ads/googleads/client.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
_logger = logging.getLogger(__name__)
3737

3838
_SERVICE_CLIENT_TEMPLATE = "{}Client"
39+
_ASYNC_SERVICE_CLIENT_TEMPLATE = "{}AsyncClient"
3940

4041
_VALID_API_VERSIONS = ["v22", "v21", "v20", "v19"]
4142
_MESSAGE_TYPES = ["common", "enums", "errors", "resources", "services"]
@@ -360,6 +361,7 @@ def get_service(
360361
name: str,
361362
version: str = _DEFAULT_VERSION,
362363
interceptors: Union[list, None] = None,
364+
is_async: bool = False,
363365
) -> Any:
364366
"""Returns a service client instance for the specified service_name.
365367

@@ -372,6 +374,8 @@ def get_service(
372374
interceptors: an optional list of interceptors to include in
373375
requests. NOTE: this parameter is not intended for non-Google use
374376
and is not officially supported.
377+
is_async: whether or not to retrieve the async version of the
378+
service client being requested.
375379

376380
Returns:
377381
A service client instance associated with the given service_name.
@@ -391,8 +395,14 @@ def get_service(
391395

392396
try:
393397
service_module: Any = import_module(f"{services_path}.{snaked}")
398+
399+
if is_async:
400+
service_name = _ASYNC_SERVICE_CLIENT_TEMPLATE.format(name)
401+
else:
402+
service_name = _SERVICE_CLIENT_TEMPLATE.format(name)
403+
394404
service_client_class: Any = util.get_nested_attr(
395-
service_module, _SERVICE_CLIENT_TEMPLATE.format(name)
405+
service_module, service_name
396406
)
397407
except (AttributeError, ModuleNotFoundError):
398408
raise ValueError(

tests/client_test.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -490,18 +490,49 @@ def test_load_from_string_versioned(self):
490490
def test_get_service(self):
491491
# Retrieve service names for all defined service clients.
492492
for ver in valid_versions:
493-
services_path = f"google.ads.googleads.{ver}"
494-
service_names = [
495-
f'{name.rsplit("ServiceClient")[0]}Service'
496-
for name in dir(import_module(services_path))
497-
if "ServiceClient" in name
493+
services_filepath = f"google/ads/googleads/{ver}/services/services"
494+
# Retrieve list of all the services that exist under the
495+
# {version}/services/services directory.
496+
service_dir_names = [
497+
name for name in os.listdir(services_filepath) if name.endswith("_service")
498498
]
499499

500-
client = self._create_test_client()
500+
client = self._create_test_client(version=ver)
501+
502+
for dir_name in service_dir_names:
503+
# Converts from snake case to title case, for example:
504+
# google_ads_service --> GoogleAdsService
505+
service_name = ''.join(
506+
[part.capitalize() for part in dir_name.split("_")]
507+
)
508+
509+
# Load each service module
510+
svc = client.get_service(service_name)
511+
self.assertEqual(svc.__class__.__name__, f"{service_name}Client")
512+
513+
def test_get_async_service(self):
514+
# Retrieve service names for all defined service clients.
515+
for ver in valid_versions:
516+
services_filepath = f"google/ads/googleads/{ver}/services/services"
517+
# Retrieve list of all the services that exist under the
518+
# {version}/services/services directory.
519+
service_dir_names = [
520+
name for name in os.listdir(services_filepath) if name.endswith("_service")
521+
]
522+
523+
client = self._create_test_client(version=ver)
524+
525+
for dir_name in service_dir_names:
526+
# Converts from snake case to title case, for example:
527+
# google_ads_service --> GoogleAdsService
528+
service_name = ''.join(
529+
[part.capitalize() for part in dir_name.split("_")]
530+
)
531+
532+
# Load each service module
533+
svc = client.get_service(service_name, is_async=True)
534+
self.assertEqual(svc.__class__.__name__, f"{service_name}AsyncClient")
501535

502-
# Iterate through retrieval of all service clients by name.
503-
for service_name in service_names:
504-
client.get_service(service_name, version=ver)
505536

506537
def test_get_service_custom_endpoint(self):
507538
service_name = "GoogleAdsService"

0 commit comments

Comments
 (0)