@@ -211,3 +211,80 @@ def open_folder(path):
211211 subprocess .Popen (["explorer.exe" , subprocess .check_output (["wslpath" , "-w" , path ])])
212212 else :
213213 subprocess .Popen (["xdg-open" , path ])
214+
215+
216+ def load_file_from_url (
217+ url : str ,
218+ * ,
219+ model_dir : str ,
220+ progress : bool = True ,
221+ file_name : str | None = None ,
222+ hash_prefix : str | None = None ,
223+ re_download : bool = False ,
224+ ) -> str :
225+ """Download a file from `url` into `model_dir`, using the file present if possible.
226+ Returns the path to the downloaded file.
227+
228+ file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url.
229+ file is downloaded to {file_name}.tmp then moved to the final location after download is complete.
230+ hash_prefix: sha256 hex string, if provided, the hash of the downloaded file will be checked against this prefix.
231+ if the hash does not match, the temporary file is deleted and a ValueError is raised.
232+ re_download: forcibly re-download the file even if it already exists.
233+ """
234+ from urllib .parse import urlparse
235+ import requests
236+ try :
237+ from tqdm import tqdm
238+ except ImportError :
239+ class tqdm :
240+ def __init__ (self , * args , ** kwargs ):
241+ pass
242+
243+ def update (self , n = 1 , * args , ** kwargs ):
244+ pass
245+
246+ def __enter__ (self ):
247+ return self
248+
249+ def __exit__ (self , exc_type , exc_val , exc_tb ):
250+ pass
251+
252+ if not file_name :
253+ parts = urlparse (url )
254+ file_name = os .path .basename (parts .path )
255+
256+ cached_file = os .path .abspath (os .path .join (model_dir , file_name ))
257+
258+ if re_download or not os .path .exists (cached_file ):
259+ os .makedirs (model_dir , exist_ok = True )
260+ temp_file = os .path .join (model_dir , f"{ file_name } .tmp" )
261+ print (f'\n Downloading: "{ url } " to { cached_file } ' )
262+ response = requests .get (url , stream = True )
263+ response .raise_for_status ()
264+ total_size = int (response .headers .get ('content-length' , 0 ))
265+ with tqdm (total = total_size , unit = 'B' , unit_scale = True , desc = file_name , disable = not progress ) as progress_bar :
266+ with open (temp_file , 'wb' ) as file :
267+ for chunk in response .iter_content (chunk_size = 1024 ):
268+ if chunk :
269+ file .write (chunk )
270+ progress_bar .update (len (chunk ))
271+
272+ if hash_prefix and not compare_sha256 (temp_file , hash_prefix ):
273+ print (f"Hash mismatch for { temp_file } . Deleting the temporary file." )
274+ os .remove (temp_file )
275+ raise ValueError (f"File hash does not match the expected hash prefix { hash_prefix } !" )
276+
277+ os .rename (temp_file , cached_file )
278+ return cached_file
279+
280+
281+ def compare_sha256 (file_path : str , hash_prefix : str ) -> bool :
282+ """Check if the SHA256 hash of the file matches the given prefix."""
283+ import hashlib
284+ hash_sha256 = hashlib .sha256 ()
285+ blksize = 1024 * 1024
286+
287+ with open (file_path , "rb" ) as f :
288+ for chunk in iter (lambda : f .read (blksize ), b"" ):
289+ hash_sha256 .update (chunk )
290+ return hash_sha256 .hexdigest ().startswith (hash_prefix .strip ().lower ())
0 commit comments