Skip to content

Commit 3ac3c52

Browse files
Big refactor: flatten CLI (#119)
* Flatten CLI and further refactor top of download path * Fix test problems * Decouple download error handling again
1 parent 6a20f99 commit 3ac3c52

File tree

10 files changed

+304
-284
lines changed

10 files changed

+304
-284
lines changed

src/ghga_connector/cli.py

Lines changed: 3 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -16,85 +16,17 @@
1616
"""CLI-specific wrappers around core functions."""
1717

1818
import asyncio
19-
import logging
2019
import os
21-
import sys
22-
from dataclasses import dataclass
23-
from functools import partial
2420
from pathlib import Path
25-
from types import TracebackType
2621

27-
import crypt4gh.keys
2822
import typer
29-
from ghga_service_commons.utils import crypt
3023

3124
from ghga_connector import exceptions
3225
from ghga_connector.config import CONFIG, set_runtime_config
3326
from ghga_connector.constants import C4GH
34-
from ghga_connector.core import (
35-
CLIMessageDisplay,
36-
WorkPackageClient,
37-
async_client,
38-
)
39-
from ghga_connector.core.downloading.api_calls import DownloadClient
40-
from ghga_connector.core.downloading.batch_processing import FileInfo, FileStager
41-
from ghga_connector.core.main import (
42-
decrypt_file,
43-
download_file,
44-
get_work_package_token,
45-
upload_file,
46-
)
47-
48-
49-
@dataclass
50-
class WorkPackageInformation:
51-
"""Wraps decrypted work package token and id to pass to other functions"""
52-
53-
decrypted_token: str
54-
package_id: str
55-
56-
57-
def strtobool(value: str) -> bool:
58-
"""Inplace replacement for distutils.utils"""
59-
return value.lower() in ("y", "yes", "on", "1", "true", "t")
60-
61-
62-
def exception_hook(
63-
type_: BaseException,
64-
value: BaseException,
65-
traceback: TracebackType | None,
66-
):
67-
"""When debug mode is NOT enabled, gets called to perform final error handling
68-
before program exits
69-
"""
70-
message = (
71-
"An error occurred. Rerun command"
72-
+ " with --debug at the end to see more information."
73-
)
74-
75-
if value.args:
76-
message += f"\n{value.args[0]}"
77-
78-
CLIMessageDisplay.failure(message)
79-
80-
81-
def modify_for_debug(debug: bool):
82-
"""Enable debug logging and configure exception printing if debug=True"""
83-
if debug:
84-
# enable debug logging
85-
logging.basicConfig(level=logging.DEBUG)
86-
sys.excepthook = partial(exception_hook)
87-
88-
89-
def get_work_package_information(my_private_key: bytes):
90-
"""Fetch a work package id and work package token and decrypt the token"""
91-
# get work package access token and id from user input
92-
work_package_id, work_package_token = get_work_package_token(max_tries=3)
93-
decrypted_token = crypt.decrypt(data=work_package_token, key=my_private_key)
94-
return WorkPackageInformation(
95-
decrypted_token=decrypted_token, package_id=work_package_id
96-
)
97-
27+
from ghga_connector.core import CLIMessageDisplay, async_client
28+
from ghga_connector.core.main import async_download, decrypt_file, upload_file
29+
from ghga_connector.core.utils import modify_for_debug, strtobool
9830

9931
cli = typer.Typer(no_args_is_help=True)
10032

@@ -203,113 +135,6 @@ def download( # noqa: PLR0913
203135
)
204136

205137

206-
def get_public_key(my_public_key_path: Path) -> bytes:
207-
"""Get the user's private key from the path supplied"""
208-
if not my_public_key_path.is_file():
209-
raise exceptions.PubKeyFileDoesNotExistError(public_key_path=my_public_key_path)
210-
211-
return crypt4gh.keys.get_public_key(filepath=my_public_key_path)
212-
213-
214-
def get_private_key(my_private_key_path: Path, passphrase: str | None = None) -> bytes:
215-
"""Get the user's private key, using the passphrase if supplied/needed."""
216-
if passphrase:
217-
my_private_key = crypt4gh.keys.get_private_key(
218-
filepath=my_private_key_path, callback=lambda: passphrase
219-
)
220-
else:
221-
my_private_key = crypt4gh.keys.get_private_key(
222-
filepath=my_private_key_path, callback=None
223-
)
224-
return my_private_key
225-
226-
227-
async def async_download(
228-
*,
229-
output_dir: Path,
230-
my_public_key_path: Path,
231-
my_private_key_path: Path,
232-
passphrase: str | None = None,
233-
overwrite: bool = False,
234-
):
235-
"""Download files asynchronously"""
236-
if not output_dir.is_dir():
237-
raise exceptions.DirectoryDoesNotExistError(directory=output_dir)
238-
239-
my_public_key = get_public_key(my_public_key_path)
240-
my_private_key = get_private_key(my_private_key_path, passphrase)
241-
242-
CLIMessageDisplay.display("\nFetching work package token...")
243-
work_package_information = get_work_package_information(
244-
my_private_key=my_private_key
245-
)
246-
247-
async with async_client() as client, set_runtime_config(client=client):
248-
CLIMessageDisplay.display("Retrieving API configuration information...")
249-
work_package_client = WorkPackageClient(
250-
access_token=work_package_information.decrypted_token,
251-
client=client,
252-
package_id=work_package_information.package_id,
253-
my_private_key=my_private_key,
254-
my_public_key=my_public_key,
255-
)
256-
257-
file_ids_with_extension = await work_package_client.get_package_files()
258-
259-
download_client = DownloadClient(
260-
client=client, work_package_client=work_package_client
261-
)
262-
263-
CLIMessageDisplay.display("Preparing files for download...")
264-
stager = FileStager(
265-
wanted_files=file_ids_with_extension,
266-
output_dir=output_dir,
267-
work_package_client=work_package_client,
268-
download_client=download_client,
269-
config=CONFIG,
270-
)
271-
while not stager.finished:
272-
staged_files = await stager.get_staged_files()
273-
for file_info in staged_files:
274-
check_for_existing_file(file_info=file_info, overwrite=overwrite)
275-
await download_file(
276-
download_client=download_client,
277-
file_info=file_info,
278-
max_concurrent_downloads=CONFIG.max_concurrent_downloads,
279-
part_size=CONFIG.part_size,
280-
)
281-
finalize_download(file_info)
282-
staged_files.clear()
283-
284-
285-
def check_for_existing_file(*, file_info: FileInfo, overwrite: bool):
286-
"""Check if a file with the given name already exists and conditionally overwrite it."""
287-
# check output file
288-
output_file = file_info.path_once_complete
289-
if output_file.exists():
290-
if overwrite:
291-
CLIMessageDisplay.display(
292-
f"A file with name '{output_file}' already exists and will be overwritten."
293-
)
294-
else:
295-
CLIMessageDisplay.failure(
296-
f"A file with name '{output_file}' already exists. Skipping."
297-
)
298-
return
299-
300-
output_file_ongoing = file_info.path_during_download
301-
if output_file_ongoing.exists():
302-
output_file_ongoing.unlink()
303-
304-
305-
def finalize_download(file_info: FileInfo):
306-
"""Rename a file after downloading and announce completion"""
307-
file_info.path_during_download.rename(file_info.path_once_complete)
308-
CLIMessageDisplay.success(
309-
f"File with id '{file_info.file_id}' has been successfully downloaded."
310-
)
311-
312-
313138
@cli.command(no_args_is_help=True)
314139
def decrypt( # noqa: PLR0912, C901
315140
*,

src/ghga_connector/core/downloading/api_calls.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
from ghga_connector.config import get_download_api_url
2626
from ghga_connector.constants import TIMEOUT_LONG
2727
from ghga_connector.core import RetryHandler
28-
from ghga_connector.core.api_calls.utils import modify_headers_for_cache_refresh
28+
from ghga_connector.core.api_calls.utils import (
29+
is_service_healthy,
30+
modify_headers_for_cache_refresh,
31+
)
2932
from ghga_connector.core.work_package import WorkPackageClient
3033

3134
from .structs import RetryResponse
@@ -100,6 +103,14 @@ async def _retrieve_drs_object_using_params(
100103

101104
return _handle_drs_object_response(url=url, response=response)
102105

106+
def check_download_api_is_reachable(self):
107+
"""Verify that the download API is reachable.
108+
109+
Raises an `ApiNotReachableError` if it is not reachable.
110+
"""
111+
if not is_service_healthy(self._download_api_url):
112+
raise exceptions.ApiNotReachableError(api_url=self._download_api_url)
113+
103114
async def get_envelope_authorization_headers(
104115
self, *, file_id: str
105116
) -> httpx.Headers:

src/ghga_connector/core/downloading/batch_processing.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@
1414
# limitations under the License.
1515
"""Module for batch processing related code"""
1616

17+
from collections.abc import AsyncGenerator
1718
from dataclasses import dataclass, field
1819
from pathlib import Path
1920
from time import perf_counter, sleep
2021

2122
from ghga_connector import exceptions
2223
from ghga_connector.config import Config, get_download_api_url
2324
from ghga_connector.constants import C4GH
24-
from ghga_connector.core import CLIMessageDisplay, WorkPackageClient
25+
from ghga_connector.core import CLIMessageDisplay, WorkPackageClient, utils
2526
from ghga_connector.core.api_calls import is_service_healthy
2627

2728
from .api_calls import DownloadClient
28-
from .structs import RetryResponse
29+
from .structs import FileInfo, RetryResponse
2930

3031

3132
@dataclass
@@ -60,36 +61,6 @@ def handle_response(self, *, response: str):
6061
raise exceptions.AbortBatchProcessError()
6162

6263

63-
@dataclass
64-
class FileInfo:
65-
"""Information about a file to be downloaded"""
66-
67-
file_id: str
68-
file_extension: str
69-
file_size: int
70-
output_dir: Path
71-
72-
@property
73-
def file_name(self) -> str:
74-
"""Construct file name with suffix, if given"""
75-
file_name = f"{self.file_id}"
76-
if self.file_extension:
77-
file_name = f"{self.file_id}{self.file_extension}"
78-
return file_name
79-
80-
@property
81-
def path_during_download(self) -> Path:
82-
"""The file path while the file download is still in progress"""
83-
# with_suffix() might overwrite existing suffixes, do this instead:
84-
output_file = self.path_once_complete
85-
return output_file.parent / (output_file.name + ".part")
86-
87-
@property
88-
def path_once_complete(self) -> Path:
89-
"""The file path once the download is complete"""
90-
return self.output_dir / f"{self.file_name}{C4GH}"
91-
92-
9364
class FileStager:
9465
"""Utility class to deal with file staging in batch processing."""
9566

@@ -130,7 +101,7 @@ def __init__(
130101
async def get_staged_files(self) -> list[FileInfo]:
131102
"""Get files that are already staged.
132103
133-
Returns a dict with file IDs as keys and FileInfo as values.
104+
Returns a list of `FileInfo` instances.
134105
These values contain the download URLs and file sizes.
135106
The dict should be cleared after these files have been downloaded.
136107
"""
@@ -221,3 +192,20 @@ def _handle_failures(self) -> bool:
221192
self._started_waiting = perf_counter() # reset the timer
222193
self._missing_files = [] # reset list of missing files
223194
return True
195+
196+
async def manage_file_downloads(self, overwrite: bool) -> AsyncGenerator[FileInfo]:
197+
"""Manages file downloads by checking for existing files before download,
198+
printing messages to the display, and renaming files after they are downloaded.
199+
200+
Yields file information.
201+
"""
202+
while not self.finished:
203+
staged_files = await self.get_staged_files()
204+
for file_info in staged_files:
205+
utils.check_for_existing_file(file_info=file_info, overwrite=overwrite)
206+
yield file_info
207+
file_info.path_during_download.rename(file_info.path_once_complete)
208+
CLIMessageDisplay.success(
209+
f"File with id '{file_info.file_id}' has been successfully downloaded."
210+
)
211+
staged_files.clear()

src/ghga_connector/core/downloading/downloader.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,39 @@
1919
import gc
2020
import logging
2121
from asyncio import PriorityQueue, Queue, Semaphore, Task
22+
from contextlib import contextmanager
2223
from io import BufferedWriter
2324
from pathlib import Path
2425

2526
from ghga_connector import exceptions
26-
from ghga_connector.core import (
27-
CLIMessageDisplay,
28-
PartRange,
29-
calc_part_ranges,
30-
)
27+
from ghga_connector.core import CLIMessageDisplay, PartRange, calc_part_ranges
3128
from ghga_connector.core.tasks import TaskHandler
3229

3330
from ..progress_bar import DownloadProgressBar
34-
from .api_calls import (
35-
DownloadClient,
36-
extract_download_url,
37-
)
38-
from .structs import RetryResponse
31+
from .api_calls import DownloadClient, extract_download_url
32+
from .structs import FileInfo, RetryResponse
3933

4034
logger = logging.getLogger(__name__)
4135
# TODO: [later] More better logging
4236

37+
__all__ = ["Downloader", "handle_download_errors"]
38+
39+
40+
@contextmanager
41+
def handle_download_errors(file_info: FileInfo):
42+
"""Used to handle download errors from `Downloader.download_file()`"""
43+
file_id = file_info.file_id
44+
try:
45+
yield
46+
except exceptions.GetEnvelopeError as error:
47+
CLIMessageDisplay.failure(
48+
f"The request to get an envelope for file '{file_id}' failed."
49+
)
50+
raise error
51+
except exceptions.DownloadError as error:
52+
CLIMessageDisplay.failure(f"Failed downloading with id '{file_id}'.")
53+
raise error
54+
4355

4456
class Downloader:
4557
"""Centralized high-level interface for downloading a single file.

0 commit comments

Comments
 (0)