Skip to content

Commit 8f45d13

Browse files
committed
restore org_dtype != compute dtype case
1 parent cb1dabd commit 8f45d13

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

extensions-builtin/Lora/networks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,8 @@ def restore_weights_backup(obj, field, weight):
391391
setattr(obj, field, None)
392392
return
393393

394-
getattr(obj, field).copy_(weight)
394+
old_weight = getattr(obj, field)
395+
old_weight.copy_(weight.to(dtype=old_weight.dtype))
395396

396397

397398
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):

0 commit comments

Comments
 (0)