Skip to content

Commit 015e20a

Browse files
committed
fix load_vae() to check size mismatch
1 parent e063757 commit 015e20a

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

modules/sd_vae.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,42 +202,52 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
202202
# use vae checkpoint cache
203203
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
204204
store_base_vae(model)
205-
_load_vae_dict(model, checkpoints_loaded[vae_file])
205+
loaded = _load_vae_dict(model, checkpoints_loaded[vae_file])
206206
else:
207207
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
208208
print(f"Loading VAE weights {vae_source}: {vae_file}")
209209
store_base_vae(model)
210210

211211
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
212-
_load_vae_dict(model, vae_dict_1)
212+
loaded = _load_vae_dict(model, vae_dict_1)
213213

214-
if cache_enabled:
214+
if loaded and cache_enabled:
215215
# cache newly loaded vae
216216
checkpoints_loaded[vae_file] = vae_dict_1.copy()
217217

218218
# clean up cache if limit is reached
219-
if cache_enabled:
219+
if loaded and cache_enabled:
220220
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
221221
checkpoints_loaded.popitem(last=False) # LRU
222222

223223
# If vae used is not in dict, update it
224224
# It will be removed on refresh though
225225
vae_opt = get_filename(vae_file)
226-
if vae_opt not in vae_dict:
226+
if loaded and vae_opt not in vae_dict:
227227
vae_dict[vae_opt] = vae_file
228228

229229
elif loaded_vae_file:
230230
restore_base_vae(model)
231+
loaded = True
231232

232-
loaded_vae_file = vae_file
233+
if loaded:
234+
loaded_vae_file = vae_file
233235
model.base_vae = base_vae
234236
model.loaded_vae_file = loaded_vae_file
237+
return loaded
235238

236239

237240
# don't call this from outside
238241
def _load_vae_dict(model, vae_dict_1):
242+
conv_out = model.first_stage_model.state_dict().get("encoder.conv_out.weight")
243+
# check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3]
244+
if conv_out.shape != vae_dict_1["encoder.conv_out.weight"].shape:
245+
print("Failed to load VAE. Size mismatched!")
246+
return False
247+
239248
model.first_stage_model.load_state_dict(vae_dict_1)
240249
model.first_stage_model.to(devices.dtype_vae)
250+
return True
241251

242252

243253
def clear_loaded_vae():
@@ -270,7 +280,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
270280

271281
sd_hijack.model_hijack.undo_hijack(sd_model)
272282

273-
load_vae(sd_model, vae_file, vae_source)
283+
loaded = load_vae(sd_model, vae_file, vae_source)
274284

275285
sd_hijack.model_hijack.hijack(sd_model)
276286

@@ -279,5 +289,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
279289

280290
script_callbacks.model_loaded_callback(sd_model)
281291

282-
print("VAE weights loaded.")
292+
if loaded:
293+
print("VAE weights loaded.")
283294
return sd_model

0 commit comments

Comments
 (0)