|
52 | 52 | _convert_non_diffusers_lumina2_lora_to_diffusers, |
53 | 53 | _convert_non_diffusers_qwen_lora_to_diffusers, |
54 | 54 | _convert_non_diffusers_wan_lora_to_diffusers, |
| 55 | + _convert_non_diffusers_z_image_lora_to_diffusers, |
55 | 56 | _convert_xlabs_flux_lora_to_diffusers, |
56 | 57 | _maybe_map_sgm_blocks_to_diffusers, |
57 | 58 | ) |
@@ -5085,6 +5086,212 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): |
5085 | 5086 | super().unfuse_lora(components=components, **kwargs) |
5086 | 5087 |
|
5087 | 5088 |
|
| 5089 | +class ZImageLoraLoaderMixin(LoraBaseMixin): |
| 5090 | + r""" |
| 5091 | + Load LoRA layers into [`ZImageTransformer2DModel`]. Specific to [`ZImagePipeline`]. |
| 5092 | + """ |
| 5093 | + |
| 5094 | + _lora_loadable_modules = ["transformer"] |
| 5095 | + transformer_name = TRANSFORMER_NAME |
| 5096 | + |
| 5097 | + @classmethod |
| 5098 | + @validate_hf_hub_args |
| 5099 | + def lora_state_dict( |
| 5100 | + cls, |
| 5101 | + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], |
| 5102 | + **kwargs, |
| 5103 | + ): |
| 5104 | + r""" |
| 5105 | + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. |
| 5106 | + """ |
| 5107 | + # Load the main state dict first which has the LoRA layers for either of |
| 5108 | + # transformer and text encoder or both. |
| 5109 | + cache_dir = kwargs.pop("cache_dir", None) |
| 5110 | + force_download = kwargs.pop("force_download", False) |
| 5111 | + proxies = kwargs.pop("proxies", None) |
| 5112 | + local_files_only = kwargs.pop("local_files_only", None) |
| 5113 | + token = kwargs.pop("token", None) |
| 5114 | + revision = kwargs.pop("revision", None) |
| 5115 | + subfolder = kwargs.pop("subfolder", None) |
| 5116 | + weight_name = kwargs.pop("weight_name", None) |
| 5117 | + use_safetensors = kwargs.pop("use_safetensors", None) |
| 5118 | + return_lora_metadata = kwargs.pop("return_lora_metadata", False) |
| 5119 | + |
| 5120 | + allow_pickle = False |
| 5121 | + if use_safetensors is None: |
| 5122 | + use_safetensors = True |
| 5123 | + allow_pickle = True |
| 5124 | + |
| 5125 | + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} |
| 5126 | + |
| 5127 | + state_dict, metadata = _fetch_state_dict( |
| 5128 | + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, |
| 5129 | + weight_name=weight_name, |
| 5130 | + use_safetensors=use_safetensors, |
| 5131 | + local_files_only=local_files_only, |
| 5132 | + cache_dir=cache_dir, |
| 5133 | + force_download=force_download, |
| 5134 | + proxies=proxies, |
| 5135 | + token=token, |
| 5136 | + revision=revision, |
| 5137 | + subfolder=subfolder, |
| 5138 | + user_agent=user_agent, |
| 5139 | + allow_pickle=allow_pickle, |
| 5140 | + ) |
| 5141 | + |
| 5142 | + is_dora_scale_present = any("dora_scale" in k for k in state_dict) |
| 5143 | + if is_dora_scale_present: |
| 5144 | + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." |
| 5145 | + logger.warning(warn_msg) |
| 5146 | + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} |
| 5147 | + |
| 5148 | + has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict) |
| 5149 | + has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict) |
| 5150 | + has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict) |
| 5151 | + has_default = any("default." in k for k in state_dict) |
| 5152 | + if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default: |
| 5153 | + state_dict = _convert_non_diffusers_z_image_lora_to_diffusers(state_dict) |
| 5154 | + |
| 5155 | + out = (state_dict, metadata) if return_lora_metadata else state_dict |
| 5156 | + return out |
| 5157 | + |
| 5158 | + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights |
| 5159 | + def load_lora_weights( |
| 5160 | + self, |
| 5161 | + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], |
| 5162 | + adapter_name: Optional[str] = None, |
| 5163 | + hotswap: bool = False, |
| 5164 | + **kwargs, |
| 5165 | + ): |
| 5166 | + """ |
| 5167 | + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. |
| 5168 | + """ |
| 5169 | + if not USE_PEFT_BACKEND: |
| 5170 | + raise ValueError("PEFT backend is required for this method.") |
| 5171 | + |
| 5172 | + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) |
| 5173 | + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): |
| 5174 | + raise ValueError( |
| 5175 | + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." |
| 5176 | + ) |
| 5177 | + |
| 5178 | + # if a dict is passed, copy it instead of modifying it inplace |
| 5179 | + if isinstance(pretrained_model_name_or_path_or_dict, dict): |
| 5180 | + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() |
| 5181 | + |
| 5182 | + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. |
| 5183 | + kwargs["return_lora_metadata"] = True |
| 5184 | + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) |
| 5185 | + |
| 5186 | + is_correct_format = all("lora" in key for key in state_dict.keys()) |
| 5187 | + if not is_correct_format: |
| 5188 | + raise ValueError("Invalid LoRA checkpoint.") |
| 5189 | + |
| 5190 | + self.load_lora_into_transformer( |
| 5191 | + state_dict, |
| 5192 | + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, |
| 5193 | + adapter_name=adapter_name, |
| 5194 | + metadata=metadata, |
| 5195 | + _pipeline=self, |
| 5196 | + low_cpu_mem_usage=low_cpu_mem_usage, |
| 5197 | + hotswap=hotswap, |
| 5198 | + ) |
| 5199 | + |
| 5200 | + @classmethod |
| 5201 | + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->ZImageTransformer2DModel |
| 5202 | + def load_lora_into_transformer( |
| 5203 | + cls, |
| 5204 | + state_dict, |
| 5205 | + transformer, |
| 5206 | + adapter_name=None, |
| 5207 | + _pipeline=None, |
| 5208 | + low_cpu_mem_usage=False, |
| 5209 | + hotswap: bool = False, |
| 5210 | + metadata=None, |
| 5211 | + ): |
| 5212 | + """ |
| 5213 | + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. |
| 5214 | + """ |
| 5215 | + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): |
| 5216 | + raise ValueError( |
| 5217 | + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." |
| 5218 | + ) |
| 5219 | + |
| 5220 | + # Load the layers corresponding to transformer. |
| 5221 | + logger.info(f"Loading {cls.transformer_name}.") |
| 5222 | + transformer.load_lora_adapter( |
| 5223 | + state_dict, |
| 5224 | + network_alphas=None, |
| 5225 | + adapter_name=adapter_name, |
| 5226 | + metadata=metadata, |
| 5227 | + _pipeline=_pipeline, |
| 5228 | + low_cpu_mem_usage=low_cpu_mem_usage, |
| 5229 | + hotswap=hotswap, |
| 5230 | + ) |
| 5231 | + |
| 5232 | + @classmethod |
| 5233 | + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights |
| 5234 | + def save_lora_weights( |
| 5235 | + cls, |
| 5236 | + save_directory: Union[str, os.PathLike], |
| 5237 | + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, |
| 5238 | + is_main_process: bool = True, |
| 5239 | + weight_name: str = None, |
| 5240 | + save_function: Callable = None, |
| 5241 | + safe_serialization: bool = True, |
| 5242 | + transformer_lora_adapter_metadata: Optional[dict] = None, |
| 5243 | + ): |
| 5244 | + r""" |
| 5245 | + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. |
| 5246 | + """ |
| 5247 | + lora_layers = {} |
| 5248 | + lora_metadata = {} |
| 5249 | + |
| 5250 | + if transformer_lora_layers: |
| 5251 | + lora_layers[cls.transformer_name] = transformer_lora_layers |
| 5252 | + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata |
| 5253 | + |
| 5254 | + if not lora_layers: |
| 5255 | + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") |
| 5256 | + |
| 5257 | + cls._save_lora_weights( |
| 5258 | + save_directory=save_directory, |
| 5259 | + lora_layers=lora_layers, |
| 5260 | + lora_metadata=lora_metadata, |
| 5261 | + is_main_process=is_main_process, |
| 5262 | + weight_name=weight_name, |
| 5263 | + save_function=save_function, |
| 5264 | + safe_serialization=safe_serialization, |
| 5265 | + ) |
| 5266 | + |
| 5267 | + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora |
| 5268 | + def fuse_lora( |
| 5269 | + self, |
| 5270 | + components: List[str] = ["transformer"], |
| 5271 | + lora_scale: float = 1.0, |
| 5272 | + safe_fusing: bool = False, |
| 5273 | + adapter_names: Optional[List[str]] = None, |
| 5274 | + **kwargs, |
| 5275 | + ): |
| 5276 | + r""" |
| 5277 | + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. |
| 5278 | + """ |
| 5279 | + super().fuse_lora( |
| 5280 | + components=components, |
| 5281 | + lora_scale=lora_scale, |
| 5282 | + safe_fusing=safe_fusing, |
| 5283 | + adapter_names=adapter_names, |
| 5284 | + **kwargs, |
| 5285 | + ) |
| 5286 | + |
| 5287 | + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora |
| 5288 | + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): |
| 5289 | + r""" |
| 5290 | + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. |
| 5291 | + """ |
| 5292 | + super().unfuse_lora(components=components, **kwargs) |
| 5293 | + |
| 5294 | + |
5088 | 5295 | class Flux2LoraLoaderMixin(LoraBaseMixin): |
5089 | 5296 | r""" |
5090 | 5297 | Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`]. |
|
0 commit comments