|
15 | 15 | import sys |
16 | 16 | import unittest |
17 | 17 |
|
| 18 | +import numpy as np |
18 | 19 | import torch |
19 | 20 | from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model |
20 | 21 |
|
21 | 22 | from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel |
22 | 23 |
|
23 | | -from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend |
| 24 | +from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device |
24 | 25 |
|
25 | 26 |
|
26 | 27 | if is_peft_available(): |
|
29 | 30 |
|
30 | 31 | sys.path.append(".") |
31 | 32 |
|
32 | | -from .utils import PeftLoraLoaderMixinTests # noqa: E402 |
| 33 | +from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 |
33 | 34 |
|
34 | 35 |
|
35 | | -# @unittest.skip( |
36 | | -# "ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations " |
37 | | -# "and torch.empty padding tokens. LoRA functionality works correctly with real models." |
38 | | -# ) |
39 | 36 | @require_peft_backend |
40 | 37 | class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): |
41 | 38 | pipeline_class = ZImagePipeline |
@@ -163,34 +160,128 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No |
163 | 160 |
|
164 | 161 | return pipeline_components, text_lora_config, denoiser_lora_config |
165 | 162 |
|
166 | | - @unittest.skip("Not supported in Flux2.") |
| 163 | + def test_correct_lora_configs_with_different_ranks(self): |
| 164 | + components, _, denoiser_lora_config = self.get_dummy_components() |
| 165 | + pipe = self.pipeline_class(**components) |
| 166 | + pipe = pipe.to(torch_device) |
| 167 | + pipe.set_progress_bar_config(disable=None) |
| 168 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 169 | + |
| 170 | + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 171 | + |
| 172 | + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") |
| 173 | + |
| 174 | + lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 175 | + |
| 176 | + pipe.transformer.delete_adapters("adapter-1") |
| 177 | + |
| 178 | + denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer |
| 179 | + for name, _ in denoiser.named_modules(): |
| 180 | + if "to_k" in name and "attention" in name and "lora" not in name: |
| 181 | + module_name_to_rank_update = name.replace(".base_layer.", ".") |
| 182 | + break |
| 183 | + |
| 184 | + # change the rank_pattern |
| 185 | + updated_rank = denoiser_lora_config.r * 2 |
| 186 | + denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank} |
| 187 | + |
| 188 | + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") |
| 189 | + updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern |
| 190 | + |
| 191 | + self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank}) |
| 192 | + |
| 193 | + lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 194 | + self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)) |
| 195 | + self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)) |
| 196 | + |
| 197 | + pipe.transformer.delete_adapters("adapter-1") |
| 198 | + |
| 199 | + # similarly change the alpha_pattern |
| 200 | + updated_alpha = denoiser_lora_config.lora_alpha * 2 |
| 201 | + denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha} |
| 202 | + |
| 203 | + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") |
| 204 | + self.assertTrue( |
| 205 | + pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} |
| 206 | + ) |
| 207 | + |
| 208 | + lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 209 | + self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) |
| 210 | + self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) |
| 211 | + |
| 212 | + @skip_mps |
| 213 | + def test_lora_fuse_nan(self): |
| 214 | + components, _, denoiser_lora_config = self.get_dummy_components() |
| 215 | + pipe = self.pipeline_class(**components) |
| 216 | + pipe = pipe.to(torch_device) |
| 217 | + pipe.set_progress_bar_config(disable=None) |
| 218 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 219 | + |
| 220 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 221 | + denoiser.add_adapter(denoiser_lora_config, "adapter-1") |
| 222 | + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
| 223 | + |
| 224 | + # corrupt one LoRA weight with `inf` values |
| 225 | + with torch.no_grad(): |
| 226 | + possible_tower_names = ["noise_refiner"] |
| 227 | + filtered_tower_names = [ |
| 228 | + tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name) |
| 229 | + ] |
| 230 | + for tower_name in filtered_tower_names: |
| 231 | + transformer_tower = getattr(pipe.transformer, tower_name) |
| 232 | + transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf") |
| 233 | + |
| 234 | + # with `safe_fusing=True` we should see an Error |
| 235 | + with self.assertRaises(ValueError): |
| 236 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) |
| 237 | + |
| 238 | + # without we should not see an error, but every image will be black |
| 239 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) |
| 240 | + out = pipe(**inputs)[0] |
| 241 | + |
| 242 | + self.assertTrue(np.isnan(out).all()) |
| 243 | + |
| 244 | + def test_lora_scale_kwargs_match_fusion(self): |
| 245 | + super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2) |
| 246 | + |
| 247 | + unittest.skip("Needs to be debugged.") |
| 248 | + |
| 249 | + def test_set_adapters_match_attention_kwargs(self): |
| 250 | + super().test_set_adapters_match_attention_kwargs() |
| 251 | + |
| 252 | + unittest.skip("Needs to be debugged.") |
| 253 | + |
| 254 | + def test_simple_inference_with_text_denoiser_lora_and_scale(self): |
| 255 | + super().test_simple_inference_with_text_denoiser_lora_and_scale() |
| 256 | + |
| 257 | + @unittest.skip("Not supported in ZImage.") |
167 | 258 | def test_simple_inference_with_text_denoiser_block_scale(self): |
168 | 259 | pass |
169 | 260 |
|
170 | | - @unittest.skip("Not supported in Flux2.") |
| 261 | + @unittest.skip("Not supported in ZImage.") |
171 | 262 | def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): |
172 | 263 | pass |
173 | 264 |
|
174 | | - @unittest.skip("Not supported in Flux2.") |
| 265 | + @unittest.skip("Not supported in ZImage.") |
175 | 266 | def test_modify_padding_mode(self): |
176 | 267 | pass |
177 | 268 |
|
178 | | - @unittest.skip("Text encoder LoRA is not supported in Flux2.") |
| 269 | + @unittest.skip("Text encoder LoRA is not supported in ZImage.") |
179 | 270 | def test_simple_inference_with_partial_text_lora(self): |
180 | 271 | pass |
181 | 272 |
|
182 | | - @unittest.skip("Text encoder LoRA is not supported in Flux2.") |
| 273 | + @unittest.skip("Text encoder LoRA is not supported in ZImage.") |
183 | 274 | def test_simple_inference_with_text_lora(self): |
184 | 275 | pass |
185 | 276 |
|
186 | | - @unittest.skip("Text encoder LoRA is not supported in Flux2.") |
| 277 | + @unittest.skip("Text encoder LoRA is not supported in ZImage.") |
187 | 278 | def test_simple_inference_with_text_lora_and_scale(self): |
188 | 279 | pass |
189 | 280 |
|
190 | | - @unittest.skip("Text encoder LoRA is not supported in Flux2.") |
| 281 | + @unittest.skip("Text encoder LoRA is not supported in ZImage.") |
191 | 282 | def test_simple_inference_with_text_lora_fused(self): |
192 | 283 | pass |
193 | 284 |
|
194 | | - @unittest.skip("Text encoder LoRA is not supported in Flux2.") |
| 285 | + @unittest.skip("Text encoder LoRA is not supported in ZImage.") |
195 | 286 | def test_simple_inference_with_text_lora_save_load(self): |
196 | 287 | pass |
0 commit comments