Skip to content

Commit 8cafa08

Browse files
pvbouwelPeter Van Bouwel
andauthored
feature: handle expired token gracefully for artifact helper (#805)
* feature: handle expired token gracefully If a token is expired we can refresh it by doing another OpenEO interaction. By allowing 3 attempts for the STS service and by doing a very wide except clause we also make ourselves more robust against intermittent errors. * pr-feedback: address comments * pr-feedback: make sure we have a valid token in any case. * tests: mock /me endpoint --------- Co-authored-by: Peter Van Bouwel <[email protected]>
1 parent 3d71ce0 commit 8cafa08

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

openeo/extra/artifacts/_s3sts/sts.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
import logging
4+
import time
5+
from random import randint
36
from typing import TYPE_CHECKING
47

58
if TYPE_CHECKING:
@@ -12,12 +15,16 @@
1215
from openeo.rest.connection import Connection
1316
from openeo.util import Rfc3339
1417

18+
_log = logging.getLogger(__name__)
19+
1520

1621
class OpenEOSTSClient:
22+
_MAX_STS_ATTEMPTS = 3
23+
1724
def __init__(self, config: S3STSConfig):
1825
self.config = config
1926

20-
def assume_from_openeo_connection(self, connection: Connection) -> AWSSTSCredentials:
27+
def assume_from_openeo_connection(self, connection: Connection, attempt: int = 0) -> AWSSTSCredentials:
2128
"""
2229
Takes an OpenEO connection object and returns temporary credentials to interact with S3
2330
"""
@@ -27,14 +34,31 @@ def assume_from_openeo_connection(self, connection: Connection) -> AWSSTSCredent
2734
raise ProviderSpecificException("Only connections that have BearerAuth can be used.")
2835
auth_token = auth.bearer.split("/")
2936

30-
return AWSSTSCredentials.from_assume_role_response(
31-
self._get_sts_client().assume_role_with_web_identity(
32-
RoleArn=self._get_aws_access_role(),
33-
RoleSessionName=f"artifact-helper-{Rfc3339().now_utc()}",
34-
WebIdentityToken=auth_token[2],
35-
DurationSeconds=43200,
37+
try:
38+
# Do an API call with OpenEO to trigger a refresh of our token if it were stale.
39+
connection.describe_account()
40+
return AWSSTSCredentials.from_assume_role_response(
41+
self._get_sts_client().assume_role_with_web_identity(
42+
RoleArn=self._get_aws_access_role(),
43+
RoleSessionName=f"artifact-helper-{Rfc3339().now_utc()}",
44+
WebIdentityToken=auth_token[2],
45+
DurationSeconds=43200,
46+
)
3647
)
37-
)
48+
except Exception as e:
49+
_log.warning("Failed to get credentials for STS access")
50+
51+
if attempt < self._MAX_STS_ATTEMPTS:
52+
# backoff with jitter
53+
max_sleep_ms = 500 * (2**attempt)
54+
sleep_ms = randint(0, max_sleep_ms)
55+
_log.info(f"Retrying STS access in {sleep_ms} ms")
56+
time.sleep(sleep_ms / 1000.0)
57+
attempt += 1
58+
_log.info(f"Retrying to get credentials for STS access {attempt}/{self._MAX_STS_ATTEMPTS}")
59+
return self.assume_from_openeo_connection(connection, attempt)
60+
else:
61+
raise RuntimeError("Could not get credentials from STS") from e
3862

3963
def _get_sts_client(self) -> STSClient:
4064
return self.config.build_client("sts")

tests/extra/artifacts/_s3sts/test_s3sts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def conn_with_s3sts_capabilities(
9292
requests_mock, extra_api_capabilities, advertised_s3sts_config
9393
) -> Iterator[Connection]:
9494
requests_mock.get(API_URL, json={"api_version": "1.0.0", **extra_api_capabilities})
95+
requests_mock.get(f"{API_URL}me", json={})
9596
conn = Connection(API_URL)
9697
conn.auth = BearerAuth("oidc/fake/token")
9798
yield conn

0 commit comments

Comments
 (0)