|
25 | 25 | from types import TracebackType |
26 | 26 |
|
27 | 27 | import crypt4gh.keys |
28 | | -import httpx |
29 | 28 | import typer |
30 | 29 | from ghga_service_commons.utils import crypt |
31 | 30 |
|
32 | 31 | from ghga_connector import exceptions |
33 | 32 | from ghga_connector.config import CONFIG, set_runtime_config |
| 33 | +from ghga_connector.constants import C4GH |
34 | 34 | from ghga_connector.core import ( |
35 | 35 | CLIMessageDisplay, |
36 | | - WorkPackageAccessor, |
| 36 | + WorkPackageClient, |
37 | 37 | async_client, |
38 | 38 | ) |
39 | | -from ghga_connector.core.downloading.batch_processing import FileStager |
| 39 | +from ghga_connector.core.downloading.api_calls import DownloadClient |
| 40 | +from ghga_connector.core.downloading.batch_processing import FileInfo, FileStager |
40 | 41 | from ghga_connector.core.main import ( |
41 | 42 | decrypt_file, |
42 | 43 | download_file, |
43 | | - get_wps_token, |
| 44 | + get_work_package_token, |
44 | 45 | upload_file, |
45 | 46 | ) |
46 | 47 |
|
47 | 48 |
|
48 | | -@dataclass |
49 | | -class DownloadParameters: |
50 | | - """Contains parameters returned by API calls to prepare information needed for download""" |
51 | | - |
52 | | - file_ids_with_extension: dict[str, str] |
53 | | - work_package_accessor: WorkPackageAccessor |
54 | | - |
55 | | - |
56 | 49 | @dataclass |
57 | 50 | class WorkPackageInformation: |
58 | 51 | """Wraps decrypted work package token and id to pass to other functions""" |
@@ -93,33 +86,10 @@ def modify_for_debug(debug: bool): |
93 | 86 | sys.excepthook = partial(exception_hook) |
94 | 87 |
|
95 | 88 |
|
96 | | -async def retrieve_download_parameters( |
97 | | - *, |
98 | | - client: httpx.AsyncClient, |
99 | | - my_private_key: bytes, |
100 | | - my_public_key: bytes, |
101 | | - work_package_information: WorkPackageInformation, |
102 | | -) -> DownloadParameters: |
103 | | - """Run necessary API calls to configure file download""" |
104 | | - work_package_accessor = WorkPackageAccessor( |
105 | | - access_token=work_package_information.decrypted_token, |
106 | | - client=client, |
107 | | - package_id=work_package_information.package_id, |
108 | | - my_private_key=my_private_key, |
109 | | - my_public_key=my_public_key, |
110 | | - ) |
111 | | - file_ids_with_extension = await work_package_accessor.get_package_files() |
112 | | - |
113 | | - return DownloadParameters( |
114 | | - file_ids_with_extension=file_ids_with_extension, |
115 | | - work_package_accessor=work_package_accessor, |
116 | | - ) |
117 | | - |
118 | | - |
119 | 89 | def get_work_package_information(my_private_key: bytes): |
120 | 90 | """Fetch a work package id and work package token and decrypt the token""" |
121 | 91 | # get work package access token and id from user input |
122 | | - work_package_id, work_package_token = get_wps_token(max_tries=3) |
| 92 | + work_package_id, work_package_token = get_work_package_token(max_tries=3) |
123 | 93 | decrypted_token = crypt.decrypt(data=work_package_token, key=my_private_key) |
124 | 94 | return WorkPackageInformation( |
125 | 95 | decrypted_token=decrypted_token, package_id=work_package_id |
@@ -276,38 +246,70 @@ async def async_download( |
276 | 246 |
|
277 | 247 | async with async_client() as client, set_runtime_config(client=client): |
278 | 248 | CLIMessageDisplay.display("Retrieving API configuration information...") |
279 | | - parameters = await retrieve_download_parameters( |
| 249 | + work_package_client = WorkPackageClient( |
| 250 | + access_token=work_package_information.decrypted_token, |
280 | 251 | client=client, |
| 252 | + package_id=work_package_information.package_id, |
281 | 253 | my_private_key=my_private_key, |
282 | 254 | my_public_key=my_public_key, |
283 | | - work_package_information=work_package_information, |
| 255 | + ) |
| 256 | + |
| 257 | + file_ids_with_extension = await work_package_client.get_package_files() |
| 258 | + |
| 259 | + download_client = DownloadClient( |
| 260 | + client=client, work_package_client=work_package_client |
284 | 261 | ) |
285 | 262 |
|
286 | 263 | CLIMessageDisplay.display("Preparing files for download...") |
287 | 264 | stager = FileStager( |
288 | | - wanted_file_ids=list(parameters.file_ids_with_extension), |
| 265 | + wanted_files=file_ids_with_extension, |
289 | 266 | output_dir=output_dir, |
290 | | - work_package_accessor=parameters.work_package_accessor, |
291 | | - client=client, |
| 267 | + work_package_client=work_package_client, |
| 268 | + download_client=download_client, |
292 | 269 | config=CONFIG, |
293 | 270 | ) |
294 | 271 | while not stager.finished: |
295 | 272 | staged_files = await stager.get_staged_files() |
296 | | - for file_id in staged_files: |
297 | | - CLIMessageDisplay.display(f"Downloading file with id '{file_id}'...") |
| 273 | + for file_info in staged_files: |
| 274 | + check_for_existing_file(file_info=file_info, overwrite=overwrite) |
298 | 275 | await download_file( |
299 | | - client=client, |
300 | | - file_id=file_id, |
301 | | - file_extension=parameters.file_ids_with_extension[file_id], |
302 | | - output_dir=output_dir, |
| 276 | + download_client=download_client, |
| 277 | + file_info=file_info, |
303 | 278 | max_concurrent_downloads=CONFIG.max_concurrent_downloads, |
304 | 279 | part_size=CONFIG.part_size, |
305 | | - work_package_accessor=parameters.work_package_accessor, |
306 | | - overwrite=overwrite, |
307 | 280 | ) |
| 281 | + finalize_download(file_info) |
308 | 282 | staged_files.clear() |
309 | 283 |
|
310 | 284 |
|
| 285 | +def check_for_existing_file(*, file_info: FileInfo, overwrite: bool): |
| 286 | + """Check if a file with the given name already exists and conditionally overwrite it.""" |
| 287 | + # check output file |
| 288 | + output_file = file_info.path_once_complete |
| 289 | + if output_file.exists(): |
| 290 | + if overwrite: |
| 291 | + CLIMessageDisplay.display( |
| 292 | + f"A file with name '{output_file}' already exists and will be overwritten." |
| 293 | + ) |
| 294 | + else: |
| 295 | + CLIMessageDisplay.failure( |
| 296 | + f"A file with name '{output_file}' already exists. Skipping." |
| 297 | + ) |
| 298 | + return |
| 299 | + |
| 300 | + output_file_ongoing = file_info.path_during_download |
| 301 | + if output_file_ongoing.exists(): |
| 302 | + output_file_ongoing.unlink() |
| 303 | + |
| 304 | + |
| 305 | +def finalize_download(file_info: FileInfo): |
| 306 | + """Rename a file after downloading and announce completion""" |
| 307 | + file_info.path_during_download.rename(file_info.path_once_complete) |
| 308 | + CLIMessageDisplay.success( |
| 309 | + f"File with id '{file_info.file_id}' has been successfully downloaded." |
| 310 | + ) |
| 311 | + |
| 312 | + |
311 | 313 | @cli.command(no_args_is_help=True) |
312 | 314 | def decrypt( # noqa: PLR0912, C901 |
313 | 315 | *, |
@@ -356,7 +358,7 @@ def decrypt( # noqa: PLR0912, C901 |
356 | 358 | skipped_files = [] |
357 | 359 | file_count = 0 |
358 | 360 | for input_file in input_dir.iterdir(): |
359 | | - if not input_file.is_file() or input_file.suffix != ".c4gh": |
| 361 | + if not input_file.is_file() or input_file.suffix != C4GH: |
360 | 362 | skipped_files.append(str(input_file)) |
361 | 363 | continue |
362 | 364 |
|
@@ -396,7 +398,7 @@ def decrypt( # noqa: PLR0912, C901 |
396 | 398 |
|
397 | 399 | if skipped_files: |
398 | 400 | CLIMessageDisplay.display( |
399 | | - "The following files were skipped as they are not .c4gh files:" |
| 401 | + f"The following files were skipped as they are not {C4GH} files:" |
400 | 402 | ) |
401 | 403 | for file in skipped_files: |
402 | 404 | CLIMessageDisplay.display(f"- {file}") |
|
0 commit comments