|
| 1 | +# /// script |
| 2 | +# requires-python = ">=3.12" |
| 3 | +# dependencies = [ |
| 4 | +# "httpx-retries", |
| 5 | +# "polars", |
| 6 | +# ] |
| 7 | +# /// |
| 8 | + |
| 9 | +import hashlib |
| 10 | +import shutil |
| 11 | + |
| 12 | +import aiofiles |
| 13 | +import httpx |
| 14 | +from httpx_retries import RetryTransport |
| 15 | +import polars as pl |
| 16 | + |
| 17 | + |
| 18 | +ARCHIVE_URL = "https://farm.cse.ucdavis.edu/~irber" |
| 19 | + |
| 20 | + |
| 21 | +async def main(args): |
| 22 | + manifest_url = f"{args.archive_url}/wort-{args.database}/SOURMASH-MANIFEST.parquet" |
| 23 | + manifest_df = ( |
| 24 | + pl.scan_parquet(manifest_url) |
| 25 | + .select(["internal_location", "sha256"]) |
| 26 | + .unique(subset=["internal_location"]) |
| 27 | + # .head(5) |
| 28 | + ) |
| 29 | + |
| 30 | + limiter = asyncio.Semaphore(args.max_downloaders) |
| 31 | + |
| 32 | + already_mirrored_locations = set() |
| 33 | + async with limiter: |
| 34 | + for root, dirs, files in args.basedir.walk(top_down=True): |
| 35 | + for name in files: |
| 36 | + already_mirrored_locations.add( |
| 37 | + str(root.relative_to(args.basedir) / name) |
| 38 | + ) |
| 39 | + print(len(already_mirrored_locations)) |
| 40 | + |
| 41 | + if args.full_check: |
| 42 | + # check sha56 |
| 43 | + internal_locations = [] |
| 44 | + sha256_sums = [] |
| 45 | + |
| 46 | + # async with asyncio.TaskGroup() as tg: |
| 47 | + for location in already_mirrored_locations: |
| 48 | + async with aiofiles.open(args.basedir / location, mode="rb") as f: |
| 49 | + h = hashlib.new("sha256") |
| 50 | + while (chnk := await f.read(1024 * 1024)) != b"": |
| 51 | + h.update(chnk) |
| 52 | + sha256 = h.hexdigest() |
| 53 | + |
| 54 | + internal_locations.append(location) |
| 55 | + sha256_sums.append(sha256) |
| 56 | + else: |
| 57 | + internal_locations = list(already_mirrored_locations) |
| 58 | + |
| 59 | + print(f"{len(internal_locations)} sha256 calculated") |
| 60 | + |
| 61 | + already_mirrored = {"internal_location": internal_locations} |
| 62 | + join_columns = ["internal_location"] |
| 63 | + schema = {"internal_location": pl.String} |
| 64 | + |
| 65 | + if args.full_check: |
| 66 | + already_mirrored["sha256"] = sha256_sums |
| 67 | + schema["sha256"] = pl.String |
| 68 | + join_columns.append("sha256") |
| 69 | + |
| 70 | + already_mirrored_df = pl.from_dict(already_mirrored, schema=schema).lazy() |
| 71 | + print(already_mirrored_df.collect()) |
| 72 | + |
| 73 | + to_mirror_df = manifest_df.join(already_mirrored_df, on=join_columns, how="anti") |
| 74 | + |
| 75 | + print(to_mirror_df.collect()) |
| 76 | + |
| 77 | + try: |
| 78 | + async with httpx.AsyncClient( |
| 79 | + timeout=30.0, |
| 80 | + # limits=httpx.Limits(max_connections=args.max_downloaders), |
| 81 | + base_url=f"{args.archive_url}/wort-{args.database}/", |
| 82 | + transport=RetryTransport(), |
| 83 | + ) as client: |
| 84 | + async with asyncio.TaskGroup() as tg: |
| 85 | + for location, sha256 in to_mirror_df.collect().iter_rows(): |
| 86 | + tg.create_task( |
| 87 | + download_sig( |
| 88 | + location, |
| 89 | + sha256, |
| 90 | + args.basedir, |
| 91 | + client, |
| 92 | + limiter, |
| 93 | + args.dry_run, |
| 94 | + ) |
| 95 | + ) |
| 96 | + except* Exception as eg: |
| 97 | + print(*[str(e)[:50] for e in eg.exceptions]) |
| 98 | + |
| 99 | + |
| 100 | +async def download_sig(location, sha256, basedir, client, limiter, dry_run): |
| 101 | + async with limiter: |
| 102 | + if dry_run: |
| 103 | + print(f"download: {location}") |
| 104 | + return |
| 105 | + |
| 106 | + async with client.stream("GET", location) as response: |
| 107 | + h = hashlib.new("sha256") |
| 108 | + total_bytes = 0 |
| 109 | + response.raise_for_status() |
| 110 | + # download to temp location |
| 111 | + async with aiofiles.tempfile.NamedTemporaryFile(delete=False) as f: |
| 112 | + async for chnk in response.aiter_raw(1024 * 1024): |
| 113 | + h.update(chnk) |
| 114 | + await f.write(chnk) |
| 115 | + total_bytes += len(chnk) |
| 116 | + |
| 117 | + if sha256 != h.hexdigest(): |
| 118 | + # TODO: raise exception, download failed? |
| 119 | + # or maybe retry? |
| 120 | + print(f"download failed! expected {sha256}, got {h.hexdigest()}") |
| 121 | + |
| 122 | + await f.flush() |
| 123 | + |
| 124 | + # move to final location |
| 125 | + ## TODO: the goal here is to avoid incomplete downloads, |
| 126 | + ## but I'm still getting incomplete files =/ |
| 127 | + print(f"completed {location}, {total_bytes:,} bytes") |
| 128 | + await asyncio.to_thread(shutil.copyfile, f.name, basedir / location) |
| 129 | + |
| 130 | + |
| 131 | +if __name__ == "__main__": |
| 132 | + import argparse |
| 133 | + import asyncio |
| 134 | + import pathlib |
| 135 | + |
| 136 | + parser = argparse.ArgumentParser() |
| 137 | + parser.add_argument("-d", "--dry-run", default=True, action="store_true") |
| 138 | + parser.add_argument("-a", "--archive-url", default=ARCHIVE_URL) |
| 139 | + parser.add_argument("-m", "--max-downloaders", type=int, default=30) |
| 140 | + parser.add_argument("-f", "--full-check", default=False, action="store_true") |
| 141 | + parser.add_argument( |
| 142 | + "database", default="img", choices=["full", "img", "genomes", "sra"] |
| 143 | + ) |
| 144 | + parser.add_argument("basedir", type=pathlib.Path) |
| 145 | + |
| 146 | + args = parser.parse_args() |
| 147 | + |
| 148 | + asyncio.run(main(args)) |
0 commit comments