Skip to content

Commit de9497a

Browse files
committed
fix load_vae() to check size mismatch
1 parent e063757 commit de9497a

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

modules/sd_vae.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,47 +197,58 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
197197

198198
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
199199

200+
loaded = False
200201
if vae_file:
201202
if cache_enabled and vae_file in checkpoints_loaded:
202203
# use vae checkpoint cache
203204
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
204205
store_base_vae(model)
205-
_load_vae_dict(model, checkpoints_loaded[vae_file])
206+
loaded = _load_vae_dict(model, checkpoints_loaded[vae_file])
206207
else:
207208
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
208209
print(f"Loading VAE weights {vae_source}: {vae_file}")
209210
store_base_vae(model)
210211

211212
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
212-
_load_vae_dict(model, vae_dict_1)
213+
loaded = _load_vae_dict(model, vae_dict_1)
213214

214-
if cache_enabled:
215+
if loaded and cache_enabled:
215216
# cache newly loaded vae
216217
checkpoints_loaded[vae_file] = vae_dict_1.copy()
217218

218219
# clean up cache if limit is reached
219-
if cache_enabled:
220+
if loaded and cache_enabled:
220221
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
221222
checkpoints_loaded.popitem(last=False) # LRU
222223

223224
# If vae used is not in dict, update it
224225
# It will be removed on refresh though
225226
vae_opt = get_filename(vae_file)
226-
if vae_opt not in vae_dict:
227+
if loaded and vae_opt not in vae_dict:
227228
vae_dict[vae_opt] = vae_file
228229

229230
elif loaded_vae_file:
230231
restore_base_vae(model)
232+
loaded = True
231233

232-
loaded_vae_file = vae_file
234+
if loaded:
235+
loaded_vae_file = vae_file
233236
model.base_vae = base_vae
234237
model.loaded_vae_file = loaded_vae_file
238+
return loaded
235239

236240

237241
# don't call this from outside
238242
def _load_vae_dict(model, vae_dict_1):
243+
conv_out = model.first_stage_model.state_dict().get("encoder.conv_out.weight")
244+
# check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3]
245+
if conv_out.shape != vae_dict_1["encoder.conv_out.weight"].shape:
246+
print("Failed to load VAE. Size mismatched!")
247+
return False
248+
239249
model.first_stage_model.load_state_dict(vae_dict_1)
240250
model.first_stage_model.to(devices.dtype_vae)
251+
return True
241252

242253

243254
def clear_loaded_vae():
@@ -270,7 +281,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
270281

271282
sd_hijack.model_hijack.undo_hijack(sd_model)
272283

273-
load_vae(sd_model, vae_file, vae_source)
284+
loaded = load_vae(sd_model, vae_file, vae_source)
274285

275286
sd_hijack.model_hijack.hijack(sd_model)
276287

@@ -279,5 +290,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
279290

280291
script_callbacks.model_loaded_callback(sd_model)
281292

282-
print("VAE weights loaded.")
293+
if loaded:
294+
print("VAE weights loaded.")
283295
return sd_model

0 commit comments

Comments
 (0)