Skip to content

Commit 5613ff0

Browse files
committed
up
1 parent 4ca68f2 commit 5613ff0

File tree

2 files changed

+107
-17
lines changed

2 files changed

+107
-17
lines changed

tests/lora/test_lora_layers_z_image.py

Lines changed: 105 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
import sys
1616
import unittest
1717

18+
import numpy as np
1819
import torch
1920
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
2021

2122
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
2223

23-
from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend
24+
from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device
2425

2526

2627
if is_peft_available():
@@ -29,13 +30,9 @@
2930

3031
sys.path.append(".")
3132

32-
from .utils import PeftLoraLoaderMixinTests # noqa: E402
33+
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
3334

3435

35-
# @unittest.skip(
36-
# "ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations "
37-
# "and torch.empty padding tokens. LoRA functionality works correctly with real models."
38-
# )
3936
@require_peft_backend
4037
class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
4138
pipeline_class = ZImagePipeline
@@ -163,34 +160,128 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No
163160

164161
return pipeline_components, text_lora_config, denoiser_lora_config
165162

166-
@unittest.skip("Not supported in Flux2.")
163+
def test_correct_lora_configs_with_different_ranks(self):
164+
components, _, denoiser_lora_config = self.get_dummy_components()
165+
pipe = self.pipeline_class(**components)
166+
pipe = pipe.to(torch_device)
167+
pipe.set_progress_bar_config(disable=None)
168+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
169+
170+
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
171+
172+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
173+
174+
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
175+
176+
pipe.transformer.delete_adapters("adapter-1")
177+
178+
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
179+
for name, _ in denoiser.named_modules():
180+
if "to_k" in name and "attention" in name and "lora" not in name:
181+
module_name_to_rank_update = name.replace(".base_layer.", ".")
182+
break
183+
184+
# change the rank_pattern
185+
updated_rank = denoiser_lora_config.r * 2
186+
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
187+
188+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
189+
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
190+
191+
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
192+
193+
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
194+
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
195+
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
196+
197+
pipe.transformer.delete_adapters("adapter-1")
198+
199+
# similarly change the alpha_pattern
200+
updated_alpha = denoiser_lora_config.lora_alpha * 2
201+
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
202+
203+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
204+
self.assertTrue(
205+
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
206+
)
207+
208+
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
209+
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
210+
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
211+
212+
@skip_mps
213+
def test_lora_fuse_nan(self):
214+
components, _, denoiser_lora_config = self.get_dummy_components()
215+
pipe = self.pipeline_class(**components)
216+
pipe = pipe.to(torch_device)
217+
pipe.set_progress_bar_config(disable=None)
218+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
219+
220+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
221+
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
222+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
223+
224+
# corrupt one LoRA weight with `inf` values
225+
with torch.no_grad():
226+
possible_tower_names = ["noise_refiner"]
227+
filtered_tower_names = [
228+
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
229+
]
230+
for tower_name in filtered_tower_names:
231+
transformer_tower = getattr(pipe.transformer, tower_name)
232+
transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf")
233+
234+
# with `safe_fusing=True` we should see an Error
235+
with self.assertRaises(ValueError):
236+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
237+
238+
# without we should not see an error, but every image will be black
239+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
240+
out = pipe(**inputs)[0]
241+
242+
self.assertTrue(np.isnan(out).all())
243+
244+
def test_lora_scale_kwargs_match_fusion(self):
245+
super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2)
246+
247+
unittest.skip("Needs to be debugged.")
248+
249+
def test_set_adapters_match_attention_kwargs(self):
250+
super().test_set_adapters_match_attention_kwargs()
251+
252+
unittest.skip("Needs to be debugged.")
253+
254+
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
255+
super().test_simple_inference_with_text_denoiser_lora_and_scale()
256+
257+
@unittest.skip("Not supported in ZImage.")
167258
def test_simple_inference_with_text_denoiser_block_scale(self):
168259
pass
169260

170-
@unittest.skip("Not supported in Flux2.")
261+
@unittest.skip("Not supported in ZImage.")
171262
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
172263
pass
173264

174-
@unittest.skip("Not supported in Flux2.")
265+
@unittest.skip("Not supported in ZImage.")
175266
def test_modify_padding_mode(self):
176267
pass
177268

178-
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
269+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
179270
def test_simple_inference_with_partial_text_lora(self):
180271
pass
181272

182-
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
273+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
183274
def test_simple_inference_with_text_lora(self):
184275
pass
185276

186-
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
277+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
187278
def test_simple_inference_with_text_lora_and_scale(self):
188279
pass
189280

190-
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
281+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
191282
def test_simple_inference_with_text_lora_fused(self):
192283
pass
193284

194-
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
285+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
195286
def test_simple_inference_with_text_lora_save_load(self):
196287
pass

tests/pipelines/z_image/test_z_image.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
2424

25-
from ...testing_utils import is_flaky, torch_device
25+
from ...testing_utils import torch_device
2626
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
2727
from ..test_pipelines_common import PipelineTesterMixin, to_np
2828

@@ -170,7 +170,6 @@ def get_dummy_inputs(self, device, seed=0):
170170

171171
return inputs
172172

173-
@is_flaky(max_attempts=10)
174173
def test_inference(self):
175174
device = "cpu"
176175

@@ -185,7 +184,7 @@ def test_inference(self):
185184
self.assertEqual(generated_image.shape, (3, 32, 32))
186185

187186
# fmt: off
188-
expected_slice = torch.tensor([0.4521, 0.4512, 0.4693, 0.5115, 0.5250, 0.5271, 0.4776, 0.4688, 0.2765, 0.2164, 0.5656, 0.6909, 0.3831, 0.5431, 0.5493, 0.4732])
187+
expected_slice = torch.tensor([0.4622, 0.4532, 0.4714, 0.5087, 0.5371, 0.5405, 0.4492, 0.4479, 0.2984, 0.2783, 0.5409, 0.6577, 0.3952, 0.5524, 0.5262, 0.453])
189188
# fmt: on
190189

191190
generated_slice = generated_image.flatten()

0 commit comments

Comments
 (0)