Skip to content

Commit 45d14b6

Browse files
committed
Introduce AutoPipelineForText2Video (simple)
1 parent 8d415a6 commit 45d14b6

File tree

5 files changed

+157
-0
lines changed

5 files changed

+157
-0
lines changed

auto_pipeline_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
from diffusers import AutoPipelineForText2Video
3+
from diffusers.utils import export_to_video
4+
5+
pipe = AutoPipelineForText2Video.from_pretrained(
6+
"THUDM/CogVideoX-5b",
7+
torch_dtype=torch.bfloat16,
8+
)

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@
303303
"AutoPipelineForImage2Image",
304304
"AutoPipelineForInpainting",
305305
"AutoPipelineForText2Image",
306+
"AutoPipelineForText2Video",
306307
"ConsistencyModelPipeline",
307308
"DanceDiffusionPipeline",
308309
"DDIMPipeline",

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"AutoPipelineForImage2Image",
4747
"AutoPipelineForInpainting",
4848
"AutoPipelineForText2Image",
49+
"AutoPipelineForText2Video",
4950
]
5051
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
5152
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@
118118
StableDiffusionXLPipeline,
119119
)
120120
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
121+
from .hunyuan_video import HunyuanVideoPipeline
122+
from .cogvideo import CogVideoXPipeline
121123
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
122124

123125

@@ -218,6 +220,8 @@
218220
AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
219221
[
220222
("wan", WanPipeline),
223+
("hunyuan", HunyuanVideoPipeline),
224+
("cogvideox", CogVideoXPipeline),
221225
]
222226
)
223227

@@ -1203,3 +1207,39 @@ def from_pipe(cls, pipeline, **kwargs):
12031207
model.register_to_config(**unused_original_config)
12041208

12051209
return model
1210+
1211+
class AutoPipelineForText2Video(ConfigMixin):
1212+
1213+
config_name = "model_index.json"
1214+
1215+
def __init__(self, *args, **kwargs):
1216+
raise EnvironmentError(
1217+
f"{self.__class__.__name__} is designed to be instantiated "
1218+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
1219+
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
1220+
)
1221+
1222+
@classmethod
1223+
@validate_hf_hub_args
1224+
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
1225+
cache_dir = kwargs.pop("cache_dir", None)
1226+
force_download = kwargs.pop("force_download", False)
1227+
proxies = kwargs.pop("proxies", None)
1228+
token = kwargs.pop("token", None)
1229+
local_files_only = kwargs.pop("local_files_only", False)
1230+
revision = kwargs.pop("revision", None)
1231+
1232+
load_config_kwargs = {
1233+
"cache_dir": cache_dir,
1234+
"force_download": force_download,
1235+
"proxies": proxies,
1236+
"token": token,
1237+
"local_files_only": local_files_only,
1238+
"revision": revision,
1239+
}
1240+
1241+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
1242+
orig_class_name = config["_class_name"]
1243+
text_to_video_cls = _get_task_class(AUTO_TEXT2VIDEO_PIPELINES_MAPPING, orig_class_name)
1244+
kwargs = {**load_config_kwargs, **kwargs}
1245+
return text_to_video_cls.from_pretrained(pretrained_model_or_path, **kwargs)

src/diffusers/pipelines/test.ipynb

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "168c7d5f-bdb8-48e3-b696-29848f3f5205",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"id": "2339eb62-6bc2-478d-803b-e56e5fb22844",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"!pip install diffusers\n",
21+
"!pip install transformers"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"id": "3bc31594-7a94-4255-85da-07a5c484e2b4",
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"from diffusers import HunyuanVideoPipeline, PipelineQuantizationConfig\n",
32+
"import torch\n",
33+
"\n",
34+
"pipeline = HunyuanVideoPipeline.from_pretrained(\n",
35+
" \"hunyuanvideo-community/HunyuanVideo\",\n",
36+
" torch_dtype=torch.bfloat16,\n",
37+
")\n",
38+
"print(pipeline.config)"
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": null,
44+
"id": "2bfdeca9-ab44-4380-ad3e-c3ec2c2b0d0e",
45+
"metadata": {},
46+
"outputs": [],
47+
"source": [
48+
"import torch\n",
49+
"from diffusers import TextToVideoZeroPipeline\n",
50+
"\n",
51+
"model_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n",
52+
"pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\n",
53+
"print(pipe.config)"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"id": "e7ea86aa-75d0-4dda-8f4a-12a666c34fb2",
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"import torch\n",
64+
"from diffusers import CogVideoXPipeline\n",
65+
"pipe = CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-2b\", torch_dtype=torch.float16).to(\"cuda\")\n",
66+
"print(pipe.config)"
67+
]
68+
},
69+
{
70+
"cell_type": "code",
71+
"execution_count": null,
72+
"id": "e3e6c662-7fea-4474-9ba0-cbaaf5a5cca7",
73+
"metadata": {},
74+
"outputs": [],
75+
"source": []
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"id": "af6f3aae-3298-41aa-a9ed-37da79675ab3",
81+
"metadata": {},
82+
"outputs": [],
83+
"source": []
84+
}
85+
],
86+
"metadata": {
87+
"kernelspec": {
88+
"display_name": "Python3 (main venv)",
89+
"language": "python",
90+
"name": "main"
91+
},
92+
"language_info": {
93+
"codemirror_mode": {
94+
"name": "ipython",
95+
"version": 3
96+
},
97+
"file_extension": ".py",
98+
"mimetype": "text/x-python",
99+
"name": "python",
100+
"nbconvert_exporter": "python",
101+
"pygments_lexer": "ipython3",
102+
"version": "3.10.12"
103+
}
104+
},
105+
"nbformat": 4,
106+
"nbformat_minor": 5
107+
}

0 commit comments

Comments
 (0)