Skip to content

Commit b370a7c

Browse files
committed
Refactor Download Path
1 parent bf91a90 commit b370a7c

File tree

21 files changed

+840
-710
lines changed

21 files changed

+840
-710
lines changed

src/ghga_connector/cli.py

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,27 @@
2525
from types import TracebackType
2626

2727
import crypt4gh.keys
28-
import httpx
2928
import typer
3029
from ghga_service_commons.utils import crypt
3130

3231
from ghga_connector import exceptions
3332
from ghga_connector.config import CONFIG, set_runtime_config
33+
from ghga_connector.constants import C4GH
3434
from ghga_connector.core import (
3535
CLIMessageDisplay,
36-
WorkPackageAccessor,
36+
WorkPackageClient,
3737
async_client,
3838
)
39-
from ghga_connector.core.downloading.batch_processing import FileStager
39+
from ghga_connector.core.downloading.api_calls import DownloadClient
40+
from ghga_connector.core.downloading.batch_processing import FileInfo, FileStager
4041
from ghga_connector.core.main import (
4142
decrypt_file,
4243
download_file,
43-
get_wps_token,
44+
get_work_package_token,
4445
upload_file,
4546
)
4647

4748

48-
@dataclass
49-
class DownloadParameters:
50-
"""Contains parameters returned by API calls to prepare information needed for download"""
51-
52-
file_ids_with_extension: dict[str, str]
53-
work_package_accessor: WorkPackageAccessor
54-
55-
5649
@dataclass
5750
class WorkPackageInformation:
5851
"""Wraps decrypted work package token and id to pass to other functions"""
@@ -93,33 +86,10 @@ def modify_for_debug(debug: bool):
9386
sys.excepthook = partial(exception_hook)
9487

9588

