Skip to content

Commit 6177666

Browse files
committed
MultiBackendJobManager: refresh bearer before _JobStartTask #817
1 parent af9cd1a commit 6177666

File tree

5 files changed

+313
-25
lines changed

5 files changed

+313
-25
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717

1818
### Fixed
1919

20+
- Proactively refresh access/bearer token in `MultiBackendJobManager` before launching a job start thread ([#817](https://github.com/Open-EO/openeo-python-client/issues/817))
21+
2022

2123
## [0.45.0] - 2025-09-17
2224

openeo/extra/job_management/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def __init__(
244244
)
245245
self._thread = None
246246
self._worker_pool = None
247+
# Generic cache
248+
self._cache = {}
247249

248250
def add_backend(
249251
self,
@@ -650,6 +652,8 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
650652
# start job if not yet done by callback
651653
try:
652654
job_con = job.connection
655+
# Proactively refresh bearer token (because task in thread will not be able to do that)
656+
self._refresh_bearer_token(connection=job_con)
653657
task = _JobStartTask(
654658
root_url=job_con.root_url,
655659
bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
@@ -670,6 +674,19 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
670674
df.loc[i, "status"] = "skipped"
671675
stats["start_job skipped"] += 1
672676

677+
def _refresh_bearer_token(self, connection: Connection, *, max_age: float = 60):
678+
"""
679+
Helper to proactively refresh access token of connection
680+
(but not too often, based on `max_age`).
681+
"""
682+
# TODO: be smarter about timing, e.g. by inspecting expiry of current token?
683+
now = time.time()
684+
key = f"connection-{id(connection)}-refresh-time"
685+
if self._cache.get(key, 0) + max_age < now:
686+
refreshed = connection.try_access_token_refresh()
687+
if refreshed:
688+
self._cache[key] = now
689+
673690
def _process_threadworker_updates(
674691
self,
675692
worker_pool: _JobManagerWorkerThreadPool,

openeo/rest/auth/testing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def token_callback_resource_owner_password_credentials(self, params: dict, conte
143143
assert params["scope"] == self.expected_fields["scope"]
144144
return self._build_token_response()
145145

146+
def token_callback_block_400(self, params: dict, context):
147+
"""Failing callback with 400 Bad Request"""
148+
context.status_code = 400
149+
return "block_400"
150+
146151
def device_code_callback(self, request: requests_mock.request._RequestObjectProxy, context):
147152
params = self._get_query_params(query=request.text)
148153
assert params["client_id"] == self.expected_client_id

openeo/rest/connection.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,28 @@ def authenticate_bearer_token(self, bearer_token: str) -> Connection:
665665
self._oidc_auth_renewer = None
666666
return self
667667

668+
def try_access_token_refresh(self, *, reason: Optional[str] = None) -> bool:
669+
"""
670+
Try to get a fresh access token if possible.
671+
Returns whether a new access token was obtained.
672+
"""
673+
reason = f" Reason: {reason}" if reason else ""
674+
if isinstance(self.auth, OidcBearerAuth) and self._oidc_auth_renewer:
675+
try:
676+
self._authenticate_oidc(
677+
authenticator=self._oidc_auth_renewer,
678+
provider_id=self._oidc_auth_renewer.provider_info.id,
679+
store_refresh_token=False,
680+
oidc_auth_renewer=self._oidc_auth_renewer,
681+
)
682+
_log.info(f"Obtained new access token (grant {self._oidc_auth_renewer.grant_type!r}).{reason}")
683+
return True
684+
except OpenEoClientException as auth_exc:
685+
_log.error(
686+
f"Failed to obtain new access token (grant {self._oidc_auth_renewer.grant_type!r}): {auth_exc!r}.{reason}"
687+
)
688+
return False
689+
668690
def request(
669691
self,
670692
method: str,
@@ -690,24 +712,11 @@ def _request():
690712
api_exc.http_status_code in {HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN}
691713
and api_exc.code == "TokenInvalid"
692714
):
693-
# Auth token expired: can we refresh?
694-
if isinstance(self.auth, OidcBearerAuth) and self._oidc_auth_renewer:
695-
msg = f"OIDC access token expired ({api_exc.http_status_code} {api_exc.code})."
696-
try:
697-
self._authenticate_oidc(
698-
authenticator=self._oidc_auth_renewer,
699-
provider_id=self._oidc_auth_renewer.provider_info.id,
700-
store_refresh_token=False,
701-
oidc_auth_renewer=self._oidc_auth_renewer,
702-
)
703-
_log.info(f"{msg} Obtained new access token (grant {self._oidc_auth_renewer.grant_type!r}).")
704-
except OpenEoClientException as auth_exc:
705-
_log.error(
706-
f"{msg} Failed to obtain new access token (grant {self._oidc_auth_renewer.grant_type!r}): {auth_exc!r}."
707-
)
708-
else:
709-
# Retry request.
710-
return _request()
715+
# Retry if we can refresh the access token
716+
if self.try_access_token_refresh(
717+
reason=f"OIDC access token expired ({api_exc.http_status_code} {api_exc.code})."
718+
):
719+
return _request()
711720
raise
712721

713722
def describe_account(self) -> dict:

0 commit comments

Comments
 (0)