Skip to content

Commit c630e63

Browse files
committed
image embedding data cache
1 parent 4823909 commit c630e63

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

modules/shared_options.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@
291291
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
292292
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
293293
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
294+
"textual_inversion_image_embedding_data_cache": OptionInfo(False, 'Cache the data of image embeddings').info('potentially increase TI load time at the cost some disk space'),
294295
}))
295296

296297
options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "ui"), {

modules/textual_inversion/textual_inversion.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313
from PIL import Image, PngImagePlugin
1414

15-
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
15+
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes, cache
1616
import modules.textual_inversion.dataset
1717
from modules.textual_inversion.learn_schedule import LearnRateScheduler
1818

@@ -116,6 +116,7 @@ def __init__(self):
116116
self.expected_shape = -1
117117
self.embedding_dirs = {}
118118
self.previously_displayed_embeddings = ()
119+
self.image_embedding_cache = cache.cache('image-embedding')
119120

120121
def add_embedding_dir(self, path):
121122
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
@@ -154,6 +155,31 @@ def get_expected_shape(self):
154155
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
155156
return vec.shape[1]
156157

158+
def read_embedding_from_image(self, path, name):
159+
try:
160+
ondisk_mtime = os.path.getmtime(path)
161+
162+
if (cache_embedding := self.image_embedding_cache.get(path)) and ondisk_mtime == cache_embedding.get('mtime', 0):
163+
# cache will only be used if the file has not been modified time matches
164+
return cache_embedding.get('data', None), cache_embedding.get('name', None)
165+
166+
embed_image = Image.open(path)
167+
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
168+
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
169+
name = data.get('name', name)
170+
elif data := extract_image_data_embed(embed_image):
171+
name = data.get('name', name)
172+
173+
if data is None or shared.opts.textual_inversion_image_embedding_data_cache:
174+
# data of image embeddings only will be cached if the option textual_inversion_image_embedding_data_cache is enabled
175+
# results of images that are not embeddings will allways be cached to reduce unnecessary future disk reads
176+
self.image_embedding_cache[path] = {'data': data, 'name': None if data is None else name, 'mtime': ondisk_mtime}
177+
178+
return data, name
179+
except Exception:
180+
errors.report(f"Error loading embedding {path}", exc_info=True)
181+
return None, None
182+
157183
def load_from_file(self, path, filename):
158184
name, ext = os.path.splitext(filename)
159185
ext = ext.upper()
@@ -163,17 +189,10 @@ def load_from_file(self, path, filename):
163189
if second_ext.upper() == '.PREVIEW':
164190
return
165191

166-
embed_image = Image.open(path)
167-
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
168-
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
169-
name = data.get('name', name)
170-
else:
171-
data = extract_image_data_embed(embed_image)
172-
if data:
173-
name = data.get('name', name)
174-
else:
175-
# if data is None, means this is not an embedding, just a preview image
176-
return
192+
data, name = self.read_embedding_from_image(path, name)
193+
if data is None:
194+
return
195+
177196
elif ext in ['.BIN', '.PT']:
178197
data = torch.load(path, map_location="cpu")
179198
elif ext in ['.SAFETENSORS']:
@@ -191,7 +210,6 @@ def load_from_file(self, path, filename):
191210
else:
192211
print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.")
193212

194-
195213
def load_from_dir(self, embdir):
196214
if not os.path.isdir(embdir.path):
197215
return

0 commit comments

Comments
 (0)