Skip to content

Commit 1d9dfd2

Browse files
committed
update
1 parent 1f217a5 commit 1d9dfd2

File tree

1 file changed

+36
-43
lines changed

1 file changed

+36
-43
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,6 @@
4848
from diffusers.utils.import_utils import is_xformers_available
4949
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
5050

51-
from ..models.autoencoders.vae import (
52-
get_asym_autoencoder_kl_config,
53-
get_autoencoder_kl_config,
54-
get_autoencoder_tiny_config,
55-
get_consistency_vae_config,
56-
)
5751
from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict
5852
from ..models.unets.test_models_unet_2d_condition import (
5953
create_ip_adapter_faceid_state_dict,
@@ -70,7 +64,6 @@
7064
require_torch,
7165
require_torch_accelerator,
7266
require_transformers_version_greater,
73-
skip_mps,
7467
torch_device,
7568
)
7669

@@ -193,12 +186,12 @@ def test_fused_qkv_projections(self):
193186
and hasattr(component, "original_attn_processors")
194187
and component.original_attn_processors is not None
195188
):
196-
assert check_qkv_fusion_processors_exist(
197-
component
198-
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
199-
assert check_qkv_fusion_matches_attn_procs_length(
200-
component, component.original_attn_processors
201-
), "Something wrong with the attention processors concerning the fused QKV projections."
189+
assert check_qkv_fusion_processors_exist(component), (
190+
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
191+
)
192+
assert check_qkv_fusion_matches_attn_procs_length(component, component.original_attn_processors), (
193+
"Something wrong with the attention processors concerning the fused QKV projections."
194+
)
202195

203196
inputs = self.get_dummy_inputs(device)
204197
inputs["return_dict"] = False
@@ -211,15 +204,15 @@ def test_fused_qkv_projections(self):
211204
image_disabled = pipe(**inputs)[0]
212205
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
213206

214-
assert np.allclose(
215-
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
216-
), "Fusion of QKV projections shouldn't affect the outputs."
217-
assert np.allclose(
218-
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
219-
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
220-
assert np.allclose(
221-
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
222-
), "Original outputs should match when fused QKV projections are disabled."
207+
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
208+
"Fusion of QKV projections shouldn't affect the outputs."
209+
)
210+
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
211+
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
212+
)
213+
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
214+
"Original outputs should match when fused QKV projections are disabled."
215+
)
223216

224217

225218
class IPAdapterTesterMixin:
@@ -862,9 +855,9 @@ def test_from_pipe_consistent_forward_pass(self, expected_max_diff=1e-3):
862855

863856
for component in pipe_original.components.values():
864857
if hasattr(component, "attn_processors"):
865-
assert all(
866-
type(proc) == AttnProcessor for proc in component.attn_processors.values()
867-
), "`from_pipe` changed the attention processor in original pipeline."
858+
assert all(type(proc) == AttnProcessor for proc in component.attn_processors.values()), (
859+
"`from_pipe` changed the attention processor in original pipeline."
860+
)
868861

869862
@require_accelerator
870863
@require_accelerate_version_greater("0.14.0")
@@ -2632,12 +2625,12 @@ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2)
26322625
image_slice_pab_disabled = output.flatten()
26332626
image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:]))
26342627

2635-
assert np.allclose(
2636-
original_image_slice, image_slice_pab_enabled, atol=expected_atol
2637-
), "PAB outputs should not differ much in specified timestep range."
2638-
assert np.allclose(
2639-
original_image_slice, image_slice_pab_disabled, atol=1e-4
2640-
), "Outputs from normal inference and after disabling cache should not differ."
2628+
assert np.allclose(original_image_slice, image_slice_pab_enabled, atol=expected_atol), (
2629+
"PAB outputs should not differ much in specified timestep range."
2630+
)
2631+
assert np.allclose(original_image_slice, image_slice_pab_disabled, atol=1e-4), (
2632+
"Outputs from normal inference and after disabling cache should not differ."
2633+
)
26412634

26422635

26432636
class FasterCacheTesterMixin:
@@ -2702,12 +2695,12 @@ def run_forward(pipe):
27022695
output = run_forward(pipe).flatten()
27032696
image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:]))
27042697

2705-
assert np.allclose(
2706-
original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol
2707-
), "FasterCache outputs should not differ much in specified timestep range."
2708-
assert np.allclose(
2709-
original_image_slice, image_slice_faster_cache_disabled, atol=1e-4
2710-
), "Outputs from normal inference and after disabling cache should not differ."
2698+
assert np.allclose(original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol), (
2699+
"FasterCache outputs should not differ much in specified timestep range."
2700+
)
2701+
assert np.allclose(original_image_slice, image_slice_faster_cache_disabled, atol=1e-4), (
2702+
"Outputs from normal inference and after disabling cache should not differ."
2703+
)
27112704

27122705
def test_faster_cache_state(self):
27132706
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
@@ -2842,12 +2835,12 @@ def run_forward(pipe):
28422835
output = run_forward(pipe).flatten()
28432836
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
28442837

2845-
assert np.allclose(
2846-
original_image_slice, image_slice_fbc_enabled, atol=expected_atol
2847-
), "FirstBlockCache outputs should not differ much."
2848-
assert np.allclose(
2849-
original_image_slice, image_slice_fbc_disabled, atol=1e-4
2850-
), "Outputs from normal inference and after disabling cache should not differ."
2838+
assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
2839+
"FirstBlockCache outputs should not differ much."
2840+
)
2841+
assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
2842+
"Outputs from normal inference and after disabling cache should not differ."
2843+
)
28512844

28522845

28532846
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.

0 commit comments

Comments
 (0)