@@ -202,42 +202,52 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
202202 # use vae checkpoint cache
203203 print (f"Loading VAE weights { vae_source } : cached { get_filename (vae_file )} " )
204204 store_base_vae (model )
205- _load_vae_dict (model , checkpoints_loaded [vae_file ])
205+ loaded = _load_vae_dict (model , checkpoints_loaded [vae_file ])
206206 else :
207207 assert os .path .isfile (vae_file ), f"VAE { vae_source } doesn't exist: { vae_file } "
208208 print (f"Loading VAE weights { vae_source } : { vae_file } " )
209209 store_base_vae (model )
210210
211211 vae_dict_1 = load_vae_dict (vae_file , map_location = shared .weight_load_location )
212- _load_vae_dict (model , vae_dict_1 )
212+ loaded = _load_vae_dict (model , vae_dict_1 )
213213
214- if cache_enabled :
214+ if loaded and cache_enabled :
215215 # cache newly loaded vae
216216 checkpoints_loaded [vae_file ] = vae_dict_1 .copy ()
217217
218218 # clean up cache if limit is reached
219- if cache_enabled :
219+ if loaded and cache_enabled :
220220 while len (checkpoints_loaded ) > shared .opts .sd_vae_checkpoint_cache + 1 : # we need to count the current model
221221 checkpoints_loaded .popitem (last = False ) # LRU
222222
223223 # If vae used is not in dict, update it
224224 # It will be removed on refresh though
225225 vae_opt = get_filename (vae_file )
226- if vae_opt not in vae_dict :
226+ if loaded and vae_opt not in vae_dict :
227227 vae_dict [vae_opt ] = vae_file
228228
229229 elif loaded_vae_file :
230230 restore_base_vae (model )
231+ loaded = True
231232
232- loaded_vae_file = vae_file
233+ if loaded :
234+ loaded_vae_file = vae_file
233235 model .base_vae = base_vae
234236 model .loaded_vae_file = loaded_vae_file
237+ return loaded
235238
236239
237240# don't call this from outside
238241def _load_vae_dict (model , vae_dict_1 ):
242+ conv_out = model .first_stage_model .state_dict ().get ("encoder.conv_out.weight" )
243+ # check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3]
244+ if conv_out .shape != vae_dict_1 ["encoder.conv_out.weight" ].shape :
245+ print ("Failed to load VAE. Size mismatched!" )
246+ return False
247+
239248 model .first_stage_model .load_state_dict (vae_dict_1 )
240249 model .first_stage_model .to (devices .dtype_vae )
250+ return True
241251
242252
243253def clear_loaded_vae ():
@@ -270,7 +280,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
270280
271281 sd_hijack .model_hijack .undo_hijack (sd_model )
272282
273- load_vae (sd_model , vae_file , vae_source )
283+ loaded = load_vae (sd_model , vae_file , vae_source )
274284
275285 sd_hijack .model_hijack .hijack (sd_model )
276286
@@ -279,5 +289,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
279289
280290 script_callbacks .model_loaded_callback (sd_model )
281291
282- print ("VAE weights loaded." )
292+ if loaded :
293+ print ("VAE weights loaded." )
283294 return sd_model
0 commit comments