Skip to content

Commit bdcc4a0

Browse files
committed
Flatten CLI and further refactor top of download path
1 parent 6a20f99 commit bdcc4a0

File tree

9 files changed

+249
-242
lines changed

9 files changed

+249
-242
lines changed

src/ghga_connector/cli.py

Lines changed: 2 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -19,39 +19,17 @@
1919
import logging
2020
import os
2121
import sys
22-
from dataclasses import dataclass
2322
from functools import partial
2423
from pathlib import Path
2524
from types import TracebackType
2625

27-
import crypt4gh.keys
2826
import typer
29-
from ghga_service_commons.utils import crypt
3027

3128
from ghga_connector import exceptions
3229
from ghga_connector.config import CONFIG, set_runtime_config
3330
from ghga_connector.constants import C4GH
34-
from ghga_connector.core import (
35-
CLIMessageDisplay,
36-
WorkPackageClient,
37-
async_client,
38-
)
39-
from ghga_connector.core.downloading.api_calls import DownloadClient
40-
from ghga_connector.core.downloading.batch_processing import FileInfo, FileStager
41-
from ghga_connector.core.main import (
42-
decrypt_file,
43-
download_file,
44-
get_work_package_token,
45-
upload_file,
46-
)
47-
48-
49-
@dataclass
50-
class WorkPackageInformation:
51-
"""Wraps decrypted work package token and id to pass to other functions"""
52-
53-
decrypted_token: str
54-
package_id: str
31+
from ghga_connector.core import CLIMessageDisplay, async_client
32+
from ghga_connector.core.main import async_download, decrypt_file, upload_file
5533

5634

5735
def strtobool(value: str) -> bool:
@@ -86,16 +64,6 @@ def modify_for_debug(debug: bool):
8664
sys.excepthook = partial(exception_hook)
8765

8866

89-
def get_work_package_information(my_private_key: bytes):
90-
"""Fetch a work package id and work package token and decrypt the token"""
91-
# get work package access token and id from user input
92-
work_package_id, work_package_token = get_work_package_token(max_tries=3)
93-
decrypted_token = crypt.decrypt(data=work_package_token, key=my_private_key)
94-
return WorkPackageInformation(
95-
decrypted_token=decrypted_token, package_id=work_package_id
96-
)
97-
98-
9967
cli = typer.Typer(no_args_is_help=True)
10068

10169

@@ -203,113 +171,6 @@ def download( # noqa: PLR0913
203171
)
204172

205173

