| 
16 | 16 | """CLI-specific wrappers around core functions."""  | 
17 | 17 | 
 
  | 
18 | 18 | import asyncio  | 
19 |  | -import logging  | 
20 | 19 | import os  | 
21 |  | -import sys  | 
22 |  | -from dataclasses import dataclass  | 
23 |  | -from functools import partial  | 
24 | 20 | from pathlib import Path  | 
25 |  | -from types import TracebackType  | 
26 | 21 | 
 
  | 
27 |  | -import crypt4gh.keys  | 
28 | 22 | import typer  | 
29 |  | -from ghga_service_commons.utils import crypt  | 
30 | 23 | 
 
  | 
31 | 24 | from ghga_connector import exceptions  | 
32 | 25 | from ghga_connector.config import CONFIG, set_runtime_config  | 
33 | 26 | from ghga_connector.constants import C4GH  | 
34 |  | -from ghga_connector.core import (  | 
35 |  | -    CLIMessageDisplay,  | 
36 |  | -    WorkPackageClient,  | 
37 |  | -    async_client,  | 
38 |  | -)  | 
39 |  | -from ghga_connector.core.downloading.api_calls import DownloadClient  | 
40 |  | -from ghga_connector.core.downloading.batch_processing import FileInfo, FileStager  | 
41 |  | -from ghga_connector.core.main import (  | 
42 |  | -    decrypt_file,  | 
43 |  | -    download_file,  | 
44 |  | -    get_work_package_token,  | 
45 |  | -    upload_file,  | 
46 |  | -)  | 
47 |  | - | 
48 |  | - | 
49 |  | -@dataclass  | 
50 |  | -class WorkPackageInformation:  | 
51 |  | -    """Wraps decrypted work package token and id to pass to other functions"""  | 
52 |  | - | 
53 |  | -    decrypted_token: str  | 
54 |  | -    package_id: str  | 
55 |  | - | 
56 |  | - | 
57 |  | -def strtobool(value: str) -> bool:  | 
58 |  | -    """Inplace replacement for distutils.utils"""  | 
59 |  | -    return value.lower() in ("y", "yes", "on", "1", "true", "t")  | 
60 |  | - | 
61 |  | - | 
62 |  | -def exception_hook(  | 
63 |  | -    type_: BaseException,  | 
64 |  | -    value: BaseException,  | 
65 |  | -    traceback: TracebackType | None,  | 
66 |  | -):  | 
67 |  | -    """When debug mode is NOT enabled, gets called to perform final error handling  | 
68 |  | -    before program exits  | 
69 |  | -    """  | 
70 |  | -    message = (  | 
71 |  | -        "An error occurred. Rerun command"  | 
72 |  | -        + " with --debug at the end to see more information."  | 
73 |  | -    )  | 
74 |  | - | 
75 |  | -    if value.args:  | 
76 |  | -        message += f"\n{value.args[0]}"  | 
77 |  | - | 
78 |  | -    CLIMessageDisplay.failure(message)  | 
79 |  | - | 
80 |  | - | 
81 |  | -def modify_for_debug(debug: bool):  | 
82 |  | -    """Enable debug logging and configure exception printing if debug=True"""  | 
83 |  | -    if debug:  | 
84 |  | -        # enable debug logging  | 
85 |  | -        logging.basicConfig(level=logging.DEBUG)  | 
86 |  | -        sys.excepthook = partial(exception_hook)  | 
87 |  | - | 
88 |  | - | 
89 |  | -def get_work_package_information(my_private_key: bytes):  | 
90 |  | -    """Fetch a work package id and work package token and decrypt the token"""  | 
91 |  | -    # get work package access token and id from user input  | 
92 |  | -    work_package_id, work_package_token = get_work_package_token(max_tries=3)  | 
93 |  | -    decrypted_token = crypt.decrypt(data=work_package_token, key=my_private_key)  | 
94 |  | -    return WorkPackageInformation(  | 
95 |  | -        decrypted_token=decrypted_token, package_id=work_package_id  | 
96 |  | -    )  | 
97 |  | - | 
 | 27 | +from ghga_connector.core import CLIMessageDisplay, async_client  | 
 | 28 | +from ghga_connector.core.main import async_download, decrypt_file, upload_file  | 
 | 29 | +from ghga_connector.core.utils import modify_for_debug, strtobool  | 
98 | 30 | 
 
  | 
99 | 31 | cli = typer.Typer(no_args_is_help=True)  | 
100 | 32 | 
 
  | 
@@ -203,113 +135,6 @@ def download(  # noqa: PLR0913  | 
203 | 135 |     )  | 
204 | 136 | 
 
  | 
205 | 137 | 
 
  | 
