@@ -197,47 +197,58 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
197197
198198 cache_enabled = shared .opts .sd_vae_checkpoint_cache > 0
199199
200+ loaded = False
200201 if vae_file :
201202 if cache_enabled and vae_file in checkpoints_loaded :
202203 # use vae checkpoint cache
203204 print (f"Loading VAE weights { vae_source } : cached { get_filename (vae_file )} " )
204205 store_base_vae (model )
205- _load_vae_dict (model , checkpoints_loaded [vae_file ])
206+ loaded = _load_vae_dict (model , checkpoints_loaded [vae_file ])
206207 else :
207208 assert os .path .isfile (vae_file ), f"VAE { vae_source } doesn't exist: { vae_file } "
208209 print (f"Loading VAE weights { vae_source } : { vae_file } " )
209210 store_base_vae (model )
210211
211212 vae_dict_1 = load_vae_dict (vae_file , map_location = shared .weight_load_location )
212- _load_vae_dict (model , vae_dict_1 )
213+ loaded = _load_vae_dict (model , vae_dict_1 )
213214
214- if cache_enabled :
215+ if loaded and cache_enabled :
215216 # cache newly loaded vae
216217 checkpoints_loaded [vae_file ] = vae_dict_1 .copy ()
217218
218219 # clean up cache if limit is reached
219- if cache_enabled :
220+ if loaded and cache_enabled :
220221 while len (checkpoints_loaded ) > shared .opts .sd_vae_checkpoint_cache + 1 : # we need to count the current model
221222 checkpoints_loaded .popitem (last = False ) # LRU
222223
223224 # If vae used is not in dict, update it
224225 # It will be removed on refresh though
225226 vae_opt = get_filename (vae_file )
226- if vae_opt not in vae_dict :
227+ if loaded and vae_opt not in vae_dict :
227228 vae_dict [vae_opt ] = vae_file
228229
229230 elif loaded_vae_file :
230231 restore_base_vae (model )
232+ loaded = True
231233
232- loaded_vae_file = vae_file
234+ if loaded :
235+ loaded_vae_file = vae_file
233236 model .base_vae = base_vae
234237 model .loaded_vae_file = loaded_vae_file
238+ return loaded
235239
236240
237241# don't call this from outside
238242def _load_vae_dict (model , vae_dict_1 ):
243+ conv_out = model .first_stage_model .state_dict ().get ("encoder.conv_out.weight" )
244+ # check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3]
245+ if conv_out .shape != vae_dict_1 ["encoder.conv_out.weight" ].shape :
246+ print ("Failed to load VAE. Size mismatched!" )
247+ return False
248+
239249 model .first_stage_model .load_state_dict (vae_dict_1 )
240250 model .first_stage_model .to (devices .dtype_vae )
251+ return True
241252
242253
243254def clear_loaded_vae ():
@@ -270,7 +281,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
270281
271282 sd_hijack .model_hijack .undo_hijack (sd_model )
272283
273- load_vae (sd_model , vae_file , vae_source )
284+ loaded = load_vae (sd_model , vae_file , vae_source )
274285
275286 sd_hijack .model_hijack .hijack (sd_model )
276287
@@ -279,5 +290,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
279290
280291 script_callbacks .model_loaded_callback (sd_model )
281292
282- print ("VAE weights loaded." )
293+ if loaded :
294+ print ("VAE weights loaded." )
283295 return sd_model
0 commit comments