Skip to content

Commit 2c367f8

Browse files
committed
up
1 parent 66b6922 commit 2c367f8

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/models/test_modeling_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
XFormersAttnProcessor,
4848
)
4949
from diffusers.models.auto_model import AutoModel
50+
from diffusers.models.modeling_outputs import BaseOutput
5051
from diffusers.training_utils import EMAModel
5152
from diffusers.utils import (
5253
SAFE_WEIGHTS_INDEX_NAME,
@@ -109,7 +110,7 @@ def check_if_lora_correctly_set(model) -> bool:
109110

110111

111112
def normalize_output(out):
112-
out0 = out[0] if isinstance(out, tuple) else out
113+
out0 = out[0] if isinstance(out, (BaseOutput, tuple)) else out
113114
return torch.stack(out0) if isinstance(out0, list) else out0
114115

115116

0 commit comments

Comments
 (0)