Skip to content

Commit 6728b6e

Browse files
committed
change download from full to streaming for transformers
1 parent ec49df8 commit 6728b6e

File tree

2 files changed

+125
-10
lines changed

2 files changed

+125
-10
lines changed

inference/core/roboflow_api.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from inference.core.env import (
3333
API_BASE_URL,
34+
ATOMIC_CACHE_WRITES_ENABLED,
3435
INTERNAL_WEIGHTS_URL_SUFFIX,
3536
MD5_VERIFICATION_ENABLED,
3637
MODEL_CACHE_DIR,
@@ -819,11 +820,38 @@ def get_from_url(
819820
url: str,
820821
json_response: bool = True,
821822
verify_content_length: bool = False,
823+
stream: bool = False,
822824
) -> Union[Response, dict]:
823825
return _get_from_url(
824826
url=url,
825827
json_response=json_response,
826828
verify_content_length=verify_content_length,
829+
stream=stream,
830+
)
831+
832+
833+
def stream_url_to_cache(
834+
url: str,
835+
filename: str,
836+
model_id: str,
837+
chunk_size: int = 1024 * 1024,
838+
) -> None:
839+
"""
840+
Stream download a file from URL directly to cache without loading into memory.
841+
842+
This is a public wrapper for _stream_url_to_cache that can be imported by other modules.
843+
844+
Args:
845+
url: URL to download from
846+
filename: Target filename in cache
847+
model_id: Model ID for cache directory
848+
chunk_size: Size of chunks to download (default 1MB)
849+
"""
850+
return _stream_url_to_cache(
851+
url=url,
852+
filename=filename,
853+
model_id=model_id,
854+
chunk_size=chunk_size,
827855
)
828856

829857

@@ -837,13 +865,27 @@ def _get_from_url(
837865
url: str,
838866
json_response: bool = True,
839867
verify_content_length: bool = False,
868+
stream: bool = False,
840869
) -> Union[Response, dict]:
870+
"""
871+
Downloads data from URL with optional streaming support.
872+
873+
Args:
874+
url: URL to download from
875+
json_response: If True, parse response as JSON
876+
verify_content_length: If True, verify Content-Length header matches received data
877+
stream: If True, return Response with stream=True (caller must handle streaming)
878+
879+
Returns:
880+
Dict if json_response=True, Response object otherwise
881+
"""
841882
try:
842883
response = requests.get(
843884
wrap_url(url),
844885
headers=build_roboflow_api_headers(),
845886
timeout=ROBOFLOW_API_REQUEST_TIMEOUT,
846887
verify=ROBOFLOW_API_VERIFY_SSL,
888+
stream=stream,
847889
)
848890

849891
except (ConnectionError, Timeout, requests.exceptions.ConnectionError) as error:
@@ -859,6 +901,12 @@ def _get_from_url(
859901
raise RetryRequestError(message=str(error), inner_error=error) from error
860902
raise error
861903

904+
# For streaming responses, return immediately without consuming content
905+
# The caller is responsible for MD5 verification during streaming
906+
if stream:
907+
return response
908+
909+
# For non-streaming responses, verify MD5 and content-length as before
862910
if MD5_VERIFICATION_ENABLED and "x-goog-hash" in response.headers:
863911
x_goog_hash = response.headers["x-goog-hash"]
864912
md5_part = None
@@ -893,6 +941,76 @@ def _get_from_url(
893941
return response
894942

895943

944+
def _stream_url_to_cache(
945+
url: str,
946+
filename: str,
947+
model_id: str,
948+
chunk_size: int = 1024 * 1024, # 1MB chunks
949+
) -> None:
950+
"""
951+
Stream download a file from URL directly to cache without loading into memory.
952+
953+
Args:
954+
url: URL to download from
955+
filename: Target filename in cache
956+
model_id: Model ID for cache directory
957+
chunk_size: Size of chunks to download (default 1MB)
958+
959+
Raises:
960+
RoboflowAPIUnsuccessfulRequestError: If MD5 verification fails or download errors
961+
"""
962+
from inference.core.cache.model_artifacts import get_cache_file_path, initialise_cache
963+
from inference.core.utils.file_system import dump_bytes_atomic, dump_bytes
964+
965+
initialise_cache(model_id=model_id)
966+
cache_file_path = get_cache_file_path(file=filename, model_id=model_id)
967+
968+
response = _get_from_url(url, json_response=False, stream=True)
969+
970+
expected_md5_digest = None
971+
if MD5_VERIFICATION_ENABLED and "x-goog-hash" in response.headers:
972+
x_goog_hash = response.headers["x-goog-hash"]
973+
for part in x_goog_hash.split(","):
974+
if part.strip().startswith("md5="):
975+
md5_part = part.strip()[4:]
976+
expected_md5_digest = base64.b64decode(md5_part)
977+
break
978+
979+
computed_md5 = hashlib.md5() if MD5_VERIFICATION_ENABLED else None
980+
temp_file_path = f"{cache_file_path}.tmp"
981+
982+
try:
983+
with open(temp_file_path, 'wb') as f:
984+
for chunk in response.iter_content(chunk_size=chunk_size):
985+
if chunk:
986+
f.write(chunk)
987+
if computed_md5 is not None:
988+
computed_md5.update(chunk)
989+
990+
if expected_md5_digest is not None and computed_md5 is not None:
991+
if expected_md5_digest != computed_md5.digest():
992+
os.remove(temp_file_path)
993+
raise RoboflowAPIUnsuccessfulRequestError(
994+
f"MD5 hash does not match for {filename}. "
995+
f"Expected: {expected_md5_digest.hex()}, "
996+
f"Got: {computed_md5.hexdigest()}"
997+
)
998+
999+
if ATOMIC_CACHE_WRITES_ENABLED:
1000+
os.replace(temp_file_path, cache_file_path)
1001+
else:
1002+
if os.path.exists(cache_file_path):
1003+
os.remove(cache_file_path)
1004+
os.rename(temp_file_path, cache_file_path)
1005+
1006+
except Exception as e:
1007+
if os.path.exists(temp_file_path):
1008+
os.remove(temp_file_path)
1009+
raise
1010+
finally:
1011+
response.close()
1012+
1013+
8961014
def _add_params_to_url(url: str, params: List[Tuple[str, str]]) -> str:
8971015
if len(params) == 0:
8981016
return url

inference/models/transformers/transformers.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_roboflow_base_lora,
3030
get_roboflow_instant_model_data,
3131
get_roboflow_model_data,
32+
stream_url_to_cache,
3233
)
3334
from inference.core.utils.image_utils import load_image_rgb
3435

@@ -250,12 +251,9 @@ def download_model_artifacts_from_roboflow_api(
250251
filename = weights_url.split("?")[0].split("/")[-1]
251252
if filename.endswith(".npz"):
252253
continue
253-
model_weights_response = get_from_url(
254-
weights_url, json_response=False
255-
)
256-
save_bytes_in_cache(
257-
content=model_weights_response.content,
258-
file=filename,
254+
stream_url_to_cache(
255+
url=weights_url,
256+
filename=filename,
259257
model_id=self.endpoint,
260258
)
261259
if filename.endswith("tar.gz"):
@@ -382,12 +380,11 @@ def get_lora_base_from_roboflow(self, repo, revision) -> str:
382380
)
383381

384382
weights_url = api_data["weights"]["model"]
385-
model_weights_response = get_from_url(weights_url, json_response=False)
386383
filename = weights_url.split("?")[0].split("/")[-1]
387384
assert filename.endswith("tar.gz")
388-
save_bytes_in_cache(
389-
content=model_weights_response.content,
390-
file=filename,
385+
stream_url_to_cache(
386+
url=weights_url,
387+
filename=filename,
391388
model_id=base_dir,
392389
)
393390
tar_file_path = get_cache_file_path(filename, base_dir)

0 commit comments

Comments
 (0)