206 |  | -def get_public_key(my_public_key_path: Path) -> bytes:  | 
207 |  | -    """Get the user's private key from the path supplied"""  | 
208 |  | -    if not my_public_key_path.is_file():  | 
209 |  | -        raise exceptions.PubKeyFileDoesNotExistError(public_key_path=my_public_key_path)  | 
210 |  | - | 
211 |  | -    return crypt4gh.keys.get_public_key(filepath=my_public_key_path)  | 
212 |  | - | 
213 |  | - | 
214 |  | -def get_private_key(my_private_key_path: Path, passphrase: str | None = None) -> bytes:  | 
215 |  | -    """Get the user's private key, using the passphrase if supplied/needed."""  | 
216 |  | -    if passphrase:  | 
217 |  | -        my_private_key = crypt4gh.keys.get_private_key(  | 
218 |  | -            filepath=my_private_key_path, callback=lambda: passphrase  | 
219 |  | -        )  | 
220 |  | -    else:  | 
221 |  | -        my_private_key = crypt4gh.keys.get_private_key(  | 
222 |  | -            filepath=my_private_key_path, callback=None  | 
223 |  | -        )  | 
224 |  | -    return my_private_key  | 
225 |  | - | 
226 |  | - | 
227 |  | -async def async_download(  | 
228 |  | -    *,  | 
229 |  | -    output_dir: Path,  | 
230 |  | -    my_public_key_path: Path,  | 
231 |  | -    my_private_key_path: Path,  | 
232 |  | -    passphrase: str | None = None,  | 
233 |  | -    overwrite: bool = False,  | 
234 |  | -):  | 
235 |  | -    """Download files asynchronously"""  | 
236 |  | -    if not output_dir.is_dir():  | 
237 |  | -        raise exceptions.DirectoryDoesNotExistError(directory=output_dir)  | 
238 |  | - | 
239 |  | -    my_public_key = get_public_key(my_public_key_path)  | 
240 |  | -    my_private_key = get_private_key(my_private_key_path, passphrase)  | 
241 |  | - | 
242 |  | -    CLIMessageDisplay.display("\nFetching work package token...")  | 
243 |  | -    work_package_information = get_work_package_information(  | 
244 |  | -        my_private_key=my_private_key  | 
245 |  | -    )  | 
246 |  | - | 
247 |  | -    async with async_client() as client, set_runtime_config(client=client):  | 
248 |  | -        CLIMessageDisplay.display("Retrieving API configuration information...")  | 
249 |  | -        work_package_client = WorkPackageClient(  | 
250 |  | -            access_token=work_package_information.decrypted_token,  | 
251 |  | -            client=client,  | 
252 |  | -            package_id=work_package_information.package_id,  | 
253 |  | -            my_private_key=my_private_key,  | 
254 |  | -            my_public_key=my_public_key,  | 
255 |  | -        )  | 
256 |  | - | 
257 |  | -        file_ids_with_extension = await work_package_client.get_package_files()  | 
258 |  | - | 
259 |  | -        download_client = DownloadClient(  | 
260 |  | -            client=client, work_package_client=work_package_client  | 
261 |  | -        )  | 
262 |  | - | 
263 |  | -        CLIMessageDisplay.display("Preparing files for download...")  | 
264 |  | -        stager = FileStager(  | 
265 |  | -            wanted_files=file_ids_with_extension,  | 
266 |  | -            output_dir=output_dir,  | 
267 |  | -            work_package_client=work_package_client,  | 
268 |  | -            download_client=download_client,  | 
269 |  | -            config=CONFIG,  | 
270 |  | -        )  | 
271 |  | -        while not stager.finished:  | 
272 |  | -            staged_files = await stager.get_staged_files()  | 
273 |  | -            for file_info in staged_files:  | 
274 |  | -                check_for_existing_file(file_info=file_info, overwrite=overwrite)  | 
275 |  | -                await download_file(  | 
276 |  | -                    download_client=download_client,  | 
277 |  | -                    file_info=file_info,  | 
278 |  | -                    max_concurrent_downloads=CONFIG.max_concurrent_downloads,  | 
279 |  | -                    part_size=CONFIG.part_size,  | 
280 |  | -                )  | 
281 |  | -                finalize_download(file_info)  | 
282 |  | -            staged_files.clear()  | 
283 |  | - | 
284 |  | - | 
285 |  | -def check_for_existing_file(*, file_info: FileInfo, overwrite: bool):  | 
286 |  | -    """Check if a file with the given name already exists and conditionally overwrite it."""  | 
287 |  | -    # check output file  | 
288 |  | -    output_file = file_info.path_once_complete  | 
289 |  | -    if output_file.exists():  | 
290 |  | -        if overwrite:  | 
291 |  | -            CLIMessageDisplay.display(  | 
292 |  | -                f"A file with name '{output_file}' already exists and will be overwritten."  | 
293 |  | -            )  | 
294 |  | -        else:  | 
295 |  | -            CLIMessageDisplay.failure(  | 
296 |  | -                f"A file with name '{output_file}' already exists. Skipping."  | 
297 |  | -            )  | 
298 |  | -            return  | 
299 |  | - | 
300 |  | -    output_file_ongoing = file_info.path_during_download  | 
301 |  | -    if output_file_ongoing.exists():  | 
302 |  | -        output_file_ongoing.unlink()  | 
303 |  | - | 
304 |  | - | 
305 |  | -def finalize_download(file_info: FileInfo):  | 
306 |  | -    """Rename a file after downloading and announce completion"""  | 
307 |  | -    file_info.path_during_download.rename(file_info.path_once_complete)  | 
308 |  | -    CLIMessageDisplay.success(  | 
309 |  | -        f"File with id '{file_info.file_id}' has been successfully downloaded."  | 
310 |  | -    )  | 
311 |  | - | 
312 |  | - | 
313 | 138 | @cli.command(no_args_is_help=True)  | 
314 | 139 | def decrypt(  # noqa: PLR0912, C901  | 
315 | 140 |     *,  | 
 | 
0 commit comments