96-
async def retrieve_download_parameters(
97-
*,
98-
client: httpx.AsyncClient,
99-
my_private_key: bytes,
100-
my_public_key: bytes,
101-
work_package_information: WorkPackageInformation,
102-
) -> DownloadParameters:
103-
"""Run necessary API calls to configure file download"""
104-
work_package_accessor = WorkPackageAccessor(
105-
access_token=work_package_information.decrypted_token,
106-
client=client,
107-
package_id=work_package_information.package_id,
108-
my_private_key=my_private_key,
109-
my_public_key=my_public_key,
110-
)
111-
file_ids_with_extension = await work_package_accessor.get_package_files()
112-
113-
return DownloadParameters(
114-
file_ids_with_extension=file_ids_with_extension,
115-
work_package_accessor=work_package_accessor,
116-
)
117-
118-
11989
def get_work_package_information(my_private_key: bytes):
12090
"""Fetch a work package id and work package token and decrypt the token"""
12191
# get work package access token and id from user input
122-
work_package_id, work_package_token = get_wps_token(max_tries=3)
92+
work_package_id, work_package_token = get_work_package_token(max_tries=3)
12393
decrypted_token = crypt.decrypt(data=work_package_token, key=my_private_key)
12494
return WorkPackageInformation(
12595
decrypted_token=decrypted_token, package_id=work_package_id
@@ -276,38 +246,70 @@ async def async_download(
276246

277247
async with async_client() as client, set_runtime_config(client=client):
278248
CLIMessageDisplay.display("Retrieving API configuration information...")
279-
parameters = await retrieve_download_parameters(
249+
work_package_client = WorkPackageClient(
250+
access_token=work_package_information.decrypted_token,
280251
client=client,
252+
package_id=work_package_information.package_id,
281253
my_private_key=my_private_key,
282254
my_public_key=my_public_key,
283-
work_package_information=work_package_information,
255+
)
256+
257+
file_ids_with_extension = await work_package_client.get_package_files()
258+
259+
download_client = DownloadClient(
260+
client=client, work_package_client=work_package_client
284261
)
285262

286263
CLIMessageDisplay.display("Preparing files for download...")
287264
stager = FileStager(
288-
wanted_file_ids=list(parameters.file_ids_with_extension),
265+
wanted_files=file_ids_with_extension,
289266
output_dir=output_dir,
290-
work_package_accessor=parameters.work_package_accessor,
291-
client=client,
267+
work_package_client=work_package_client,
268+
download_client=download_client,
292269
config=CONFIG,
293270
)
294271
while not stager.finished:
295272
staged_files = await stager.get_staged_files()
296-
for file_id in staged_files:
297-
CLIMessageDisplay.display(f"Downloading file with id '{file_id}'...")
273+
for file_info in staged_files:
274+
check_for_existing_file(file_info=file_info, overwrite=overwrite)
298275
await download_file(
299-
client=client,
300-
file_id=file_id,
301-
file_extension=parameters.file_ids_with_extension[file_id],
302-
output_dir=output_dir,
276+
download_client=download_client,
277+
file_info=file_info,
303278
max_concurrent_downloads=CONFIG.max_concurrent_downloads,
304279
part_size=CONFIG.part_size,
305-
work_package_accessor=parameters.work_package_accessor,
306-
overwrite=overwrite,
307280
)
281+
finalize_download(file_info)
308282
staged_files.clear()
309283

310284

285+
def check_for_existing_file(*, file_info: FileInfo, overwrite: bool):
286+
"""Check if a file with the given name already exists and conditionally overwrite it."""
287+
# check output file
288+
output_file = file_info.path_once_complete
289+
if output_file.exists():
290+
if overwrite:
291+
CLIMessageDisplay.display(
292+
f"A file with name '{output_file}' already exists and will be overwritten."
293+
)
294+
else:
295+
CLIMessageDisplay.failure(
296+
f"A file with name '{output_file}' already exists. Skipping."
297+
)
298+
return
299+
300+
output_file_ongoing = file_info.path_during_download
301+
if output_file_ongoing.exists():
302+
output_file_ongoing.unlink()
303+
304+
305+
def finalize_download(file_info: FileInfo):
306+
"""Rename a file after downloading and announce completion"""
307+
file_info.path_during_download.rename(file_info.path_once_complete)
308+
CLIMessageDisplay.success(
309+
f"File with id '{file_info.file_id}' has been successfully downloaded."
310+
)
311+
312+
311313
@cli.command(no_args_is_help=True)
312314
def decrypt( # noqa: PLR0912, C901
313315
*,
@@ -356,7 +358,7 @@ def decrypt( # noqa: PLR0912, C901
356358
skipped_files = []
357359
file_count = 0
358360
for input_file in input_dir.iterdir():
359-
if not input_file.is_file() or input_file.suffix != ".c4gh":
361+
if not input_file.is_file() or input_file.suffix != C4GH:
360362
skipped_files.append(str(input_file))
361363
continue
362364

@@ -396,7 +398,7 @@ def decrypt( # noqa: PLR0912, C901
396398

397399
if skipped_files:
398400
CLIMessageDisplay.display(
399-
"The following files were skipped as they are not .c4gh files:"
401+
f"The following files were skipped as they are not {C4GH} files:"
400402
)
401403
for file in skipped_files:
402404
CLIMessageDisplay.display(f"- {file}")

src/ghga_connector/config.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,18 @@
3232
__all__ = [
3333
"CONFIG",
3434
"Config",
35-
"get_dcs_api_url",
35+
"get_download_api_url",
3636
"get_ghga_pubkey",
37-
"get_ucs_api_url",
38-
"get_wps_api_url",
37+
"get_upload_api_url",
38+
"get_work_package_api_url",
3939
"set_runtime_config",
4040
]
4141

42-
ucs_api_url_var: ContextVar[str] = ContextVar("ucs_api_url", default="")
43-
dcs_api_url_var: ContextVar[str] = ContextVar("dcs_api_url", default="")
44-
wps_api_url_var: ContextVar[str] = ContextVar("wps_api_url", default="")
42+
upload_api_url_var: ContextVar[str] = ContextVar("upload_api_url_var", default="")
43+
download_api_url_var: ContextVar[str] = ContextVar("download_api_url_var", default="")
44+
work_package_api_url_var: ContextVar[str] = ContextVar(
45+
"work_package_api_url_var", default=""
46+
)
4547
ghga_pubkey_var: ContextVar[str] = ContextVar("ghga_pubkey", default="")
4648

4749

@@ -52,19 +54,19 @@ def _get_context_var(context_var: ContextVar) -> Any:
5254
return value
5355

5456

55-
def get_ucs_api_url() -> str:
56-
"""Get the UCS API URL."""
57-
return _get_context_var(ucs_api_url_var)
57+
def get_upload_api_url() -> str:
58+
"""Get the Upload API URL."""
59+
return _get_context_var(upload_api_url_var)
5860

5961

60-
def get_dcs_api_url() -> str:
61-
"""Get the DCS API URL."""
62-
return _get_context_var(dcs_api_url_var)
62+
def get_download_api_url() -> str:
63+
"""Get the Download API URL."""
64+
return _get_context_var(download_api_url_var)
6365

6466

65-
def get_wps_api_url() -> str:
66-
"""Get the WPS API URL."""
67-
return _get_context_var(wps_api_url_var)
67+
def get_work_package_api_url() -> str:
68+
"""Get the Work Package API URL."""
69+
return _get_context_var(work_package_api_url_var)
6870

6971

7072
def get_ghga_pubkey() -> str:
@@ -112,9 +114,9 @@ async def set_runtime_config(client: httpx.AsyncClient):
112114
113115
This sets the following values:
114116
- ghga_pubkey
115-
- wps_api_url
116-
- dcs_api_url
117-
- ucs_api_url
117+
- work_package_api_url
118+
- download_api_url
119+
- upload_api_url
118120
119121
Raises:
120122
WellKnownValueNotFound: If one of the well-known values is not found in the
@@ -135,9 +137,9 @@ async def set_runtime_config(client: httpx.AsyncClient):
135137

136138
async with (
137139
set_context_var(ghga_pubkey_var, values["crypt4gh_public_key"]),
138-
set_context_var(wps_api_url_var, values["wps_api_url"].rstrip("/")),
139-
set_context_var(dcs_api_url_var, values["dcs_api_url"].rstrip("/")),
140-
set_context_var(ucs_api_url_var, values["ucs_api_url"].rstrip("/")),
140+
set_context_var(work_package_api_url_var, values["wps_api_url"].rstrip("/")),
141+
set_context_var(download_api_url_var, values["dcs_api_url"].rstrip("/")),
142+
set_context_var(upload_api_url_var, values["ucs_api_url"].rstrip("/")),
141143
):
142144
yield
143145

src/ghga_connector/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
MAX_RETRIES = 5
2424
MAX_WAIT_TIME = 60 * 60
2525
CACHE_MIN_FRESH = 3
26+
C4GH = ".c4gh"

src/ghga_connector/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@
2828
from .http_translation import ResponseExceptionTranslator # noqa: F401
2929
from .message_display import CLIMessageDisplay, MessageColors # noqa: F401
3030
from .structs import PartRange # noqa: F401
31-
from .work_package import WorkPackageAccessor # noqa: F401
31+
from .work_package import WorkPackageClient # noqa: F401

src/ghga_connector/core/api_calls/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,16 @@ def check_url(api_url: str, *, timeout_in_seconds: int = 5) -> bool:
4242

4343
content = response.json()
4444
return "status" in content and content["status"].lower() == "ok"
45+
46+
47+
def modify_headers_for_cache_refresh(headers: httpx.Headers) -> None:
48+
"""Update cache-control headers to get fresh response from source.
49+
50+
Headers object is modified in place, hence no return value.
51+
"""
52+
cache_control_headers = headers.get("Cache-Control")
53+
if not cache_control_headers:
54+
cache_control_headers = ["max-age=0"]
55+
else:
56+
cache_control_headers = [cache_control_headers, "max-age=0"]
57+
headers["Cache-Control"] = ",".join(cache_control_headers)

src/ghga_connector/core/downloading/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@
1515
#
1616
"""This subpackage contains functionality needed to download files from GHGA"""
1717

18-
from ..progress_bar import DownloadProgressBar # noqa: F401
19-
from .structs import RetryResponse, UrlAndHeaders, URLResponse # noqa: F401
18+
from ..progress_bar import DownloadProgressBar
19+
from .structs import RetryResponse, UrlAndHeaders
20+
21+
__all__ = ["DownloadProgressBar", "RetryResponse", "UrlAndHeaders"]

0 commit comments

Comments
 (0)