Skip to content

Commit cf34435

Browse files
committed
reinitialize the padding tokens to ones to prevent NaN problems.
1 parent e4702dc commit cf34435

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

tests/lora/test_lora_layers_z_image.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@
3737
from .utils import PeftLoraLoaderMixinTests # noqa: E402
3838

3939

40-
@unittest.skip(
41-
"ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations "
42-
"and torch.empty padding tokens. LoRA functionality works correctly with real models."
43-
)
40+
# @unittest.skip(
41+
# "ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations "
42+
# "and torch.empty padding tokens. LoRA functionality works correctly with real models."
43+
# )
4444
@require_peft_backend
4545
class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
4646
pipeline_class = ZImagePipeline
@@ -127,6 +127,11 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No
127127
tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id)
128128

129129
transformer = self.transformer_cls(**self.transformer_kwargs)
130+
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
131+
# This can cause NaN data values in our testing environment. Fixating them
132+
# helps prevent that issue.
133+
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
134+
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
130135
vae = self.vae_cls(**self.vae_kwargs)
131136

132137
if scheduler_cls is None:

tests/pipelines/z_image/test_z_image.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ def get_dummy_components(self):
101101
axes_dims=[8, 4, 4],
102102
axes_lens=[256, 32, 32],
103103
)
104+
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
105+
# This can cause NaN data values in our testing environment. Fixating them
106+
# helps prevent that issue.
107+
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
108+
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
104109

105110
torch.manual_seed(0)
106111
vae = AutoencoderKL(

0 commit comments

Comments
 (0)