We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 66b6922 commit 2c367f8Copy full SHA for 2c367f8
tests/models/test_modeling_common.py
@@ -47,6 +47,7 @@
47
XFormersAttnProcessor,
48
)
49
from diffusers.models.auto_model import AutoModel
50
+from diffusers.models.modeling_outputs import BaseOutput
51
from diffusers.training_utils import EMAModel
52
from diffusers.utils import (
53
SAFE_WEIGHTS_INDEX_NAME,
@@ -109,7 +110,7 @@ def check_if_lora_correctly_set(model) -> bool:
109
110
111
112
def normalize_output(out):
- out0 = out[0] if isinstance(out, tuple) else out
113
+ out0 = out[0] if isinstance(out, (BaseOutput, tuple)) else out
114
return torch.stack(out0) if isinstance(out0, list) else out0
115
116
0 commit comments