Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions modules/hashes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hashlib
import os.path

from modules import shared
from modules import shared, errors
import modules.cache

dump_cache = modules.cache.dump_cache
Expand Down Expand Up @@ -32,7 +32,7 @@ def sha256_from_cache(filename, title, use_addnet_hash=False):
cached_sha256 = hashes[title].get("sha256", None)
cached_mtime = hashes[title].get("mtime", 0)

if ondisk_mtime > cached_mtime or cached_sha256 is None:
if ondisk_mtime != cached_mtime or cached_sha256 is None:
return None

return cached_sha256
Expand Down Expand Up @@ -82,3 +82,31 @@ def addnet_hash_safetensors(b):

return hash_sha256.hexdigest()


def partial_hash_from_cache(filename, ignore_cache=False):
"""old hash that only looks at a small part of the file and is prone to collisions
kept for compatibility, don't use this for new things
"""
try:
filename = str(filename)
mtime = os.path.getmtime(filename)
hashes = cache('partial-hash')
cache_entry = hashes.get(filename, {})
cache_mtime = cache_entry.get("mtime", 0)
cache_hash = cache_entry.get("hash", None)
if mtime == cache_mtime and cache_hash and not ignore_cache:
return cache_hash

with open(filename, 'rb') as file:
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
partial_hash = m.hexdigest()
hashes[filename] = {'mtime': mtime, 'hash': partial_hash}
return partial_hash[0:8]

except FileNotFoundError:
pass
except Exception:
errors.report(f'Error calculating partial hash for {filename}', exc_info=True)
return 'NOFILE'
18 changes: 2 additions & 16 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ldm.modules.midas as midas

from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.hashes import partial_hash_from_cache as model_hash # noqa: F401 for backwards compatibility
from modules.timer import Timer
from modules.shared import opts
import tomesd
Expand Down Expand Up @@ -87,7 +88,7 @@ def read_metadata():
self.name = name
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
self.hash = hashes.partial_hash_from_cache(filename)

self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
self.shorthash = self.sha256[0:10] if self.sha256 else None
Expand Down Expand Up @@ -200,21 +201,6 @@ def get_closet_checkpoint_match(search_string):
return None


def model_hash(filename):
"""old hash that only looks at a small part of the file and is prone to collisions"""

try:
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()

file.seek(0x100000)
m.update(file.read(0x10000))
return m.hexdigest()[0:8]
except FileNotFoundError:
return 'NOFILE'


def select_checkpoint():
"""Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint
Expand Down