Skip to content

Commit b010a8c

Browse files
[Modular] Add single file support to Modular (#12383)
* update * update * update * update * Apply style fixes * update * update * update * update * update --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 1b91856 commit b010a8c

File tree

11 files changed

+101
-59
lines changed

11 files changed

+101
-59
lines changed

docs/source/en/modular_diffusers/guiders.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ Change the [`~ComponentSpec.default_creation_method`] to `from_pretrained` and u
159159
```py
160160
guider_spec = t2i_pipeline.get_component_spec("guider")
161161
guider_spec.default_creation_method="from_pretrained"
162-
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
162+
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
163163
guider_spec.subfolder="pag_guider"
164164
pag_guider = guider_spec.load()
165165
t2i_pipeline.update_components(guider=pag_guider)

docs/source/en/modular_diffusers/modular_pipeline.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,14 @@ unet_spec
313313
ComponentSpec(
314314
name='unet',
315315
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
316-
repo='RunDiffusion/Juggernaut-XL-v9',
316+
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
317317
subfolder='unet',
318318
variant='fp16',
319319
default_creation_method='from_pretrained'
320320
)
321321

322322
# modify to load from a different repository
323-
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
323+
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
324324

325325
# load component with modified spec
326326
unet = unet_spec.load(torch_dtype=torch.float16)

docs/source/zh/modular_diffusers/guiders.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
157157
```py
158158
guider_spec = t2i_pipeline.get_component_spec("guider")
159159
guider_spec.default_creation_method="from_pretrained"
160-
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
160+
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
161161
guider_spec.subfolder="pag_guider"
162162
pag_guider = guider_spec.load()
163163
t2i_pipeline.update_components(guider=pag_guider)

docs/source/zh/modular_diffusers/modular_pipeline.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,14 @@ unet_spec
313313
ComponentSpec(
314314
name='unet',
315315
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
316-
repo='RunDiffusion/Juggernaut-XL-v9',
316+
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
317317
subfolder='unet',
318318
variant='fp16',
319319
default_creation_method='from_pretrained'
320320
)
321321

322322
# 修改以从不同的仓库加载
323-
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
323+
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
324324

325325
# 使用修改后的规范加载组件
326326
unet = unet_spec.load(torch_dtype=torch.float16)

src/diffusers/loaders/single_file_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,14 @@ def is_valid_url(url):
389389
return False
390390

391391

392+
def _is_single_file_path_or_url(pretrained_model_name_or_path):
393+
if not os.path.isfile(pretrained_model_name_or_path) or not is_valid_url(pretrained_model_name_or_path):
394+
return False
395+
396+
repo_id, weight_name = _extract_repo_id_and_weights_name(pretrained_model_name_or_path)
397+
return bool(repo_id and weight_name)
398+
399+
392400
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
393401
if not is_valid_url(pretrained_model_name_or_path):
394402
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
@@ -400,7 +408,6 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
400408
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
401409
match = re.match(pattern, pretrained_model_name_or_path)
402410
if not match:
403-
logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
404411
return repo_id, weights_name
405412

406413
repo_id = f"{match.group(1)}/{match.group(2)}"

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def init_pipeline(
360360
collection: Optional[str] = None,
361361
) -> "ModularPipeline":
362362
"""
363-
create a ModularPipeline, optionally accept modular_repo to load from hub.
363+
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
364364
"""
365365
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
366366
diffusers_module = importlib.import_module("diffusers")
@@ -1645,8 +1645,8 @@ def from_pretrained(
16451645
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
16461646
Path to a pretrained pipeline configuration. It will first try to load config from
16471647
`modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard
1648-
non-modular repositories. If the repo does not contain any pipeline config, it will be set to None
1649-
during initialization.
1648+
non-modular repositories. If the pretrained_model_name_or_path does not contain any pipeline config, it
1649+
will be set to None during initialization.
16501650
trust_remote_code (`bool`, optional):
16511651
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
16521652
pipeline blocks based on the custom code in `pretrained_model_name_or_path`
@@ -1807,7 +1807,7 @@ def register_components(self, **kwargs):
18071807
library, class_name = None, None
18081808

18091809
# extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
1810-
# e.g. {"repo": "stabilityai/stable-diffusion-2-1",
1810+
# e.g. {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1",
18111811
# "type_hint": ("diffusers", "UNet2DConditionModel"),
18121812
# "subfolder": "unet",
18131813
# "variant": None,
@@ -2111,8 +2111,10 @@ def load_components(self, names: Optional[Union[List[str], str]] = None, **kwarg
21112111
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
21122112
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
21132113
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
2114-
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`,
2115-
`variant`, `revision`, etc.
2114+
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
2115+
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
2116+
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
2117+
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
21162118
"""
21172119

21182120
if names is None:
@@ -2378,10 +2380,10 @@ def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
23782380
- "type_hint": Tuple[str, str]
23792381
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
23802382
- All loading fields defined by `component_spec.loading_fields()`, typically:
2381-
- "repo": Optional[str]
2382-
The model repository (e.g., "stabilityai/stable-diffusion-xl").
2383+
- "pretrained_model_name_or_path": Optional[str]
2384+
The model pretrained_model_name_or_pathsitory (e.g., "stabilityai/stable-diffusion-xl").
23832385
- "subfolder": Optional[str]
2384-
A subfolder within the repo where this component lives.
2386+
A subfolder within the pretrained_model_name_or_path where this component lives.
23852387
- "variant": Optional[str]
23862388
An optional variant identifier for the model.
23872389
- "revision": Optional[str]
@@ -2398,11 +2400,13 @@ def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
23982400
Example:
23992401
>>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import
24002402
UNet2DConditionModel >>> spec = ComponentSpec(
2401-
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ... repo="path/to/repo", ...
2402-
subfolder="subfolder", ... variant=None, ... revision=None, ...
2403-
default_creation_method="from_pretrained",
2403+
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ...
2404+
pretrained_model_name_or_path="path/to/pretrained_model_name_or_path", ... subfolder="subfolder", ...
2405+
variant=None, ... revision=None, ... default_creation_method="from_pretrained",
24042406
... ) >>> ModularPipeline._component_spec_to_dict(spec) {
2405-
"type_hint": ("diffusers", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": "subfolder",
2407+
"type_hint": ("diffusers", "UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo",
2408+
"subfolder": "subfolder", "variant": None, "revision": None, "type_hint": ("diffusers",
2409+
"UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder",
24062410
"variant": None, "revision": None,
24072411
}
24082412
"""
@@ -2432,10 +2436,10 @@ def _dict_to_component_spec(
24322436
- "type_hint": Tuple[str, str]
24332437
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
24342438
- All loading fields defined by `component_spec.loading_fields()`, typically:
2435-
- "repo": Optional[str]
2439+
- "pretrained_model_name_or_path": Optional[str]
24362440
The model repository (e.g., "stabilityai/stable-diffusion-xl").
24372441
- "subfolder": Optional[str]
2438-
A subfolder within the repo where this component lives.
2442+
A subfolder within the pretrained_model_name_or_path where this component lives.
24392443
- "variant": Optional[str]
24402444
An optional variant identifier for the model.
24412445
- "revision": Optional[str]
@@ -2452,11 +2456,20 @@ def _dict_to_component_spec(
24522456
ComponentSpec: A reconstructed ComponentSpec object.
24532457
24542458
Example:
2455-
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ... "repo":
2456-
"stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant": None, ... "revision": None, ...
2457-
} >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) ComponentSpec(
2458-
name="unet", type_hint=UNet2DConditionModel, config=None, repo="stabilityai/stable-diffusion-xl",
2459-
subfolder="unet", variant=None, revision=None, default_creation_method="from_pretrained"
2459+
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
2460+
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
2461+
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
2462+
ComponentSpec(
2463+
name="unet", type_hint=UNet2DConditionModel, config=None,
2464+
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
2465+
revision=None, default_creation_method="from_pretrained"
2466+
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
2467+
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
2468+
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
2469+
ComponentSpec(
2470+
name="unet", type_hint=UNet2DConditionModel, config=None,
2471+
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
2472+
revision=None, default_creation_method="from_pretrained"
24602473
)
24612474
"""
24622475
# make a shallow copy so we can pop() safely

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222

2323
from ..configuration_utils import ConfigMixin, FrozenDict
24+
from ..loaders.single_file_utils import _is_single_file_path_or_url
2425
from ..utils import is_torch_available, logging
2526

2627

@@ -80,24 +81,31 @@ class ComponentSpec:
8081
type_hint: Type of the component (e.g. UNet2DConditionModel)
8182
description: Optional description of the component
8283
config: Optional config dict for __init__ creation
83-
repo: Optional repo path for from_pretrained creation
84-
subfolder: Optional subfolder in repo
85-
variant: Optional variant in repo
86-
revision: Optional revision in repo
84+
pretrained_model_name_or_path: Optional pretrained_model_name_or_path path for from_pretrained creation
85+
subfolder: Optional subfolder in pretrained_model_name_or_path
86+
variant: Optional variant in pretrained_model_name_or_path
87+
revision: Optional revision in pretrained_model_name_or_path
8788
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
8889
"""
8990

9091
name: Optional[str] = None
9192
type_hint: Optional[Type] = None
9293
description: Optional[str] = None
9394
config: Optional[FrozenDict] = None
94-
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
95-
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
95+
pretrained_model_name_or_path: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
9696
subfolder: Optional[str] = field(default="", metadata={"loading": True})
9797
variant: Optional[str] = field(default=None, metadata={"loading": True})
9898
revision: Optional[str] = field(default=None, metadata={"loading": True})
9999
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
100100

101+
# Deprecated
102+
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": False})
103+
104+
def __post_init__(self):
105+
repo_value = self.repo
106+
if repo_value is not None and self.pretrained_model_name_or_path is None:
107+
object.__setattr__(self, "pretrained_model_name_or_path", repo_value)
108+
101109
def __hash__(self):
102110
"""Make ComponentSpec hashable, using load_id as the hash value."""
103111
return hash((self.name, self.load_id, self.default_creation_method))
@@ -182,8 +190,8 @@ def loading_fields(cls) -> List[str]:
182190
@property
183191
def load_id(self) -> str:
184192
"""
185-
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
186-
segments).
193+
Unique identifier for this spec's pretrained load, composed of
194+
pretrained_model_name_or_path|subfolder|variant|revision (no empty segments).
187195
"""
188196
if self.default_creation_method == "from_config":
189197
return "null"
@@ -197,12 +205,13 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
197205
Decode a load_id string back into a dictionary of loading fields and values.
198206
199207
Args:
200-
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
208+
load_id: The load_id string to decode, format: "pretrained_model_name_or_path|subfolder|variant|revision"
201209
where None values are represented as "null"
202210
203211
Returns:
204212
Dict mapping loading field names to their values. e.g. {
205-
"repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision"
213+
"pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder", "variant": "variant",
214+
"revision": "revision"
206215
} If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
207216
component not created with `load` method).
208217
"""
@@ -259,34 +268,45 @@ def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **k
259268
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
260269
def load(self, **kwargs) -> Any:
261270
"""Load component using from_pretrained."""
262-
263-
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
271+
# select loading fields from kwargs passed from user: e.g. pretrained_model_name_or_path, subfolder, variant, revision, note the list could change
264272
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
265273
# merge loading field value in the spec with user passed values to create load_kwargs
266274
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
267-
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
268-
repo = load_kwargs.pop("repo", None)
269-
if repo is None:
275+
276+
pretrained_model_name_or_path = load_kwargs.pop("pretrained_model_name_or_path", None)
277+
if pretrained_model_name_or_path is None:
270278
raise ValueError(
271-
"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)"
279+
"`pretrained_model_name_or_path` info is required when using `load` method (you can directly set it in `pretrained_model_name_or_path` field of the ComponentSpec or pass it as an argument)"
280+
)
281+
is_single_file = _is_single_file_path_or_url(pretrained_model_name_or_path)
282+
if is_single_file and self.type_hint is None:
283+
raise ValueError(
284+
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
272285
)
273286

274287
if self.type_hint is None:
275288
try:
276289
from diffusers import AutoModel
277290

278-
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
291+
component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
279292
except Exception as e:
280293
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
281294
# update type_hint if AutoModel load successfully
282295
self.type_hint = component.__class__
283296
else:
297+
# determine load method
298+
load_method = (
299+
getattr(self.type_hint, "from_single_file")
300+
if is_single_file
301+
else getattr(self.type_hint, "from_pretrained")
302+
)
303+
284304
try:
285-
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
305+
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
286306
except Exception as e:
287307
raise ValueError(f"Unable to load {self.name} using load method: {e}")
288308

289-
self.repo = repo
309+
self.pretrained_model_name_or_path = pretrained_model_name_or_path
290310
for k, v in load_kwargs.items():
291311
setattr(self, k, v)
292312
component._diffusers_load_id = self.load_id

tests/modular_pipelines/flux/test_modular_pipeline_flux.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
3737
pipeline_class = FluxModularPipeline
3838
pipeline_blocks_class = FluxAutoBlocks
39-
repo = "hf-internal-testing/tiny-flux-modular"
39+
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
4040

4141
params = frozenset(["prompt", "height", "width", "guidance_scale"])
4242
batch_params = frozenset(["prompt"])
@@ -62,7 +62,7 @@ def test_float16_inference(self):
6262
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
6363
pipeline_class = FluxModularPipeline
6464
pipeline_blocks_class = FluxAutoBlocks
65-
repo = "hf-internal-testing/tiny-flux-modular"
65+
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
6666

6767
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
6868
batch_params = frozenset(["prompt", "image"])
@@ -128,7 +128,7 @@ def test_float16_inference(self):
128128
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
129129
pipeline_class = FluxKontextModularPipeline
130130
pipeline_blocks_class = FluxKontextAutoBlocks
131-
repo = "hf-internal-testing/tiny-flux-kontext-pipe"
131+
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
132132

133133
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
134134
batch_params = frozenset(["prompt", "image"])

0 commit comments

Comments
 (0)