1212import numpy as np
1313from PIL import Image , PngImagePlugin
1414
15- from modules import shared , devices , sd_hijack , sd_models , images , sd_samplers , sd_hijack_checkpoint , errors , hashes
15+ from modules import shared , devices , sd_hijack , sd_models , images , sd_samplers , sd_hijack_checkpoint , errors , hashes , cache
1616import modules .textual_inversion .dataset
1717from modules .textual_inversion .learn_schedule import LearnRateScheduler
1818
@@ -116,6 +116,7 @@ def __init__(self):
116116 self .expected_shape = - 1
117117 self .embedding_dirs = {}
118118 self .previously_displayed_embeddings = ()
119+ self .image_embedding_cache = cache .cache ('image-embedding' )
119120
120121 def add_embedding_dir (self , path ):
121122 self .embedding_dirs [path ] = DirWithTextualInversionEmbeddings (path )
@@ -154,6 +155,31 @@ def get_expected_shape(self):
154155 vec = shared .sd_model .cond_stage_model .encode_embedding_init_text ("," , 1 )
155156 return vec .shape [1 ]
156157
158+ def read_embedding_from_image (self , path , name ):
159+ try :
160+ ondisk_mtime = os .path .getmtime (path )
161+
162+ if (cache_embedding := self .image_embedding_cache .get (path )) and ondisk_mtime == cache_embedding .get ('mtime' , 0 ):
163+ # cache will only be used if the file has not been modified time matches
164+ return cache_embedding .get ('data' , None ), cache_embedding .get ('name' , None )
165+
166+ embed_image = Image .open (path )
167+ if hasattr (embed_image , 'text' ) and 'sd-ti-embedding' in embed_image .text :
168+ data = embedding_from_b64 (embed_image .text ['sd-ti-embedding' ])
169+ name = data .get ('name' , name )
170+ elif data := extract_image_data_embed (embed_image ):
171+ name = data .get ('name' , name )
172+
173+ if data is None or shared .opts .textual_inversion_image_embedding_data_cache :
174+ # data of image embeddings only will be cached if the option textual_inversion_image_embedding_data_cache is enabled
175+ # results of images that are not embeddings will allways be cached to reduce unnecessary future disk reads
176+ self .image_embedding_cache [path ] = {'data' : data , 'name' : None if data is None else name , 'mtime' : ondisk_mtime }
177+
178+ return data , name
179+ except Exception :
180+ errors .report (f"Error loading embedding { path } " , exc_info = True )
181+ return None , None
182+
157183 def load_from_file (self , path , filename ):
158184 name , ext = os .path .splitext (filename )
159185 ext = ext .upper ()
@@ -163,17 +189,10 @@ def load_from_file(self, path, filename):
163189 if second_ext .upper () == '.PREVIEW' :
164190 return
165191
166- embed_image = Image .open (path )
167- if hasattr (embed_image , 'text' ) and 'sd-ti-embedding' in embed_image .text :
168- data = embedding_from_b64 (embed_image .text ['sd-ti-embedding' ])
169- name = data .get ('name' , name )
170- else :
171- data = extract_image_data_embed (embed_image )
172- if data :
173- name = data .get ('name' , name )
174- else :
175- # if data is None, means this is not an embedding, just a preview image
176- return
192+ data , name = self .read_embedding_from_image (path , name )
193+ if data is None :
194+ return
195+
177196 elif ext in ['.BIN' , '.PT' ]:
178197 data = torch .load (path , map_location = "cpu" )
179198 elif ext in ['.SAFETENSORS' ]:
@@ -191,7 +210,6 @@ def load_from_file(self, path, filename):
191210 else :
192211 print (f"Unable to load Textual inversion embedding due to data issue: '{ name } '." )
193212
194-
195213 def load_from_dir (self , embdir ):
196214 if not os .path .isdir (embdir .path ):
197215 return
0 commit comments