Skip to content

Commit dcc26d4

Browse files
committed
Move runtime config to contextvars to reduce value relay
1 parent ede1a75 commit dcc26d4

File tree

17 files changed

+208
-116
lines changed

17 files changed

+208
-116
lines changed

src/ghga_connector/cli.py

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,13 @@
2929
import typer
3030
from ghga_service_commons.utils import crypt
3131

32-
from ghga_connector.config import CONFIG
32+
from ghga_connector import exceptions
33+
from ghga_connector.config import CONFIG, set_runtime_config
3334
from ghga_connector.core import (
3435
CLIMessageDisplay,
3536
WorkPackageAccessor,
3637
async_client,
37-
exceptions,
3838
)
39-
from ghga_connector.core.api_calls import WKVSCaller
4039
from ghga_connector.core.downloading.batch_processing import FileStager
4140
from ghga_connector.core.main import (
4241
decrypt_file,
@@ -50,19 +49,10 @@
5049
class DownloadParameters:
5150
"""Contains parameters returned by API calls to prepare information needed for download"""
5251

53-
dcs_api_url: str
5452
file_ids_with_extension: dict[str, str]
5553
work_package_accessor: WorkPackageAccessor
5654

5755

58-
@dataclass
59-
class UploadParameters:
60-
"""Contains parameters returned by API calls to prepare information needed for upload"""
61-
62-
ucs_api_url: str
63-
server_pubkey: str
64-
65-
6656
@dataclass
6757
class WorkPackageInformation:
6858
"""Wraps decrypted work package token and id to pass to other functions"""
@@ -103,15 +93,6 @@ def modify_for_debug(debug: bool):
10393
sys.excepthook = partial(exception_hook)
10494

10595

106-
async def retrieve_upload_parameters(client: httpx.AsyncClient) -> UploadParameters:
107-
"""Configure httpx client and retrieve necessary parameters from WKVS"""
108-
wkvs_caller = WKVSCaller(client=client, wkvs_url=CONFIG.wkvs_api_url)
109-
ucs_api_url = await wkvs_caller.get_ucs_api_url()
110-
server_pubkey = await wkvs_caller.get_server_pubkey()
111-
112-
return UploadParameters(server_pubkey=server_pubkey, ucs_api_url=ucs_api_url)
113-
114-
11596
async def retrieve_download_parameters(
11697
*,
11798
client: httpx.AsyncClient,
@@ -120,23 +101,16 @@ async def retrieve_download_parameters(
120101
work_package_information: WorkPackageInformation,
121102
) -> DownloadParameters:
122103
"""Run necessary API calls to configure file download"""
123-
wkvs_caller = WKVSCaller(client=client, wkvs_url=CONFIG.wkvs_api_url)
124-
dcs_api_url = await wkvs_caller.get_dcs_api_url()
125-
wps_api_url = await wkvs_caller.get_wps_api_url()
126-
127104
work_package_accessor = WorkPackageAccessor(
128105
access_token=work_package_information.decrypted_token,
129-
api_url=wps_api_url,
130106
client=client,
131-
dcs_api_url=dcs_api_url,
132107
package_id=work_package_information.package_id,
133108
my_private_key=my_private_key,
134109
my_public_key=my_public_key,
135110
)
136111
file_ids_with_extension = await work_package_accessor.get_package_files()
137112

138113
return DownloadParameters(
139-
dcs_api_url=dcs_api_url,
140114
file_ids_with_extension=file_ids_with_extension,
141115
work_package_accessor=work_package_accessor,
142116
)
@@ -199,14 +173,11 @@ async def async_upload(
199173
passphrase: str | None = None,
200174
):
201175
"""Upload a file asynchronously"""
202-
async with async_client() as client:
203-
parameters = await retrieve_upload_parameters(client)
176+
async with async_client() as client, set_runtime_config(client=client):
204177
await upload_file(
205-
api_url=parameters.ucs_api_url,
206178
client=client,
207179
file_id=file_id,
208180
file_path=file_path,
209-
server_public_key=parameters.server_pubkey,
210181
my_public_key_path=my_public_key_path,
211182
my_private_key_path=my_private_key_path,
212183
passphrase=passphrase,
@@ -303,7 +274,7 @@ async def async_download(
303274
my_private_key=my_private_key
304275
)
305276

306-
async with async_client() as client:
277+
async with async_client() as client, set_runtime_config(client=client):
307278
CLIMessageDisplay.display("Retrieving API configuration information...")
308279
parameters = await retrieve_download_parameters(
309280
client=client,
@@ -315,7 +286,6 @@ async def async_download(
315286
CLIMessageDisplay.display("Preparing files for download...")
316287
stager = FileStager(
317288
wanted_file_ids=list(parameters.file_ids_with_extension),
318-
dcs_api_url=parameters.dcs_api_url,
319289
output_dir=output_dir,
320290
work_package_accessor=parameters.work_package_accessor,
321291
client=client,
@@ -326,7 +296,6 @@ async def async_download(
326296
for file_id in staged_files:
327297
CLIMessageDisplay.display(f"Downloading file with id '{file_id}'...")
328298
await download_file(
329-
api_url=parameters.dcs_api_url,
330299
client=client,
331300
file_id=file_id,
332301
file_extension=parameters.file_ids_with_extension[file_id],

src/ghga_connector/config.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,61 @@
1616

1717
"""Global Config Parameters"""
1818

19+
from contextlib import asynccontextmanager
20+
from contextvars import ContextVar
21+
from typing import Any
22+
23+
import httpx
1924
from hexkit.config import config_from_yaml
25+
from hexkit.utils import set_context_var
2026
from pydantic import Field, NonNegativeInt, PositiveInt
2127
from pydantic_settings import BaseSettings
2228

29+
from ghga_connector import exceptions
2330
from ghga_connector.constants import DEFAULT_PART_SIZE, MAX_RETRIES, MAX_WAIT_TIME
2431

32+
__all__ = [
33+
"CONFIG",
34+
"Config",
35+
"get_dcs_api_url",
36+
"get_ghga_pubkey",
37+
"get_ucs_api_url",
38+
"get_wps_api_url",
39+
"set_runtime_config",
40+
]
41+
42+
ucs_api_url_var: ContextVar[str] = ContextVar("ucs_api_url", default="")
43+
dcs_api_url_var: ContextVar[str] = ContextVar("dcs_api_url", default="")
44+
wps_api_url_var: ContextVar[str] = ContextVar("wps_api_url", default="")
45+
ghga_pubkey_var: ContextVar[str] = ContextVar("ghga_pubkey", default="")
46+
47+
48+
def _get_context_var(context_var: ContextVar) -> Any:
49+
value = context_var.get()
50+
if not value:
51+
raise ValueError(f"{context_var.name} is not set")
52+
return value
53+
54+
55+
def get_ucs_api_url() -> str:
56+
"""Get the UCS API URL."""
57+
return _get_context_var(ucs_api_url_var)
58+
59+
60+
def get_dcs_api_url() -> str:
61+
"""Get the DCS API URL."""
62+
return _get_context_var(dcs_api_url_var)
63+
64+
65+
def get_wps_api_url() -> str:
66+
"""Get the WPS API URL."""
67+
return _get_context_var(wps_api_url_var)
68+
69+
70+
def get_ghga_pubkey() -> str:
71+
"""Get the GHGA crypt4gh public key."""
72+
return _get_context_var(ghga_pubkey_var)
73+
2574

2675
@config_from_yaml(prefix="ghga_connector")
2776
class Config(BaseSettings):
@@ -55,3 +104,58 @@ class Config(BaseSettings):
55104

56105

57106
CONFIG = Config()
107+
108+
109+
@asynccontextmanager
110+
async def set_runtime_config(client: httpx.AsyncClient):
111+
"""Set runtime config as context vars to be accessed within a context manager.
112+
113+
This sets the following values:
114+
- ghga_pubkey
115+
- wps_api_url
116+
- dcs_api_url
117+
- ucs_api_url
118+
"""
119+
ghga_pubkey = await _get_wkvs_value(client, value_name="crypt4gh_public_key")
120+
wps_api_url = (await _get_wkvs_value(client, value_name="wps_api_url")).rstrip("/")
121+
dcs_api_url = (await _get_wkvs_value(client, value_name="dcs_api_url")).rstrip("/")
122+
ucs_api_url = (await _get_wkvs_value(client, value_name="ucs_api_url")).rstrip("/")
123+
124+
async with (
125+
set_context_var(ghga_pubkey_var, ghga_pubkey),
126+
set_context_var(wps_api_url_var, wps_api_url),
127+
set_context_var(dcs_api_url_var, dcs_api_url),
128+
set_context_var(ucs_api_url_var, ucs_api_url),
129+
):
130+
yield
131+
132+
133+
async def _get_wkvs_value(client: httpx.AsyncClient, *, value_name: str) -> Any:
134+
"""Retrieve a value from the well-known-value-service.
135+
136+
Args:
137+
value_name (str): the name of the value to be retrieved
138+
139+
Raises:
140+
WellKnownValueNotFound: when a 404 response is received from the WKVS
141+
KeyError: when a successful response is received but doesn't contain the expected value
142+
"""
143+
url = f"{CONFIG.wkvs_api_url}/values/{value_name}"
144+
145+
try:
146+
response = await client.get(url) # verify is True by default
147+
except httpx.RequestError as request_error:
148+
exceptions.raise_if_connection_failed(request_error=request_error, url=url)
149+
raise exceptions.RequestFailedError(url=url) from request_error
150+
151+
if response.status_code == 404:
152+
raise exceptions.WellKnownValueNotFound(value_name=value_name)
153+
154+
try:
155+
value = response.json()[value_name]
156+
except KeyError as err:
157+
raise KeyError(
158+
"Response from well-known-value-service did not include expected field"
159+
+ f" '{value_name}'"
160+
) from err
161+
return value

src/ghga_connector/core/api_calls/well_knowns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import httpx
2222

23-
from ghga_connector.core import exceptions
23+
from ghga_connector import exceptions
2424

2525

2626
@dataclass

src/ghga_connector/core/crypt/encryption.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import crypt4gh.lib
2727
from nacl.bindings import crypto_aead_chacha20poly1305_ietf_encrypt
2828

29+
from ghga_connector.config import get_ghga_pubkey
2930
from ghga_connector.core import get_segments, read_file_parts
3031

3132
from .abstract_bases import Encryptor
@@ -35,11 +36,10 @@
3536
class Crypt4GHEncryptor(Encryptor):
3637
"""Handles on the fly encryption and checksum calculation"""
3738

38-
def __init__( # noqa: PLR0913
39+
def __init__(
3940
self,
4041
part_size: int,
4142
private_key_path: Path,
42-
server_public_key: str,
4343
passphrase: str | None,
4444
checksums: Checksums = Checksums(),
4545
file_secret: bytes | None = None,
@@ -48,7 +48,7 @@ def __init__( # noqa: PLR0913
4848
self._checksums = checksums
4949
self._part_size = part_size
5050
self._private_key_path = private_key_path
51-
self._server_public_key = base64.b64decode(server_public_key)
51+
self._server_public_key = base64.b64decode(get_ghga_pubkey())
5252
self._passphrase = passphrase
5353
if file_secret is None:
5454
file_secret = os.urandom(32)

src/ghga_connector/core/downloading/api_calls.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import httpx
1919
from tenacity import RetryError
2020

21+
from ghga_connector import exceptions
22+
from ghga_connector.config import get_dcs_api_url
2123
from ghga_connector.constants import CACHE_MIN_FRESH, TIMEOUT_LONG
22-
from ghga_connector.core import RetryHandler, WorkPackageAccessor, exceptions
24+
from ghga_connector.core import RetryHandler, WorkPackageAccessor
2325

2426
from .structs import (
2527
RetryResponse,
@@ -63,7 +65,8 @@ async def get_envelope_authorization(
6365
a Crypt4GH envelope for file identified by `file_id`
6466
"""
6567
# build url
66-
url = f"{work_package_accessor.dcs_api_url}/objects/{file_id}/envelopes"
68+
dcs_api_url = get_dcs_api_url()
69+
url = f"{dcs_api_url}/objects/{file_id}/envelopes"
6770
headers = await _get_authorization(
6871
file_id=file_id, work_package_accessor=work_package_accessor
6972
)
@@ -78,7 +81,8 @@ async def get_file_authorization(
7881
object storage URL for file download
7982
"""
8083
# build URL
81-
url = f"{work_package_accessor.dcs_api_url}/objects/{file_id}"
84+
dcs_api_url = get_dcs_api_url()
85+
url = f"{dcs_api_url}/objects/{file_id}"
8286
headers = await _get_authorization(
8387
file_id=file_id,
8488
work_package_accessor=work_package_accessor,

src/ghga_connector/core/downloading/batch_processing.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020

2121
import httpx
2222

23-
from ghga_connector.config import Config
23+
from ghga_connector import exceptions
24+
from ghga_connector.config import Config, get_dcs_api_url
2425
from ghga_connector.core import (
2526
CLIMessageDisplay,
2627
WorkPackageAccessor,
27-
exceptions,
2828
)
2929
from ghga_connector.core.api_calls import is_service_healthy
3030

@@ -70,11 +70,10 @@ def handle_response(self, *, response: str):
7070
class FileStager:
7171
"""Utility class to deal with file staging in batch processing."""
7272

73-
def __init__( # noqa: PLR0913
73+
def __init__(
7474
self,
7575
*,
7676
wanted_file_ids: list[str],
77-
dcs_api_url: str,
7877
output_dir: Path,
7978
work_package_accessor: WorkPackageAccessor,
8079
client: httpx.AsyncClient,
@@ -83,9 +82,9 @@ def __init__( # noqa: PLR0913
8382
"""Initialize the FileStager."""
8483
self.io_handler = CliIoHandler()
8584
existing_file_ids = set(self.io_handler.check_output(location=output_dir))
86-
if not is_service_healthy(dcs_api_url):
87-
raise exceptions.ApiNotReachableError(api_url=dcs_api_url)
88-
self.api_url = dcs_api_url
85+
self.api_url = get_dcs_api_url()
86+
if not is_service_healthy(self.api_url):
87+
raise exceptions.ApiNotReachableError(api_url=self.api_url)
8988
self.work_package_accessor = work_package_accessor
9089
self.max_wait_time = config.max_wait_time
9190
self.client = client

src/ghga_connector/core/downloading/downloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
import httpx
2727
from tenacity import RetryError
2828

29+
from ghga_connector import exceptions
2930
from ghga_connector.core import (
3031
CLIMessageDisplay,
3132
PartRange,
3233
ResponseExceptionTranslator,
3334
RetryHandler,
3435
WorkPackageAccessor,
3536
calc_part_ranges,
36-
exceptions,
3737
)
3838
from ghga_connector.core.tasks import TaskHandler
3939

0 commit comments

Comments
 (0)