|
28 | 28 | AutoencoderKL, |
29 | 29 | UNet2DConditionModel, |
30 | 30 | ) |
| 31 | +from diffusers.hooks.group_offloading import _GROUP_OFFLOADING, apply_group_offloading |
31 | 32 | from diffusers.utils import logging |
32 | 33 | from diffusers.utils.import_utils import is_peft_available |
33 | 34 |
|
@@ -2367,3 +2368,51 @@ def test_lora_loading_model_cpu_offload(self): |
2367 | 2368 |
|
2368 | 2369 | output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] |
2369 | 2370 | self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3)) |
| 2371 | + |
| 2372 | + @require_torch_accelerator |
| 2373 | + def test_lora_group_offloading_delete_adapters(self): |
| 2374 | + components, _, denoiser_lora_config = self.get_dummy_components() |
| 2375 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 2376 | + pipe = self.pipeline_class(**components) |
| 2377 | + pipe = pipe.to(torch_device) |
| 2378 | + pipe.set_progress_bar_config(disable=None) |
| 2379 | + |
| 2380 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 2381 | + denoiser.add_adapter(denoiser_lora_config) |
| 2382 | + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
| 2383 | + |
| 2384 | + try: |
| 2385 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 2386 | + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) |
| 2387 | + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) |
| 2388 | + self.pipeline_class.save_lora_weights( |
| 2389 | + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts |
| 2390 | + ) |
| 2391 | + |
| 2392 | + components, _, _ = self.get_dummy_components() |
| 2393 | + pipe = self.pipeline_class(**components) |
| 2394 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 2395 | + pipe.to(torch_device) |
| 2396 | + |
| 2397 | + # Enable Group Offloading (leaf_level for more granular testing) |
| 2398 | + apply_group_offloading( |
| 2399 | + denoiser, |
| 2400 | + onload_device=torch_device, |
| 2401 | + offload_device="cpu", |
| 2402 | + offload_type="leaf_level", |
| 2403 | + ) |
| 2404 | + |
| 2405 | + pipe.load_lora_weights(tmpdirname, adapter_name="default") |
| 2406 | + |
| 2407 | + out_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2408 | + |
| 2409 | + # Delete the adapter |
| 2410 | + pipe.delete_adapters("default") |
| 2411 | + |
| 2412 | + out_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2413 | + |
| 2414 | + self.assertFalse(np.allclose(out_lora, out_no_lora, atol=1e-3, rtol=1e-3)) |
| 2415 | + finally: |
| 2416 | + # Clean up the hooks to prevent state leak |
| 2417 | + if hasattr(denoiser, "_diffusers_hook"): |
| 2418 | + denoiser._diffusers_hook.remove_hook(_GROUP_OFFLOADING, recurse=True) |
0 commit comments