206-
def get_public_key(my_public_key_path: Path) -> bytes:
207-
"""Get the user's private key from the path supplied"""
208-
if not my_public_key_path.is_file():
209-
raise exceptions.PubKeyFileDoesNotExistError(public_key_path=my_public_key_path)
210-
211-
return crypt4gh.keys.get_public_key(filepath=my_public_key_path)
212-
213-
214-
def get_private_key(my_private_key_path: Path, passphrase: str | None = None) -> bytes:
215-
"""Get the user's private key, using the passphrase if supplied/needed."""
216-
if passphrase:
217-
my_private_key = crypt4gh.keys.get_private_key(
218-
filepath=my_private_key_path, callback=lambda: passphrase
219-
)
220-
else:
221-
my_private_key = crypt4gh.keys.get_private_key(
222-
filepath=my_private_key_path, callback=None
223-
)
224-
return my_private_key
225-
226-
227-
async def async_download(
228-
*,
229-
output_dir: Path,
230-
my_public_key_path: Path,
231-
my_private_key_path: Path,
232-
passphrase: str | None = None,
233-
overwrite: bool = False,
234-
):
235-
"""Download files asynchronously"""
236-
if not output_dir.is_dir():
237-
raise exceptions.DirectoryDoesNotExistError(directory=output_dir)
238-
239-
my_public_key = get_public_key(my_public_key_path)
240-
my_private_key = get_private_key(my_private_key_path, passphrase)
241-
242-
CLIMessageDisplay.display("\nFetching work package token...")
243-
work_package_information = get_work_package_information(
244-
my_private_key=my_private_key
245-
)
246-
247-
async with async_client() as client, set_runtime_config(client=client):
248-
CLIMessageDisplay.display("Retrieving API configuration information...")
249-
work_package_client = WorkPackageClient(
250-
access_token=work_package_information.decrypted_token,
251-
client=client,
252-
package_id=work_package_information.package_id,
253-
my_private_key=my_private_key,
254-
my_public_key=my_public_key,
255-
)
256-
257-
file_ids_with_extension = await work_package_client.get_package_files()
258-
259-
download_client = DownloadClient(
260-
client=client, work_package_client=work_package_client
261-
)
262-
263-
CLIMessageDisplay.display("Preparing files for download...")
264-
stager = FileStager(
265-
wanted_files=file_ids_with_extension,
266-
output_dir=output_dir,
267-
work_package_client=work_package_client,
268-
download_client=download_client,
269-
config=CONFIG,
270-
)
271-
while not stager.finished:
272-
staged_files = await stager.get_staged_files()
273-
for file_info in staged_files:
274-
check_for_existing_file(file_info=file_info, overwrite=overwrite)
275-
await download_file(
276-
download_client=download_client,
277-
file_info=file_info,
278-
max_concurrent_downloads=CONFIG.max_concurrent_downloads,
279-
part_size=CONFIG.part_size,
280-
)
281-
finalize_download(file_info)
282-
staged_files.clear()
283-
284-
285-
def check_for_existing_file(*, file_info: FileInfo, overwrite: bool):
286-
"""Check if a file with the given name already exists and conditionally overwrite it."""
287-
# check output file
288-
output_file = file_info.path_once_complete
289-
if output_file.exists():
290-
if overwrite:
291-
CLIMessageDisplay.display(
292-
f"A file with name '{output_file}' already exists and will be overwritten."
293-
)
294-
else:
295-
CLIMessageDisplay.failure(
296-
f"A file with name '{output_file}' already exists. Skipping."
297-
)
298-
return
299-
300-
output_file_ongoing = file_info.path_during_download
301-
if output_file_ongoing.exists():
302-
output_file_ongoing.unlink()
303-
304-
305-
def finalize_download(file_info: FileInfo):
306-
"""Rename a file after downloading and announce completion"""
307-
file_info.path_during_download.rename(file_info.path_once_complete)
308-
CLIMessageDisplay.success(
309-
f"File with id '{file_info.file_id}' has been successfully downloaded."
310-
)
311-
312-
313174
@cli.command(no_args_is_help=True)
314175
def decrypt( # noqa: PLR0912, C901
315176
*,

src/ghga_connector/core/downloading/api_calls.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
from ghga_connector.config import get_download_api_url
2626
from ghga_connector.constants import TIMEOUT_LONG
2727
from ghga_connector.core import RetryHandler
28-
from ghga_connector.core.api_calls.utils import modify_headers_for_cache_refresh
28+
from ghga_connector.core.api_calls.utils import (
29+
is_service_healthy,
30+
modify_headers_for_cache_refresh,
31+
)
2932
from ghga_connector.core.work_package import WorkPackageClient
3033

3134
from .structs import RetryResponse
@@ -100,6 +103,14 @@ async def _retrieve_drs_object_using_params(
100103

101104
return _handle_drs_object_response(url=url, response=response)
102105

106+
def check_download_api_is_reachable(self):
107+
"""Verify that the download API is reachable.
108+
109+
Raises an `ApiNotReachableError` if it is not reachable.
110+
"""
111+
if not is_service_healthy(self._download_api_url):
112+
raise exceptions.ApiNotReachableError(api_url=self._download_api_url)
113+
103114
async def get_envelope_authorization_headers(
104115
self, *, file_id: str
105116
) -> httpx.Headers:

src/ghga_connector/core/downloading/batch_processing.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@
1414
# limitations under the License.
1515
"""Module for batch processing related code"""
1616

17+
from collections.abc import AsyncGenerator
1718
from dataclasses import dataclass, field
1819
from pathlib import Path
1920
from time import perf_counter, sleep
2021

2122
from ghga_connector import exceptions
2223
from ghga_connector.config import Config, get_download_api_url
2324
from ghga_connector.constants import C4GH
24-
from ghga_connector.core import CLIMessageDisplay, WorkPackageClient
25+
from ghga_connector.core import CLIMessageDisplay, WorkPackageClient, utils
2526
from ghga_connector.core.api_calls import is_service_healthy
2627

2728
from .api_calls import DownloadClient
28-
from .structs import RetryResponse
29+
from .structs import FileInfo, RetryResponse
2930

3031

3132
@dataclass
@@ -60,36 +61,6 @@ def handle_response(self, *, response: str):
6061
raise exceptions.AbortBatchProcessError()
6162

6263

63-
@dataclass
64-
class FileInfo:
65-
"""Information about a file to be downloaded"""
66-
67-
file_id: str
68-
file_extension: str
69-
file_size: int
70-
output_dir: Path
71-
72-
@property
73-
def file_name(self) -> str:
74-
"""Construct file name with suffix, if given"""
75-
file_name = f"{self.file_id}"
76-
if self.file_extension:
77-
file_name = f"{self.file_id}{self.file_extension}"
78-
return file_name
79-
80-
@property
81-
def path_during_download(self) -> Path:
82-
"""The file path while the file download is still in progress"""
83-
# with_suffix() might overwrite existing suffixes, do this instead:
84-
output_file = self.path_once_complete
85-
return output_file.parent / (output_file.name + ".part")
86-
87-
@property
88-
def path_once_complete(self) -> Path:
89-
"""The file path once the download is complete"""
90-
return self.output_dir / f"{self.file_name}{C4GH}"
91-
92-
9364
class FileStager:
9465
"""Utility class to deal with file staging in batch processing."""
9566

@@ -130,7 +101,7 @@ def __init__(
130101
async def get_staged_files(self) -> list[FileInfo]:
131102
"""Get files that are already staged.
132103
133-
Returns a dict with file IDs as keys and FileInfo as values.
104+
Returns a list of `FileInfo` instances.
134105
These values contain the download URLs and file sizes.
135106
The dict should be cleared after these files have been downloaded.
136107
"""
@@ -221,3 +192,31 @@ def _handle_failures(self) -> bool:
221192
self._started_waiting = perf_counter() # reset the timer
222193
self._missing_files = [] # reset list of missing files
223194
return True
195+
196+
async def manage_file_downloads(self, overwrite: bool) -> AsyncGenerator[FileInfo]:
197+
"""Manages file downloads by handling errors, checking for existing files,
198+
printing messages to the display, and renaming files after they are downloaded.
199+
200+
Yields file information.
201+
"""
202+
while not self.finished:
203+
staged_files = await self.get_staged_files()
204+
for file_info in staged_files:
205+
utils.check_for_existing_file(file_info=file_info, overwrite=overwrite)
206+
try:
207+
file_id = file_info.file_id
208+
yield file_info
209+
except exceptions.GetEnvelopeError as error:
210+
CLIMessageDisplay.failure(
211+
f"The request to get an envelope for file '{file_id}' failed."
212+
)
213+
raise error
214+
except exceptions.DownloadError as error:
215+
CLIMessageDisplay.failure(
216+
f"Failed downloading with id '{file_id}'."
217+
)
218+
raise error
219+
file_info.path_during_download.rename(file_info.path_once_complete)
220+
CLIMessageDisplay.success(
221+
f"File with id '{file_info.file_id}' has been successfully downloaded."
222+
)

src/ghga_connector/core/downloading/downloader.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,11 @@
2323
from pathlib import Path
2424

2525
from ghga_connector import exceptions
26-
from ghga_connector.core import (
27-
CLIMessageDisplay,
28-
PartRange,
29-
calc_part_ranges,
30-
)
26+
from ghga_connector.core import CLIMessageDisplay, PartRange, calc_part_ranges
3127
from ghga_connector.core.tasks import TaskHandler
3228

3329
from ..progress_bar import DownloadProgressBar
34-
from .api_calls import (
35-
DownloadClient,
36-
extract_download_url,
37-
)
30+
from .api_calls import DownloadClient, extract_download_url
3831
from .structs import RetryResponse
3932

4033
logger = logging.getLogger(__name__)

src/ghga_connector/core/downloading/structs.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,43 @@
1616
"""Contains additional data structures needed by the download code"""
1717

1818
from dataclasses import dataclass
19+
from pathlib import Path
20+
21+
from ghga_connector.constants import C4GH
1922

2023

2124
@dataclass
2225
class RetryResponse:
2326
"""Response to download request if file is not yet staged"""
2427

2528
retry_after: int
29+
30+
31+
@dataclass
32+
class FileInfo:
33+
"""Information about a file to be downloaded"""
34+
35+
file_id: str
36+
file_extension: str
37+
file_size: int
38+
output_dir: Path
39+
40+
@property
41+
def file_name(self) -> str:
42+
"""Construct file name with suffix, if given"""
43+
file_name = f"{self.file_id}"
44+
if self.file_extension:
45+
file_name = f"{self.file_id}{self.file_extension}"
46+
return file_name
47+
48+
@property
49+
def path_during_download(self) -> Path:
50+
"""The file path while the file download is still in progress"""
51+
# with_suffix() might overwrite existing suffixes, do this instead:
52+
output_file = self.path_once_complete
53+
return output_file.parent / (output_file.name + ".part")
54+
55+
@property
56+
def path_once_complete(self) -> Path:
57+
"""The file path once the download is complete"""
58+
return self.output_dir / f"{self.file_name}{C4GH}"

0 commit comments

Comments
 (0)