ThorAILabs commited on
Commit
b720398
·
verified ·
1 Parent(s): 352d290

Upload 25 files

Browse files
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.4.0
2
+ torchvision>=0.19.0
3
+ opencv-python>=4.9.0.80
4
+ diffusers>=0.31.0
5
+ transformers>=4.49.0
6
+ tokenizers>=0.20.3
7
+ accelerate>=1.1.1
8
+ tqdm
9
+ imageio
10
+ easydict
11
+ ftfy
12
+ imageio-ffmpeg
13
+ flash_attn
14
+ gradio>=5.0.0
15
+ numpy>=1.23.5,<2
16
+ xfuser
t2v_14B_singleGPU.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+ import os
5
+ import sys
6
+ import warnings
7
+
8
+ import gradio as gr
9
+
10
+ warnings.filterwarnings('ignore')
11
+
12
+ # Model
13
+ sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
14
+ import wan
15
+ from wan.configs import WAN_CONFIGS
16
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
17
+ from wan.utils.utils import cache_video
18
+
19
+ # Global Var
20
+ prompt_expander = None
21
+ wan_t2v = None
22
+
23
+
24
+ # Button Func
25
+ def prompt_enc(prompt, tar_lang):
26
+ global prompt_expander
27
+ prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
28
+ if prompt_output.status == False:
29
+ return prompt
30
+ else:
31
+ return prompt_output.prompt
32
+
33
+
34
+ def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
35
+ shift_scale, seed, n_prompt):
36
+ global wan_t2v
37
+ # print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
38
+
39
+ W = int(resolution.split("*")[0])
40
+ H = int(resolution.split("*")[1])
41
+ video = wan_t2v.generate(
42
+ txt2vid_prompt,
43
+ size=(W, H),
44
+ shift=shift_scale,
45
+ sampling_steps=sd_steps,
46
+ guide_scale=guide_scale,
47
+ n_prompt=n_prompt,
48
+ seed=seed,
49
+ offload_model=True)
50
+
51
+ cache_video(
52
+ tensor=video[None],
53
+ save_file="example.mp4",
54
+ fps=16,
55
+ nrow=1,
56
+ normalize=True,
57
+ value_range=(-1, 1))
58
+
59
+ return "example.mp4"
60
+
61
+
62
+ # Interface
63
+ def gradio_interface():
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("""
66
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
67
+ Wan2.1 (T2V-14B)
68
+ </div>
69
+ <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
70
+ Wan: Open and Advanced Large-Scale Video Generative Models.
71
+ </div>
72
+ """)
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ txt2vid_prompt = gr.Textbox(
77
+ label="Prompt",
78
+ placeholder="Describe the video you want to generate",
79
+ )
80
+ tar_lang = gr.Radio(
81
+ choices=["CH", "EN"],
82
+ label="Target language of prompt enhance",
83
+ value="CH")
84
+ run_p_button = gr.Button(value="Prompt Enhance")
85
+
86
+ with gr.Accordion("Advanced Options", open=True):
87
+ resolution = gr.Dropdown(
88
+ label='Resolution(Width*Height)',
89
+ choices=[
90
+ '720*1280', '1280*720', '960*960', '1088*832',
91
+ '832*1088', '480*832', '832*480', '624*624',
92
+ '704*544', '544*704'
93
+ ],
94
+ value='720*1280')
95
+
96
+ with gr.Row():
97
+ sd_steps = gr.Slider(
98
+ label="Diffusion steps",
99
+ minimum=1,
100
+ maximum=1000,
101
+ value=50,
102
+ step=1)
103
+ guide_scale = gr.Slider(
104
+ label="Guide scale",
105
+ minimum=0,
106
+ maximum=20,
107
+ value=5.0,
108
+ step=1)
109
+ with gr.Row():
110
+ shift_scale = gr.Slider(
111
+ label="Shift scale",
112
+ minimum=0,
113
+ maximum=10,
114
+ value=5.0,
115
+ step=1)
116
+ seed = gr.Slider(
117
+ label="Seed",
118
+ minimum=-1,
119
+ maximum=2147483647,
120
+ step=1,
121
+ value=-1)
122
+ n_prompt = gr.Textbox(
123
+ label="Negative Prompt",
124
+ placeholder="Describe the negative prompt you want to add"
125
+ )
126
+
127
+ run_t2v_button = gr.Button("Generate Video")
128
+
129
+ with gr.Column():
130
+ result_gallery = gr.Video(
131
+ label='Generated Video', interactive=False, height=600)
132
+
133
+ run_p_button.click(
134
+ fn=prompt_enc,
135
+ inputs=[txt2vid_prompt, tar_lang],
136
+ outputs=[txt2vid_prompt])
137
+
138
+ run_t2v_button.click(
139
+ fn=t2v_generation,
140
+ inputs=[
141
+ txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
142
+ seed, n_prompt
143
+ ],
144
+ outputs=[result_gallery],
145
+ )
146
+
147
+ return demo
148
+
149
+
150
+ # Main
151
+ def _parse_args():
152
+ parser = argparse.ArgumentParser(
153
+ description="Generate a video from a text prompt or image using Gradio")
154
+ parser.add_argument(
155
+ "--ckpt_dir",
156
+ type=str,
157
+ default="cache",
158
+ help="The path to the checkpoint directory.")
159
+ parser.add_argument(
160
+ "--prompt_extend_method",
161
+ type=str,
162
+ default="local_qwen",
163
+ choices=["dashscope", "local_qwen"],
164
+ help="The prompt extend method to use.")
165
+ parser.add_argument(
166
+ "--prompt_extend_model",
167
+ type=str,
168
+ default=None,
169
+ help="The prompt extend model to use.")
170
+
171
+ args = parser.parse_args()
172
+
173
+ return args
174
+
175
+
176
+ if __name__ == '__main__':
177
+ args = _parse_args()
178
+
179
+ print("Step1: Init prompt_expander...", end='', flush=True)
180
+ if args.prompt_extend_method == "dashscope":
181
+ prompt_expander = DashScopePromptExpander(
182
+ model_name=args.prompt_extend_model, is_vl=False)
183
+ elif args.prompt_extend_method == "local_qwen":
184
+ prompt_expander = QwenPromptExpander(
185
+ model_name=args.prompt_extend_model, is_vl=False, device=0)
186
+ else:
187
+ raise NotImplementedError(
188
+ f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
189
+ print("done", flush=True)
190
+
191
+ print("Step2: Init 14B t2v model...", end='', flush=True)
192
+ cfg = WAN_CONFIGS['t2v-14B']
193
+ wan_t2v = wan.WanT2V(
194
+ config=cfg,
195
+ checkpoint_dir=args.ckpt_dir,
196
+ device_id=0,
197
+ rank=0,
198
+ t5_fsdp=False,
199
+ dit_fsdp=False,
200
+ use_usp=False,
201
+ )
202
+ print("done", flush=True)
203
+
204
+ demo = gradio_interface()
205
+ demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
wan/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import configs, distributed, modules
2
+ from .text2video import WanT2V
wan/configs/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import copy
3
+ import os
4
+
5
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
+
7
+ from .wan_t2v_14B import t2v_14B
8
+
9
+ # the config of t2i_14B is the same as t2v_14B
10
+ t2i_14B = copy.deepcopy(t2v_14B)
11
+ t2i_14B.__name__ = 'Config: Wan T2I 14B'
12
+
13
+ WAN_CONFIGS = {
14
+ 't2v-14B': t2v_14B,
15
+ 't2i-14B': t2i_14B,
16
+ }
17
+
18
+ SIZE_CONFIGS = {
19
+ "1920*1056": (1920, 1056),
20
+ "1920*1072": (1920, 1072),
21
+ "1920*832": (1920, 832),
22
+ "1280*560": (1280, 560),
23
+ "560*1280": (560, 1280),
24
+ "1056*1920": (1056, 1920),
25
+ "832*1920": (832, 1920),
26
+ '720*1280': (720, 1280),
27
+ '1280*720': (1280, 720),
28
+ '480*832': (480, 832),
29
+ '832*480': (832, 480),
30
+ '1024*1024': (1024, 1024),
31
+ }
32
+
33
+ MAX_AREA_CONFIGS = {
34
+ '720*1280': 720 * 1280,
35
+ '1280*720': 1280 * 720,
36
+ '480*832': 480 * 832,
37
+ '832*480': 832 * 480,
38
+ }
39
+
40
+ SUPPORTED_SIZES = {
41
+ 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480', "1920*1056", "1056*1920", "1920*832", "832*1920", "1920*1072", "1072*1920", "1280*560", "560*1280"),
42
+ 't2i-14B': tuple(SIZE_CONFIGS.keys()),
43
+ }
wan/configs/shared_config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ #------------------------ Wan shared config ------------------------#
6
+ wan_shared_cfg = EasyDict()
7
+
8
+ # t5
9
+ wan_shared_cfg.t5_model = 'umt5_xxl'
10
+ wan_shared_cfg.t5_dtype = torch.bfloat16
11
+ wan_shared_cfg.text_len = 512
12
+
13
+ # transformer
14
+ wan_shared_cfg.param_dtype = torch.bfloat16
15
+
16
+ # inference
17
+ wan_shared_cfg.num_train_timesteps = 1000
18
+ wan_shared_cfg.sample_fps = 16
19
+ wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
20
+ # wan_shared_cfg.sample_neg_prompt = "Vibrant colors, overexposed, static, blurry details, subtitles, stylized, artwork, painting, still image, overall grayish, worst quality, low quality, JPEG compression artifacts, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, deformed limbs, merged fingers, motionless frame, cluttered background, three legs, crowded background, walking backwards"
wan/configs/wan_t2v_14B.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan T2V 14B ------------------------#
7
+
8
+ t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
9
+ t2v_14B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ t2v_14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ t2v_14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ t2v_14B.patch_size = (1, 2, 2)
21
+ t2v_14B.dim = 5120
22
+ t2v_14B.ffn_dim = 13824
23
+ t2v_14B.freq_dim = 256
24
+ t2v_14B.num_heads = 40
25
+ t2v_14B.num_layers = 40
26
+ t2v_14B.window_size = (-1, -1)
27
+ t2v_14B.qk_norm = True
28
+ t2v_14B.cross_attn_norm = True
29
+ t2v_14B.eps = 1e-6
wan/distributed/__init__.py ADDED
File without changes
wan/distributed/fsdp.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
6
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
7
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
8
+
9
+
10
+ def shard_model(
11
+ model,
12
+ device_id,
13
+ param_dtype=torch.bfloat16,
14
+ reduce_dtype=torch.float32,
15
+ buffer_dtype=torch.float32,
16
+ process_group=None,
17
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
18
+ sync_module_states=True,
19
+ ):
20
+ model = FSDP(
21
+ module=model,
22
+ process_group=process_group,
23
+ sharding_strategy=sharding_strategy,
24
+ auto_wrap_policy=partial(
25
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
26
+ mixed_precision=MixedPrecision(
27
+ param_dtype=param_dtype,
28
+ reduce_dtype=reduce_dtype,
29
+ buffer_dtype=buffer_dtype),
30
+ device_id=device_id,
31
+ sync_module_states=sync_module_states)
32
+ return model
wan/distributed/xdit_context_parallel.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.amp as amp
4
+
5
+ from xfuser.core.distributed import get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group
6
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
7
+
8
+ from ..modules.model import sinusoidal_embedding_1d
9
+
10
+
11
+ def pad_freqs(original_tensor, target_len):
12
+ seq_len, s1, s2 = original_tensor.shape
13
+ pad_size = target_len - seq_len
14
+ padding_tensor = torch.ones(
15
+ pad_size,
16
+ s1,
17
+ s2,
18
+ dtype=original_tensor.dtype,
19
+ device=original_tensor.device)
20
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
21
+ return padded_tensor
22
+
23
+
24
+ @amp.autocast("cuda", enabled=False)
25
+ def rope_apply(x, grid_sizes, freqs):
26
+ """
27
+ x: [B, L, N, C].
28
+ grid_sizes: [B, 3].
29
+ freqs: [M, C // 2].
30
+ """
31
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
32
+ # split freqs
33
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
34
+
35
+ # loop over samples
36
+ output = []
37
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
38
+ seq_len = f * h * w
39
+
40
+ # precompute multipliers
41
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
42
+ s, n, -1, 2))
43
+ freqs_i = torch.cat([
44
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
45
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
46
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
47
+ ],
48
+ dim=-1).reshape(seq_len, 1, -1)
49
+
50
+ # apply rotary embedding
51
+ sp_size = get_sequence_parallel_world_size()
52
+ sp_rank = get_sequence_parallel_rank()
53
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
54
+ s_per_rank = s
55
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
56
+ s_per_rank), :, :]
57
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
58
+ x_i = torch.cat([x_i, x[i, s:]])
59
+
60
+ # append to collection
61
+ output.append(x_i)
62
+ return torch.stack(output).float()
63
+
64
+
65
+ def usp_dit_forward(
66
+ self,
67
+ x,
68
+ t,
69
+ context,
70
+ seq_len,
71
+ clip_fea=None,
72
+ y=None,
73
+ guidance=None
74
+ ):
75
+ """
76
+ x: A list of videos each with shape [C, T, H, W].
77
+ t: [B].
78
+ context: A list of text embeddings each with shape [L, C].
79
+ """
80
+ if self.model_type == 'i2v':
81
+ assert clip_fea is not None and y is not None
82
+ # params
83
+ device = self.patch_embedding.weight.device
84
+ if self.freqs.device != device:
85
+ self.freqs = self.freqs.to(device)
86
+
87
+ if y is not None:
88
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
89
+
90
+ # embeddings
91
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
92
+ grid_sizes = torch.stack(
93
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
94
+ x = [u.flatten(2).transpose(1, 2) for u in x]
95
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
96
+ assert seq_lens.max() <= seq_len
97
+ x = torch.cat([
98
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
99
+ for u in x
100
+ ])
101
+
102
+ # time embeddings
103
+ with amp.autocast("cuda", dtype=torch.float32):
104
+ e = self.time_embedding(
105
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
106
+
107
+ if guidance is not None and self.guidance_embedding is not None:
108
+ guidance_input = sinusoidal_embedding_1d(self.freq_dim, guidance).float()
109
+ guidance_emb = self.guidance_embedding(guidance_input)
110
+ e = e + guidance_emb
111
+
112
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
113
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
114
+
115
+ # context
116
+ context_lens = None
117
+ context = self.text_embedding(
118
+ torch.stack([
119
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
120
+ for u in context
121
+ ]))
122
+
123
+ if clip_fea is not None:
124
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
125
+ context = torch.concat([context_clip, context], dim=1)
126
+
127
+ # arguments
128
+ kwargs = dict(
129
+ e=e0,
130
+ seq_lens=seq_lens,
131
+ grid_sizes=grid_sizes,
132
+ freqs=self.freqs,
133
+ context=context,
134
+ context_lens=context_lens)
135
+
136
+ # Context Parallel
137
+ x = torch.chunk(
138
+ x, get_sequence_parallel_world_size(),
139
+ dim=1)[get_sequence_parallel_rank()]
140
+
141
+ for block in self.blocks:
142
+ x = block(x, **kwargs)
143
+
144
+ # head
145
+ x = self.head(x, e)
146
+
147
+ # Context Parallel
148
+ x = get_sp_group().all_gather(x, dim=1)
149
+
150
+ # unpatchify
151
+ x = self.unpatchify(x, grid_sizes)
152
+ return [u.float() for u in x]
153
+
154
+
155
+ def usp_attn_forward(self,
156
+ x,
157
+ seq_lens,
158
+ grid_sizes,
159
+ freqs,
160
+ dtype=torch.bfloat16):
161
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
162
+ half_dtypes = (torch.float16, torch.bfloat16)
163
+
164
+ def half(x):
165
+ return x if x.dtype in half_dtypes else x.to(dtype)
166
+
167
+ # query, key, value function
168
+ def qkv_fn(x):
169
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
170
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
171
+ v = self.v(x).view(b, s, n, d)
172
+ return q, k, v
173
+
174
+ q, k, v = qkv_fn(x)
175
+ q = rope_apply(q, grid_sizes, freqs)
176
+ k = rope_apply(k, grid_sizes, freqs)
177
+
178
+ # TODO: We should use unpaded q,k,v for attention.
179
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
180
+ # if k_lens is not None:
181
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
182
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
183
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
184
+
185
+ x = xFuserLongContextAttention()(
186
+ None,
187
+ query=half(q),
188
+ key=half(k),
189
+ value=half(v),
190
+ window_size=self.window_size)
191
+
192
+ # TODO: padding after attention.
193
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
194
+
195
+ # output
196
+ x = x.flatten(2)
197
+ x = self.o(x)
198
+ return x
wan/distributed/xdit_context_parallel_bk.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.cuda.amp as amp
4
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size,
6
+ get_sp_group)
7
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8
+
9
+ from ..modules.model import sinusoidal_embedding_1d
10
+
11
+
12
+ def pad_freqs(original_tensor, target_len):
13
+ seq_len, s1, s2 = original_tensor.shape
14
+ pad_size = target_len - seq_len
15
+ padding_tensor = torch.ones(
16
+ pad_size,
17
+ s1,
18
+ s2,
19
+ dtype=original_tensor.dtype,
20
+ device=original_tensor.device)
21
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
22
+ return padded_tensor
23
+
24
+
25
+ @amp.autocast(enabled=False)
26
+ def rope_apply(x, grid_sizes, freqs):
27
+ """
28
+ x: [B, L, N, C].
29
+ grid_sizes: [B, 3].
30
+ freqs: [M, C // 2].
31
+ """
32
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
33
+ # split freqs
34
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
35
+
36
+ # loop over samples
37
+ output = []
38
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
39
+ seq_len = f * h * w
40
+
41
+ # precompute multipliers
42
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
43
+ s, n, -1, 2))
44
+ freqs_i = torch.cat([
45
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
46
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
47
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
48
+ ],
49
+ dim=-1).reshape(seq_len, 1, -1)
50
+
51
+ # apply rotary embedding
52
+ sp_size = get_sequence_parallel_world_size()
53
+ sp_rank = get_sequence_parallel_rank()
54
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
55
+ s_per_rank = s
56
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
57
+ s_per_rank), :, :]
58
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
59
+ x_i = torch.cat([x_i, x[i, s:]])
60
+
61
+ # append to collection
62
+ output.append(x_i)
63
+ return torch.stack(output).float()
64
+
65
+
66
+ def usp_dit_forward(
67
+ self,
68
+ x,
69
+ t,
70
+ context,
71
+ seq_len,
72
+ clip_fea=None,
73
+ y=None,
74
+ ):
75
+ """
76
+ x: A list of videos each with shape [C, T, H, W].
77
+ t: [B].
78
+ context: A list of text embeddings each with shape [L, C].
79
+ """
80
+ if self.model_type == 'i2v':
81
+ assert clip_fea is not None and y is not None
82
+ # params
83
+ device = self.patch_embedding.weight.device
84
+ if self.freqs.device != device:
85
+ self.freqs = self.freqs.to(device)
86
+
87
+ if y is not None:
88
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
89
+
90
+ # embeddings
91
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
92
+ grid_sizes = torch.stack(
93
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
94
+ x = [u.flatten(2).transpose(1, 2) for u in x]
95
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
96
+ assert seq_lens.max() <= seq_len
97
+ x = torch.cat([
98
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
99
+ for u in x
100
+ ])
101
+
102
+ # time embeddings
103
+ with amp.autocast(dtype=torch.float32):
104
+ e = self.time_embedding(
105
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
106
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
107
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
108
+
109
+ # context
110
+ context_lens = None
111
+ context = self.text_embedding(
112
+ torch.stack([
113
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
114
+ for u in context
115
+ ]))
116
+
117
+ if clip_fea is not None:
118
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
119
+ context = torch.concat([context_clip, context], dim=1)
120
+
121
+ # arguments
122
+ kwargs = dict(
123
+ e=e0,
124
+ seq_lens=seq_lens,
125
+ grid_sizes=grid_sizes,
126
+ freqs=self.freqs,
127
+ context=context,
128
+ context_lens=context_lens)
129
+
130
+ # Context Parallel
131
+ x = torch.chunk(
132
+ x, get_sequence_parallel_world_size(),
133
+ dim=1)[get_sequence_parallel_rank()]
134
+
135
+ for block in self.blocks:
136
+ x = block(x, **kwargs)
137
+
138
+ # head
139
+ x = self.head(x, e)
140
+
141
+ # Context Parallel
142
+ x = get_sp_group().all_gather(x, dim=1)
143
+
144
+ # unpatchify
145
+ x = self.unpatchify(x, grid_sizes)
146
+ return [u.float() for u in x]
147
+
148
+
149
+ def usp_attn_forward(self,
150
+ x,
151
+ seq_lens,
152
+ grid_sizes,
153
+ freqs,
154
+ dtype=torch.bfloat16):
155
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
156
+ half_dtypes = (torch.float16, torch.bfloat16)
157
+
158
+ def half(x):
159
+ return x if x.dtype in half_dtypes else x.to(dtype)
160
+
161
+ # query, key, value function
162
+ def qkv_fn(x):
163
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
164
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
165
+ v = self.v(x).view(b, s, n, d)
166
+ return q, k, v
167
+
168
+ q, k, v = qkv_fn(x)
169
+ q = rope_apply(q, grid_sizes, freqs)
170
+ k = rope_apply(k, grid_sizes, freqs)
171
+
172
+ # TODO: We should use unpaded q,k,v for attention.
173
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
174
+ # if k_lens is not None:
175
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
176
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
177
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
178
+
179
+ x = xFuserLongContextAttention()(
180
+ None,
181
+ query=half(q),
182
+ key=half(k),
183
+ value=half(v),
184
+ window_size=self.window_size)
185
+
186
+ # TODO: padding after attention.
187
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
188
+
189
+ # output
190
+ x = x.flatten(2)
191
+ x = self.o(x)
192
+ return x
wan/modules/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import flash_attention
2
+ from .model import WanModel
3
+ from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
4
+ from .tokenizers import HuggingfaceTokenizer
5
+ from .vae import WanVAE
6
+
7
+ __all__ = [
8
+ 'WanVAE',
9
+ 'WanModel',
10
+ 'T5Model',
11
+ 'T5Encoder',
12
+ 'T5Decoder',
13
+ 'T5EncoderModel',
14
+ 'HuggingfaceTokenizer',
15
+ 'flash_attention',
16
+ ]
wan/modules/attention.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+
4
+ try:
5
+ import flash_attn_interface
6
+ FLASH_ATTN_3_AVAILABLE = True
7
+ except ModuleNotFoundError:
8
+ FLASH_ATTN_3_AVAILABLE = False
9
+
10
+ try:
11
+ import flash_attn
12
+ FLASH_ATTN_2_AVAILABLE = True
13
+ except ModuleNotFoundError:
14
+ FLASH_ATTN_2_AVAILABLE = False
15
+
16
+ import warnings
17
+
18
+ __all__ = [
19
+ 'flash_attention',
20
+ 'attention',
21
+ ]
22
+
23
+
24
+ def flash_attention(
25
+ q,
26
+ k,
27
+ v,
28
+ q_lens=None,
29
+ k_lens=None,
30
+ dropout_p=0.,
31
+ softmax_scale=None,
32
+ q_scale=None,
33
+ causal=False,
34
+ window_size=(-1, -1),
35
+ deterministic=False,
36
+ dtype=torch.bfloat16,
37
+ version=None,
38
+ ):
39
+ """
40
+ q: [B, Lq, Nq, C1].
41
+ k: [B, Lk, Nk, C1].
42
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
43
+ q_lens: [B].
44
+ k_lens: [B].
45
+ dropout_p: float. Dropout probability.
46
+ softmax_scale: float. The scaling of QK^T before applying softmax.
47
+ causal: bool. Whether to apply causal attention mask.
48
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
49
+ deterministic: bool. If True, slightly slower and uses more memory.
50
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
51
+ """
52
+ half_dtypes = (torch.float16, torch.bfloat16)
53
+ assert dtype in half_dtypes
54
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
55
+
56
+ # params
57
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
58
+
59
+ def half(x):
60
+ return x if x.dtype in half_dtypes else x.to(dtype)
61
+
62
+ # preprocess query
63
+ if q_lens is None:
64
+ q = half(q.flatten(0, 1))
65
+ q_lens = torch.tensor(
66
+ [lq] * b, dtype=torch.int32).to(
67
+ device=q.device, non_blocking=True)
68
+ else:
69
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
70
+
71
+ # preprocess key, value
72
+ if k_lens is None:
73
+ k = half(k.flatten(0, 1))
74
+ v = half(v.flatten(0, 1))
75
+ k_lens = torch.tensor(
76
+ [lk] * b, dtype=torch.int32).to(
77
+ device=k.device, non_blocking=True)
78
+ else:
79
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
80
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
81
+
82
+ q = q.to(v.dtype)
83
+ k = k.to(v.dtype)
84
+
85
+ if q_scale is not None:
86
+ q = q * q_scale
87
+
88
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
89
+ warnings.warn(
90
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
91
+ )
92
+
93
+ # apply attention
94
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
95
+ # Note: dropout_p, window_size are not supported in FA3 now.
96
+ x = flash_attn_interface.flash_attn_varlen_func(
97
+ q=q,
98
+ k=k,
99
+ v=v,
100
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
101
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
102
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
103
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
104
+ seqused_q=None,
105
+ seqused_k=None,
106
+ max_seqlen_q=lq,
107
+ max_seqlen_k=lk,
108
+ softmax_scale=softmax_scale,
109
+ causal=causal,
110
+ deterministic=deterministic)[0].unflatten(0, (b, lq))
111
+ else:
112
+ assert FLASH_ATTN_2_AVAILABLE
113
+ x = flash_attn.flash_attn_varlen_func(
114
+ q=q,
115
+ k=k,
116
+ v=v,
117
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
118
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
119
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
120
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
121
+ max_seqlen_q=lq,
122
+ max_seqlen_k=lk,
123
+ dropout_p=dropout_p,
124
+ softmax_scale=softmax_scale,
125
+ causal=causal,
126
+ window_size=window_size,
127
+ deterministic=deterministic).unflatten(0, (b, lq))
128
+
129
+ # output
130
+ return x.type(out_dtype)
131
+
132
+
133
+ def attention(
134
+ q,
135
+ k,
136
+ v,
137
+ q_lens=None,
138
+ k_lens=None,
139
+ dropout_p=0.,
140
+ softmax_scale=None,
141
+ q_scale=None,
142
+ causal=False,
143
+ window_size=(-1, -1),
144
+ deterministic=False,
145
+ dtype=torch.bfloat16,
146
+ fa_version=None,
147
+ ):
148
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
149
+ return flash_attention(
150
+ q=q,
151
+ k=k,
152
+ v=v,
153
+ q_lens=q_lens,
154
+ k_lens=k_lens,
155
+ dropout_p=dropout_p,
156
+ softmax_scale=softmax_scale,
157
+ q_scale=q_scale,
158
+ causal=causal,
159
+ window_size=window_size,
160
+ deterministic=deterministic,
161
+ dtype=dtype,
162
+ version=fa_version,
163
+ )
164
+ else:
165
+ if q_lens is not None or k_lens is not None:
166
+ warnings.warn(
167
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
168
+ )
169
+ attn_mask = None
170
+
171
+ q = q.transpose(1, 2).to(dtype)
172
+ k = k.transpose(1, 2).to(dtype)
173
+ v = v.transpose(1, 2).to(dtype)
174
+
175
+ out = torch.nn.functional.scaled_dot_product_attention(
176
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
177
+
178
+ out = out.transpose(1, 2).contiguous()
179
+ return out
wan/modules/clip.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as T
10
+
11
+ from .attention import flash_attention
12
+ from .tokenizers import HuggingfaceTokenizer
13
+ from .xlm_roberta import XLMRoberta
14
+
15
+ __all__ = [
16
+ 'XLMRobertaCLIP',
17
+ 'clip_xlm_roberta_vit_h_14',
18
+ 'CLIPModel',
19
+ ]
20
+
21
+
22
+ def pos_interpolate(pos, seq_len):
23
+ if pos.size(1) == seq_len:
24
+ return pos
25
+ else:
26
+ src_grid = int(math.sqrt(pos.size(1)))
27
+ tar_grid = int(math.sqrt(seq_len))
28
+ n = pos.size(1) - src_grid * src_grid
29
+ return torch.cat([
30
+ pos[:, :n],
31
+ F.interpolate(
32
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
33
+ 0, 3, 1, 2),
34
+ size=(tar_grid, tar_grid),
35
+ mode='bicubic',
36
+ align_corners=False).flatten(2).transpose(1, 2)
37
+ ],
38
+ dim=1)
39
+
40
+
41
+ class QuickGELU(nn.Module):
42
+
43
+ def forward(self, x):
44
+ return x * torch.sigmoid(1.702 * x)
45
+
46
+
47
+ class LayerNorm(nn.LayerNorm):
48
+
49
+ def forward(self, x):
50
+ return super().forward(x.float()).type_as(x)
51
+
52
+
53
+ class SelfAttention(nn.Module):
54
+
55
+ def __init__(self,
56
+ dim,
57
+ num_heads,
58
+ causal=False,
59
+ attn_dropout=0.0,
60
+ proj_dropout=0.0):
61
+ assert dim % num_heads == 0
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.num_heads = num_heads
65
+ self.head_dim = dim // num_heads
66
+ self.causal = causal
67
+ self.attn_dropout = attn_dropout
68
+ self.proj_dropout = proj_dropout
69
+
70
+ # layers
71
+ self.to_qkv = nn.Linear(dim, dim * 3)
72
+ self.proj = nn.Linear(dim, dim)
73
+
74
+ def forward(self, x):
75
+ """
76
+ x: [B, L, C].
77
+ """
78
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
79
+
80
+ # compute query, key, value
81
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
82
+
83
+ # compute attention
84
+ p = self.attn_dropout if self.training else 0.0
85
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
86
+ x = x.reshape(b, s, c)
87
+
88
+ # output
89
+ x = self.proj(x)
90
+ x = F.dropout(x, self.proj_dropout, self.training)
91
+ return x
92
+
93
+
94
+ class SwiGLU(nn.Module):
95
+
96
+ def __init__(self, dim, mid_dim):
97
+ super().__init__()
98
+ self.dim = dim
99
+ self.mid_dim = mid_dim
100
+
101
+ # layers
102
+ self.fc1 = nn.Linear(dim, mid_dim)
103
+ self.fc2 = nn.Linear(dim, mid_dim)
104
+ self.fc3 = nn.Linear(mid_dim, dim)
105
+
106
+ def forward(self, x):
107
+ x = F.silu(self.fc1(x)) * self.fc2(x)
108
+ x = self.fc3(x)
109
+ return x
110
+
111
+
112
+ class AttentionBlock(nn.Module):
113
+
114
+ def __init__(self,
115
+ dim,
116
+ mlp_ratio,
117
+ num_heads,
118
+ post_norm=False,
119
+ causal=False,
120
+ activation='quick_gelu',
121
+ attn_dropout=0.0,
122
+ proj_dropout=0.0,
123
+ norm_eps=1e-5):
124
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
125
+ super().__init__()
126
+ self.dim = dim
127
+ self.mlp_ratio = mlp_ratio
128
+ self.num_heads = num_heads
129
+ self.post_norm = post_norm
130
+ self.causal = causal
131
+ self.norm_eps = norm_eps
132
+
133
+ # layers
134
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
135
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
136
+ proj_dropout)
137
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
138
+ if activation == 'swi_glu':
139
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
140
+ else:
141
+ self.mlp = nn.Sequential(
142
+ nn.Linear(dim, int(dim * mlp_ratio)),
143
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
144
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
145
+
146
+ def forward(self, x):
147
+ if self.post_norm:
148
+ x = x + self.norm1(self.attn(x))
149
+ x = x + self.norm2(self.mlp(x))
150
+ else:
151
+ x = x + self.attn(self.norm1(x))
152
+ x = x + self.mlp(self.norm2(x))
153
+ return x
154
+
155
+
156
+ class AttentionPool(nn.Module):
157
+
158
+ def __init__(self,
159
+ dim,
160
+ mlp_ratio,
161
+ num_heads,
162
+ activation='gelu',
163
+ proj_dropout=0.0,
164
+ norm_eps=1e-5):
165
+ assert dim % num_heads == 0
166
+ super().__init__()
167
+ self.dim = dim
168
+ self.mlp_ratio = mlp_ratio
169
+ self.num_heads = num_heads
170
+ self.head_dim = dim // num_heads
171
+ self.proj_dropout = proj_dropout
172
+ self.norm_eps = norm_eps
173
+
174
+ # layers
175
+ gain = 1.0 / math.sqrt(dim)
176
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
177
+ self.to_q = nn.Linear(dim, dim)
178
+ self.to_kv = nn.Linear(dim, dim * 2)
179
+ self.proj = nn.Linear(dim, dim)
180
+ self.norm = LayerNorm(dim, eps=norm_eps)
181
+ self.mlp = nn.Sequential(
182
+ nn.Linear(dim, int(dim * mlp_ratio)),
183
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
184
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
185
+
186
+ def forward(self, x):
187
+ """
188
+ x: [B, L, C].
189
+ """
190
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
191
+
192
+ # compute query, key, value
193
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
194
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
195
+
196
+ # compute attention
197
+ x = flash_attention(q, k, v, version=2)
198
+ x = x.reshape(b, 1, c)
199
+
200
+ # output
201
+ x = self.proj(x)
202
+ x = F.dropout(x, self.proj_dropout, self.training)
203
+
204
+ # mlp
205
+ x = x + self.mlp(self.norm(x))
206
+ return x[:, 0]
207
+
208
+
209
+ class VisionTransformer(nn.Module):
210
+
211
+ def __init__(self,
212
+ image_size=224,
213
+ patch_size=16,
214
+ dim=768,
215
+ mlp_ratio=4,
216
+ out_dim=512,
217
+ num_heads=12,
218
+ num_layers=12,
219
+ pool_type='token',
220
+ pre_norm=True,
221
+ post_norm=False,
222
+ activation='quick_gelu',
223
+ attn_dropout=0.0,
224
+ proj_dropout=0.0,
225
+ embedding_dropout=0.0,
226
+ norm_eps=1e-5):
227
+ if image_size % patch_size != 0:
228
+ print(
229
+ '[WARNING] image_size is not divisible by patch_size',
230
+ flush=True)
231
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
232
+ out_dim = out_dim or dim
233
+ super().__init__()
234
+ self.image_size = image_size
235
+ self.patch_size = patch_size
236
+ self.num_patches = (image_size // patch_size)**2
237
+ self.dim = dim
238
+ self.mlp_ratio = mlp_ratio
239
+ self.out_dim = out_dim
240
+ self.num_heads = num_heads
241
+ self.num_layers = num_layers
242
+ self.pool_type = pool_type
243
+ self.post_norm = post_norm
244
+ self.norm_eps = norm_eps
245
+
246
+ # embeddings
247
+ gain = 1.0 / math.sqrt(dim)
248
+ self.patch_embedding = nn.Conv2d(
249
+ 3,
250
+ dim,
251
+ kernel_size=patch_size,
252
+ stride=patch_size,
253
+ bias=not pre_norm)
254
+ if pool_type in ('token', 'token_fc'):
255
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
256
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
257
+ 1, self.num_patches +
258
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
259
+ self.dropout = nn.Dropout(embedding_dropout)
260
+
261
+ # transformer
262
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
263
+ self.transformer = nn.Sequential(*[
264
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
265
+ activation, attn_dropout, proj_dropout, norm_eps)
266
+ for _ in range(num_layers)
267
+ ])
268
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
269
+
270
+ # head
271
+ if pool_type == 'token':
272
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
273
+ elif pool_type == 'token_fc':
274
+ self.head = nn.Linear(dim, out_dim)
275
+ elif pool_type == 'attn_pool':
276
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
277
+ proj_dropout, norm_eps)
278
+
279
+ def forward(self, x, interpolation=False, use_31_block=False):
280
+ b = x.size(0)
281
+
282
+ # embeddings
283
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
284
+ if self.pool_type in ('token', 'token_fc'):
285
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
286
+ if interpolation:
287
+ e = pos_interpolate(self.pos_embedding, x.size(1))
288
+ else:
289
+ e = self.pos_embedding
290
+ x = self.dropout(x + e)
291
+ if self.pre_norm is not None:
292
+ x = self.pre_norm(x)
293
+
294
+ # transformer
295
+ if use_31_block:
296
+ x = self.transformer[:-1](x)
297
+ return x
298
+ else:
299
+ x = self.transformer(x)
300
+ return x
301
+
302
+
303
+ class XLMRobertaWithHead(XLMRoberta):
304
+
305
+ def __init__(self, **kwargs):
306
+ self.out_dim = kwargs.pop('out_dim')
307
+ super().__init__(**kwargs)
308
+
309
+ # head
310
+ mid_dim = (self.dim + self.out_dim) // 2
311
+ self.head = nn.Sequential(
312
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
313
+ nn.Linear(mid_dim, self.out_dim, bias=False))
314
+
315
+ def forward(self, ids):
316
+ # xlm-roberta
317
+ x = super().forward(ids)
318
+
319
+ # average pooling
320
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
321
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
322
+
323
+ # head
324
+ x = self.head(x)
325
+ return x
326
+
327
+
328
+ class XLMRobertaCLIP(nn.Module):
329
+
330
+ def __init__(self,
331
+ embed_dim=1024,
332
+ image_size=224,
333
+ patch_size=14,
334
+ vision_dim=1280,
335
+ vision_mlp_ratio=4,
336
+ vision_heads=16,
337
+ vision_layers=32,
338
+ vision_pool='token',
339
+ vision_pre_norm=True,
340
+ vision_post_norm=False,
341
+ activation='gelu',
342
+ vocab_size=250002,
343
+ max_text_len=514,
344
+ type_size=1,
345
+ pad_id=1,
346
+ text_dim=1024,
347
+ text_heads=16,
348
+ text_layers=24,
349
+ text_post_norm=True,
350
+ text_dropout=0.1,
351
+ attn_dropout=0.0,
352
+ proj_dropout=0.0,
353
+ embedding_dropout=0.0,
354
+ norm_eps=1e-5):
355
+ super().__init__()
356
+ self.embed_dim = embed_dim
357
+ self.image_size = image_size
358
+ self.patch_size = patch_size
359
+ self.vision_dim = vision_dim
360
+ self.vision_mlp_ratio = vision_mlp_ratio
361
+ self.vision_heads = vision_heads
362
+ self.vision_layers = vision_layers
363
+ self.vision_pre_norm = vision_pre_norm
364
+ self.vision_post_norm = vision_post_norm
365
+ self.activation = activation
366
+ self.vocab_size = vocab_size
367
+ self.max_text_len = max_text_len
368
+ self.type_size = type_size
369
+ self.pad_id = pad_id
370
+ self.text_dim = text_dim
371
+ self.text_heads = text_heads
372
+ self.text_layers = text_layers
373
+ self.text_post_norm = text_post_norm
374
+ self.norm_eps = norm_eps
375
+
376
+ # models
377
+ self.visual = VisionTransformer(
378
+ image_size=image_size,
379
+ patch_size=patch_size,
380
+ dim=vision_dim,
381
+ mlp_ratio=vision_mlp_ratio,
382
+ out_dim=embed_dim,
383
+ num_heads=vision_heads,
384
+ num_layers=vision_layers,
385
+ pool_type=vision_pool,
386
+ pre_norm=vision_pre_norm,
387
+ post_norm=vision_post_norm,
388
+ activation=activation,
389
+ attn_dropout=attn_dropout,
390
+ proj_dropout=proj_dropout,
391
+ embedding_dropout=embedding_dropout,
392
+ norm_eps=norm_eps)
393
+ self.textual = XLMRobertaWithHead(
394
+ vocab_size=vocab_size,
395
+ max_seq_len=max_text_len,
396
+ type_size=type_size,
397
+ pad_id=pad_id,
398
+ dim=text_dim,
399
+ out_dim=embed_dim,
400
+ num_heads=text_heads,
401
+ num_layers=text_layers,
402
+ post_norm=text_post_norm,
403
+ dropout=text_dropout)
404
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
405
+
406
+ def forward(self, imgs, txt_ids):
407
+ """
408
+ imgs: [B, 3, H, W] of torch.float32.
409
+ - mean: [0.48145466, 0.4578275, 0.40821073]
410
+ - std: [0.26862954, 0.26130258, 0.27577711]
411
+ txt_ids: [B, L] of torch.long.
412
+ Encoded by data.CLIPTokenizer.
413
+ """
414
+ xi = self.visual(imgs)
415
+ xt = self.textual(txt_ids)
416
+ return xi, xt
417
+
418
+ def param_groups(self):
419
+ groups = [{
420
+ 'params': [
421
+ p for n, p in self.named_parameters()
422
+ if 'norm' in n or n.endswith('bias')
423
+ ],
424
+ 'weight_decay': 0.0
425
+ }, {
426
+ 'params': [
427
+ p for n, p in self.named_parameters()
428
+ if not ('norm' in n or n.endswith('bias'))
429
+ ]
430
+ }]
431
+ return groups
432
+
433
+
434
+ def _clip(pretrained=False,
435
+ pretrained_name=None,
436
+ model_cls=XLMRobertaCLIP,
437
+ return_transforms=False,
438
+ return_tokenizer=False,
439
+ tokenizer_padding='eos',
440
+ dtype=torch.float32,
441
+ device='cpu',
442
+ **kwargs):
443
+ # init a model on device
444
+ with torch.device(device):
445
+ model = model_cls(**kwargs)
446
+
447
+ # set device
448
+ model = model.to(dtype=dtype, device=device)
449
+ output = (model,)
450
+
451
+ # init transforms
452
+ if return_transforms:
453
+ # mean and std
454
+ if 'siglip' in pretrained_name.lower():
455
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
456
+ else:
457
+ mean = [0.48145466, 0.4578275, 0.40821073]
458
+ std = [0.26862954, 0.26130258, 0.27577711]
459
+
460
+ # transforms
461
+ transforms = T.Compose([
462
+ T.Resize((model.image_size, model.image_size),
463
+ interpolation=T.InterpolationMode.BICUBIC),
464
+ T.ToTensor(),
465
+ T.Normalize(mean=mean, std=std)
466
+ ])
467
+ output += (transforms,)
468
+ return output[0] if len(output) == 1 else output
469
+
470
+
471
+ def clip_xlm_roberta_vit_h_14(
472
+ pretrained=False,
473
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
474
+ **kwargs):
475
+ cfg = dict(
476
+ embed_dim=1024,
477
+ image_size=224,
478
+ patch_size=14,
479
+ vision_dim=1280,
480
+ vision_mlp_ratio=4,
481
+ vision_heads=16,
482
+ vision_layers=32,
483
+ vision_pool='token',
484
+ activation='gelu',
485
+ vocab_size=250002,
486
+ max_text_len=514,
487
+ type_size=1,
488
+ pad_id=1,
489
+ text_dim=1024,
490
+ text_heads=16,
491
+ text_layers=24,
492
+ text_post_norm=True,
493
+ text_dropout=0.1,
494
+ attn_dropout=0.0,
495
+ proj_dropout=0.0,
496
+ embedding_dropout=0.0)
497
+ cfg.update(**kwargs)
498
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
499
+
500
+
501
+ class CLIPModel:
502
+
503
+ def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
504
+ self.dtype = dtype
505
+ self.device = device
506
+ self.checkpoint_path = checkpoint_path
507
+ self.tokenizer_path = tokenizer_path
508
+
509
+ # init model
510
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
511
+ pretrained=False,
512
+ return_transforms=True,
513
+ return_tokenizer=False,
514
+ dtype=dtype,
515
+ device=device)
516
+ self.model = self.model.eval().requires_grad_(False)
517
+ logging.info(f'loading {checkpoint_path}')
518
+ self.model.load_state_dict(
519
+ torch.load(checkpoint_path, map_location='cpu'))
520
+
521
+ # init tokenizer
522
+ self.tokenizer = HuggingfaceTokenizer(
523
+ name=tokenizer_path,
524
+ seq_len=self.model.max_text_len - 2,
525
+ clean='whitespace')
526
+
527
+ def visual(self, videos):
528
+ # preprocess
529
+ size = (self.model.image_size,) * 2
530
+ videos = torch.cat([
531
+ F.interpolate(
532
+ u.transpose(0, 1),
533
+ size=size,
534
+ mode='bicubic',
535
+ align_corners=False) for u in videos
536
+ ])
537
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
538
+
539
+ # forward
540
+ with torch.amp.autocast("cuda", dtype=self.dtype):
541
+ out = self.model.visual(videos, use_31_block=True)
542
+ return out
wan/modules/model.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ import torch.amp as amp
6
+ import torch.nn as nn
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+
10
+ from .attention import flash_attention
11
+
12
+ __all__ = ['WanModel']
13
+
14
+
15
+ def sinusoidal_embedding_1d(dim, position):
16
+ # preprocess
17
+ assert dim % 2 == 0
18
+ half = dim // 2
19
+ position = position.type(torch.float64)
20
+
21
+ # calculation
22
+ sinusoid = torch.outer(
23
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
24
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
25
+ return x
26
+
27
+
28
+ @amp.autocast("cuda", enabled=False)
29
+ def rope_params(max_seq_len, dim, theta=10000):
30
+ assert dim % 2 == 0
31
+ freqs = torch.outer(
32
+ torch.arange(max_seq_len),
33
+ 1.0 / torch.pow(theta,
34
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
35
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
36
+ return freqs
37
+
38
+
39
+ @amp.autocast("cuda", enabled=False)
40
+ def rope_apply(x, grid_sizes, freqs):
41
+ n, c = x.size(2), x.size(3) // 2
42
+
43
+ # split freqs
44
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
45
+
46
+ # loop over samples
47
+ output = []
48
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
49
+ seq_len = f * h * w
50
+
51
+ # precompute multipliers
52
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
53
+ seq_len, n, -1, 2))
54
+ freqs_i = torch.cat([
55
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
56
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
57
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
58
+ ],
59
+ dim=-1).reshape(seq_len, 1, -1)
60
+
61
+ # apply rotary embedding
62
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
63
+ x_i = torch.cat([x_i, x[i, seq_len:]])
64
+
65
+ # append to collection
66
+ output.append(x_i)
67
+ return torch.stack(output).float()
68
+
69
+
70
+ class WanRMSNorm(nn.Module):
71
+
72
+ def __init__(self, dim, eps=1e-5):
73
+ super().__init__()
74
+ self.dim = dim
75
+ self.eps = eps
76
+ self.weight = nn.Parameter(torch.ones(dim))
77
+
78
+ def forward(self, x):
79
+ r"""
80
+ Args:
81
+ x(Tensor): Shape [B, L, C]
82
+ """
83
+ return self._norm(x.float()).type_as(x) * self.weight
84
+
85
+ def _norm(self, x):
86
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
87
+
88
+
89
+ class WanLayerNorm(nn.LayerNorm):
90
+
91
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
92
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
93
+
94
+ def forward(self, x):
95
+ r"""
96
+ Args:
97
+ x(Tensor): Shape [B, L, C]
98
+ """
99
+ return super().forward(x.float()).type_as(x)
100
+
101
+
102
+ class WanSelfAttention(nn.Module):
103
+
104
+ def __init__(self,
105
+ dim,
106
+ num_heads,
107
+ window_size=(-1, -1),
108
+ qk_norm=True,
109
+ eps=1e-6):
110
+ assert dim % num_heads == 0
111
+ super().__init__()
112
+ self.dim = dim
113
+ self.num_heads = num_heads
114
+ self.head_dim = dim // num_heads
115
+ self.window_size = window_size
116
+ self.qk_norm = qk_norm
117
+ self.eps = eps
118
+
119
+ # layers
120
+ self.q = nn.Linear(dim, dim)
121
+ self.k = nn.Linear(dim, dim)
122
+ self.v = nn.Linear(dim, dim)
123
+ self.o = nn.Linear(dim, dim)
124
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
125
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
126
+
127
+ def forward(self, x, seq_lens, grid_sizes, freqs):
128
+ r"""
129
+ Args:
130
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
131
+ seq_lens(Tensor): Shape [B]
132
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
133
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
134
+ """
135
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
136
+
137
+ # query, key, value function
138
+ def qkv_fn(x):
139
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
140
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
141
+ v = self.v(x).view(b, s, n, d)
142
+ return q, k, v
143
+
144
+ q, k, v = qkv_fn(x)
145
+
146
+ x = flash_attention(
147
+ q=rope_apply(q, grid_sizes, freqs),
148
+ k=rope_apply(k, grid_sizes, freqs),
149
+ v=v,
150
+ k_lens=seq_lens,
151
+ window_size=self.window_size)
152
+
153
+ # output
154
+ x = x.flatten(2)
155
+ x = self.o(x)
156
+ return x
157
+
158
+
159
+ class WanT2VCrossAttention(WanSelfAttention):
160
+
161
+ def forward(self, x, context, context_lens):
162
+ r"""
163
+ Args:
164
+ x(Tensor): Shape [B, L1, C]
165
+ context(Tensor): Shape [B, L2, C]
166
+ context_lens(Tensor): Shape [B]
167
+ """
168
+ b, n, d = x.size(0), self.num_heads, self.head_dim
169
+
170
+ # compute query, key, value
171
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
172
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
173
+ v = self.v(context).view(b, -1, n, d)
174
+
175
+ # compute attention
176
+ x = flash_attention(q, k, v, k_lens=context_lens)
177
+
178
+ # output
179
+ x = x.flatten(2)
180
+ x = self.o(x)
181
+ return x
182
+
183
+
184
+ class WanI2VCrossAttention(WanSelfAttention):
185
+
186
+ def __init__(self,
187
+ dim,
188
+ num_heads,
189
+ window_size=(-1, -1),
190
+ qk_norm=True,
191
+ eps=1e-6):
192
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
193
+
194
+ self.k_img = nn.Linear(dim, dim)
195
+ self.v_img = nn.Linear(dim, dim)
196
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
197
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
198
+
199
+ def forward(self, x, context, context_lens):
200
+ r"""
201
+ Args:
202
+ x(Tensor): Shape [B, L1, C]
203
+ context(Tensor): Shape [B, L2, C]
204
+ context_lens(Tensor): Shape [B]
205
+ """
206
+ context_img = context[:, :257]
207
+ context = context[:, 257:]
208
+ b, n, d = x.size(0), self.num_heads, self.head_dim
209
+
210
+ # compute query, key, value
211
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
212
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
213
+ v = self.v(context).view(b, -1, n, d)
214
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
215
+ v_img = self.v_img(context_img).view(b, -1, n, d)
216
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
217
+ # compute attention
218
+ x = flash_attention(q, k, v, k_lens=context_lens)
219
+
220
+ # output
221
+ x = x.flatten(2)
222
+ img_x = img_x.flatten(2)
223
+ x = x + img_x
224
+ x = self.o(x)
225
+ return x
226
+
227
+
228
+ WAN_CROSSATTENTION_CLASSES = {
229
+ 't2v_cross_attn': WanT2VCrossAttention,
230
+ 'i2v_cross_attn': WanI2VCrossAttention,
231
+ }
232
+
233
+
234
+ class WanAttentionBlock(nn.Module):
235
+
236
+ def __init__(self,
237
+ cross_attn_type,
238
+ dim,
239
+ ffn_dim,
240
+ num_heads,
241
+ window_size=(-1, -1),
242
+ qk_norm=True,
243
+ cross_attn_norm=False,
244
+ eps=1e-6):
245
+ super().__init__()
246
+ self.dim = dim
247
+ self.ffn_dim = ffn_dim
248
+ self.num_heads = num_heads
249
+ self.window_size = window_size
250
+ self.qk_norm = qk_norm
251
+ self.cross_attn_norm = cross_attn_norm
252
+ self.eps = eps
253
+
254
+ # layers
255
+ self.norm1 = WanLayerNorm(dim, eps)
256
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
257
+ eps)
258
+ self.norm3 = WanLayerNorm(
259
+ dim, eps,
260
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
261
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
262
+ num_heads,
263
+ (-1, -1),
264
+ qk_norm,
265
+ eps)
266
+ self.norm2 = WanLayerNorm(dim, eps)
267
+ self.ffn = nn.Sequential(
268
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
269
+ nn.Linear(ffn_dim, dim))
270
+
271
+ # modulation
272
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
273
+
274
+ def forward(
275
+ self,
276
+ x,
277
+ e,
278
+ seq_lens,
279
+ grid_sizes,
280
+ freqs,
281
+ context,
282
+ context_lens,
283
+ ):
284
+ r"""
285
+ Args:
286
+ x(Tensor): Shape [B, L, C]
287
+ e(Tensor): Shape [B, 6, C]
288
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
289
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
290
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
291
+ """
292
+ assert e.dtype == torch.float32
293
+ with amp.autocast("cuda", dtype=torch.float32):
294
+ e = (self.modulation + e).chunk(6, dim=1)
295
+ assert e[0].dtype == torch.float32
296
+
297
+ # self-attention
298
+ y = self.self_attn(
299
+ self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
300
+ freqs)
301
+ with amp.autocast("cuda", dtype=torch.float32):
302
+ x = x + y * e[2]
303
+
304
+ # cross-attention & ffn function
305
+ def cross_attn_ffn(x, context, context_lens, e):
306
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
307
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
308
+ with amp.autocast("cuda", dtype=torch.float32):
309
+ x = x + y * e[5]
310
+ return x
311
+
312
+ x = cross_attn_ffn(x, context, context_lens, e)
313
+ return x
314
+
315
+
316
+ class Head(nn.Module):
317
+
318
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
319
+ super().__init__()
320
+ self.dim = dim
321
+ self.out_dim = out_dim
322
+ self.patch_size = patch_size
323
+ self.eps = eps
324
+
325
+ # layers
326
+ out_dim = math.prod(patch_size) * out_dim
327
+ self.norm = WanLayerNorm(dim, eps)
328
+ self.head = nn.Linear(dim, out_dim)
329
+
330
+ # modulation
331
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
332
+
333
+ def forward(self, x, e):
334
+ r"""
335
+ Args:
336
+ x(Tensor): Shape [B, L1, C]
337
+ e(Tensor): Shape [B, C]
338
+ """
339
+ assert e.dtype == torch.float32
340
+ with amp.autocast("cuda", dtype=torch.float32):
341
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
342
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
343
+ return x
344
+
345
+
346
+ class MLPProj(torch.nn.Module):
347
+
348
+ def __init__(self, in_dim, out_dim):
349
+ super().__init__()
350
+
351
+ self.proj = torch.nn.Sequential(
352
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
353
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
354
+ torch.nn.LayerNorm(out_dim))
355
+
356
+ def forward(self, image_embeds):
357
+ clip_extra_context_tokens = self.proj(image_embeds)
358
+ return clip_extra_context_tokens
359
+
360
+
361
+ class WanModel(ModelMixin, ConfigMixin):
362
+ r"""
363
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
364
+ """
365
+
366
+ ignore_for_config = [
367
+ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
368
+ ]
369
+ _no_split_modules = ['WanAttentionBlock']
370
+
371
+ @register_to_config
372
+ def __init__(self,
373
+ model_type='t2v',
374
+ patch_size=(1, 2, 2),
375
+ text_len=512,
376
+ in_dim=16,
377
+ dim=2048,
378
+ ffn_dim=8192,
379
+ freq_dim=256,
380
+ text_dim=4096,
381
+ out_dim=16,
382
+ num_heads=16,
383
+ num_layers=32,
384
+ window_size=(-1, -1),
385
+ qk_norm=True,
386
+ cross_attn_norm=True,
387
+ eps=1e-6):
388
+ r"""
389
+ Initialize the diffusion model backbone.
390
+
391
+ Args:
392
+ model_type (`str`, *optional*, defaults to 't2v'):
393
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
394
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
395
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
396
+ text_len (`int`, *optional*, defaults to 512):
397
+ Fixed length for text embeddings
398
+ in_dim (`int`, *optional*, defaults to 16):
399
+ Input video channels (C_in)
400
+ dim (`int`, *optional*, defaults to 2048):
401
+ Hidden dimension of the transformer
402
+ ffn_dim (`int`, *optional*, defaults to 8192):
403
+ Intermediate dimension in feed-forward network
404
+ freq_dim (`int`, *optional*, defaults to 256):
405
+ Dimension for sinusoidal time embeddings
406
+ text_dim (`int`, *optional*, defaults to 4096):
407
+ Input dimension for text embeddings
408
+ out_dim (`int`, *optional*, defaults to 16):
409
+ Output video channels (C_out)
410
+ num_heads (`int`, *optional*, defaults to 16):
411
+ Number of attention heads
412
+ num_layers (`int`, *optional*, defaults to 32):
413
+ Number of transformer blocks
414
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
415
+ Window size for local attention (-1 indicates global attention)
416
+ qk_norm (`bool`, *optional*, defaults to True):
417
+ Enable query/key normalization
418
+ cross_attn_norm (`bool`, *optional*, defaults to False):
419
+ Enable cross-attention normalization
420
+ eps (`float`, *optional*, defaults to 1e-6):
421
+ Epsilon value for normalization layers
422
+ """
423
+
424
+ super().__init__()
425
+
426
+ assert model_type in ['t2v', 'i2v']
427
+ self.model_type = model_type
428
+
429
+ self.patch_size = patch_size
430
+ self.text_len = text_len
431
+ self.in_dim = in_dim
432
+ self.dim = dim
433
+ self.ffn_dim = ffn_dim
434
+ self.freq_dim = freq_dim
435
+ self.text_dim = text_dim
436
+ self.out_dim = out_dim
437
+ self.num_heads = num_heads
438
+ self.num_layers = num_layers
439
+ self.window_size = window_size
440
+ self.qk_norm = qk_norm
441
+ self.cross_attn_norm = cross_attn_norm
442
+ self.eps = eps
443
+
444
+ # embeddings
445
+ self.patch_embedding = nn.Conv3d(
446
+ in_dim,
447
+ dim,
448
+ kernel_size=patch_size,
449
+ stride=patch_size,
450
+ )
451
+ self.text_embedding = nn.Sequential(
452
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
453
+ nn.Linear(dim, dim))
454
+
455
+ self.time_embedding = nn.Sequential(
456
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
457
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
458
+
459
+ # blocks
460
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
461
+ self.blocks = nn.ModuleList([
462
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
463
+ window_size, qk_norm, cross_attn_norm, eps)
464
+ for _ in range(num_layers)
465
+ ])
466
+
467
+ # head
468
+ self.head = Head(dim, out_dim, patch_size, eps)
469
+
470
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
471
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
472
+ d = dim // num_heads
473
+ self.freqs = torch.cat([
474
+ rope_params(1024, d - 4 * (d // 6)),
475
+ rope_params(1024, 2 * (d // 6)),
476
+ rope_params(1024, 2 * (d // 6))
477
+ ],
478
+ dim=1)
479
+
480
+ if model_type == 'i2v':
481
+ self.img_emb = MLPProj(1280, dim)
482
+
483
+ # initialize weights
484
+ self.init_weights()
485
+
486
+ def forward(
487
+ self,
488
+ x,
489
+ t,
490
+ context,
491
+ seq_len,
492
+ clip_fea=None,
493
+ y=None,
494
+ ):
495
+ r"""
496
+ Forward pass through the diffusion model
497
+
498
+ Args:
499
+ x (List[Tensor]):
500
+ List of input video tensors, each with shape [C_in, F, H, W]
501
+ t (Tensor):
502
+ Diffusion timesteps tensor of shape [B]
503
+ context (List[Tensor]):
504
+ List of text embeddings each with shape [L, C]
505
+ seq_len (`int`):
506
+ Maximum sequence length for positional encoding
507
+ clip_fea (Tensor, *optional*):
508
+ CLIP image features for image-to-video mode
509
+ y (List[Tensor], *optional*):
510
+ Conditional video inputs for image-to-video mode, same shape as x
511
+
512
+ Returns:
513
+ List[Tensor]:
514
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
515
+ """
516
+ if self.model_type == 'i2v':
517
+ assert clip_fea is not None and y is not None
518
+ # params
519
+ device = self.patch_embedding.weight.device
520
+ if self.freqs.device != device:
521
+ self.freqs = self.freqs.to(device)
522
+
523
+ if y is not None:
524
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
525
+
526
+ # embeddings
527
+ original_shapes = [u.shape[1:] for u in x] # Store F, H, W
528
+
529
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
530
+ grid_sizes = torch.stack(
531
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
532
+ x = [u.flatten(2).transpose(1, 2) for u in x]
533
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
534
+ assert seq_lens.max() <= seq_len
535
+ x = torch.cat([
536
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
537
+ dim=1) for u in x
538
+ ])
539
+
540
+ # time embeddings
541
+ with amp.autocast("cuda", dtype=torch.float32):
542
+ e = self.time_embedding(
543
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
544
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
545
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
546
+
547
+ # context
548
+ context_lens = None
549
+ context = self.text_embedding(
550
+ torch.stack([
551
+ torch.cat(
552
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
553
+ for u in context
554
+ ]))
555
+
556
+ if clip_fea is not None:
557
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
558
+ context = torch.concat([context_clip, context], dim=1)
559
+
560
+ # arguments
561
+ kwargs = dict(
562
+ e=e0,
563
+ seq_lens=seq_lens,
564
+ grid_sizes=grid_sizes,
565
+ freqs=self.freqs,
566
+ context=context,
567
+ context_lens=context_lens)
568
+
569
+ for block in self.blocks:
570
+ x = block(x, **kwargs)
571
+
572
+ # head
573
+ x = self.head(x, e)
574
+
575
+ # unpatchify
576
+ # x = self.unpatchify(x, grid_sizes, original_shapes=original_shapes)
577
+ x = self.unpatchify(x, grid_sizes)
578
+
579
+ return [u.float() for u in x]
580
+
581
+ def unpatchify(self, x, grid_sizes, original_shapes=None):
582
+ r"""
583
+ Reconstruct video tensors from patch embeddings.
584
+
585
+ Args:
586
+ x (List[Tensor]):
587
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
588
+ grid_sizes (Tensor):
589
+ Original spatial-temporal grid dimensions before patching,
590
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
591
+
592
+ Returns:
593
+ List[Tensor]:
594
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
595
+ """
596
+
597
+ c = self.out_dim
598
+ out = []
599
+ for idx, (u, v) in enumerate(zip(x, grid_sizes.tolist())):
600
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
601
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
602
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
603
+
604
+ if original_shapes is not None:
605
+ original_H = original_shapes[idx][1]
606
+ u = u[:, :, :original_H, :]
607
+ out.append(u)
608
+
609
+ return out
610
+
611
+ def init_weights(self):
612
+ r"""
613
+ Initialize model parameters using Xavier initialization.
614
+ """
615
+
616
+ # basic init
617
+ for m in self.modules():
618
+ if isinstance(m, nn.Linear):
619
+ nn.init.xavier_uniform_(m.weight)
620
+ if m.bias is not None:
621
+ nn.init.zeros_(m.bias)
622
+
623
+ # init embeddings
624
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
625
+ for m in self.text_embedding.modules():
626
+ if isinstance(m, nn.Linear):
627
+ nn.init.normal_(m.weight, std=.02)
628
+ for m in self.time_embedding.modules():
629
+ if isinstance(m, nn.Linear):
630
+ nn.init.normal_(m.weight, std=.02)
631
+
632
+ # init output layer
633
+ nn.init.zeros_(self.head.head.weight)
wan/modules/t5.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.t5.modeling_t5
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .tokenizers import HuggingfaceTokenizer
11
+
12
+ __all__ = [
13
+ 'T5Model',
14
+ 'T5Encoder',
15
+ 'T5Decoder',
16
+ 'T5EncoderModel',
17
+ ]
18
+
19
+
20
+ def fp16_clamp(x):
21
+ if x.dtype == torch.float16 and torch.isinf(x).any():
22
+ clamp = torch.finfo(x.dtype).max - 1000
23
+ x = torch.clamp(x, min=-clamp, max=clamp)
24
+ return x
25
+
26
+
27
+ def init_weights(m):
28
+ if isinstance(m, T5LayerNorm):
29
+ nn.init.ones_(m.weight)
30
+ elif isinstance(m, T5Model):
31
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
32
+ elif isinstance(m, T5FeedForward):
33
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
34
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
35
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
36
+ elif isinstance(m, T5Attention):
37
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
38
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
39
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
40
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
41
+ elif isinstance(m, T5RelativeEmbedding):
42
+ nn.init.normal_(
43
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
44
+
45
+
46
+ class GELU(nn.Module):
47
+
48
+ def forward(self, x):
49
+ return 0.5 * x * (1.0 + torch.tanh(
50
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
51
+
52
+
53
+ class T5LayerNorm(nn.Module):
54
+
55
+ def __init__(self, dim, eps=1e-6):
56
+ super(T5LayerNorm, self).__init__()
57
+ self.dim = dim
58
+ self.eps = eps
59
+ self.weight = nn.Parameter(torch.ones(dim))
60
+
61
+ def forward(self, x):
62
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
63
+ self.eps)
64
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
65
+ x = x.type_as(self.weight)
66
+ return self.weight * x
67
+
68
+
69
+ class T5Attention(nn.Module):
70
+
71
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
72
+ assert dim_attn % num_heads == 0
73
+ super(T5Attention, self).__init__()
74
+ self.dim = dim
75
+ self.dim_attn = dim_attn
76
+ self.num_heads = num_heads
77
+ self.head_dim = dim_attn // num_heads
78
+
79
+ # layers
80
+ self.q = nn.Linear(dim, dim_attn, bias=False)
81
+ self.k = nn.Linear(dim, dim_attn, bias=False)
82
+ self.v = nn.Linear(dim, dim_attn, bias=False)
83
+ self.o = nn.Linear(dim_attn, dim, bias=False)
84
+ self.dropout = nn.Dropout(dropout)
85
+
86
+ def forward(self, x, context=None, mask=None, pos_bias=None):
87
+ """
88
+ x: [B, L1, C].
89
+ context: [B, L2, C] or None.
90
+ mask: [B, L2] or [B, L1, L2] or None.
91
+ """
92
+ # check inputs
93
+ context = x if context is None else context
94
+ b, n, c = x.size(0), self.num_heads, self.head_dim
95
+
96
+ # compute query, key, value
97
+ q = self.q(x).view(b, -1, n, c)
98
+ k = self.k(context).view(b, -1, n, c)
99
+ v = self.v(context).view(b, -1, n, c)
100
+
101
+ # attention bias
102
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
103
+ if pos_bias is not None:
104
+ attn_bias += pos_bias
105
+ if mask is not None:
106
+ assert mask.ndim in [2, 3]
107
+ mask = mask.view(b, 1, 1,
108
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
109
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
110
+
111
+ # compute attention (T5 does not use scaling)
112
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
113
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
114
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
115
+
116
+ # output
117
+ x = x.reshape(b, -1, n * c)
118
+ x = self.o(x)
119
+ x = self.dropout(x)
120
+ return x
121
+
122
+
123
+ class T5FeedForward(nn.Module):
124
+
125
+ def __init__(self, dim, dim_ffn, dropout=0.1):
126
+ super(T5FeedForward, self).__init__()
127
+ self.dim = dim
128
+ self.dim_ffn = dim_ffn
129
+
130
+ # layers
131
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
132
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
133
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
134
+ self.dropout = nn.Dropout(dropout)
135
+
136
+ def forward(self, x):
137
+ x = self.fc1(x) * self.gate(x)
138
+ x = self.dropout(x)
139
+ x = self.fc2(x)
140
+ x = self.dropout(x)
141
+ return x
142
+
143
+
144
+ class T5SelfAttention(nn.Module):
145
+
146
+ def __init__(self,
147
+ dim,
148
+ dim_attn,
149
+ dim_ffn,
150
+ num_heads,
151
+ num_buckets,
152
+ shared_pos=True,
153
+ dropout=0.1):
154
+ super(T5SelfAttention, self).__init__()
155
+ self.dim = dim
156
+ self.dim_attn = dim_attn
157
+ self.dim_ffn = dim_ffn
158
+ self.num_heads = num_heads
159
+ self.num_buckets = num_buckets
160
+ self.shared_pos = shared_pos
161
+
162
+ # layers
163
+ self.norm1 = T5LayerNorm(dim)
164
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
165
+ self.norm2 = T5LayerNorm(dim)
166
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
167
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
168
+ num_buckets, num_heads, bidirectional=True)
169
+
170
+ def forward(self, x, mask=None, pos_bias=None):
171
+ e = pos_bias if self.shared_pos else self.pos_embedding(
172
+ x.size(1), x.size(1))
173
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
174
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
175
+ return x
176
+
177
+
178
+ class T5CrossAttention(nn.Module):
179
+
180
+ def __init__(self,
181
+ dim,
182
+ dim_attn,
183
+ dim_ffn,
184
+ num_heads,
185
+ num_buckets,
186
+ shared_pos=True,
187
+ dropout=0.1):
188
+ super(T5CrossAttention, self).__init__()
189
+ self.dim = dim
190
+ self.dim_attn = dim_attn
191
+ self.dim_ffn = dim_ffn
192
+ self.num_heads = num_heads
193
+ self.num_buckets = num_buckets
194
+ self.shared_pos = shared_pos
195
+
196
+ # layers
197
+ self.norm1 = T5LayerNorm(dim)
198
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
199
+ self.norm2 = T5LayerNorm(dim)
200
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
201
+ self.norm3 = T5LayerNorm(dim)
202
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
203
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
204
+ num_buckets, num_heads, bidirectional=False)
205
+
206
+ def forward(self,
207
+ x,
208
+ mask=None,
209
+ encoder_states=None,
210
+ encoder_mask=None,
211
+ pos_bias=None):
212
+ e = pos_bias if self.shared_pos else self.pos_embedding(
213
+ x.size(1), x.size(1))
214
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
215
+ x = fp16_clamp(x + self.cross_attn(
216
+ self.norm2(x), context=encoder_states, mask=encoder_mask))
217
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
218
+ return x
219
+
220
+
221
+ class T5RelativeEmbedding(nn.Module):
222
+
223
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
224
+ super(T5RelativeEmbedding, self).__init__()
225
+ self.num_buckets = num_buckets
226
+ self.num_heads = num_heads
227
+ self.bidirectional = bidirectional
228
+ self.max_dist = max_dist
229
+
230
+ # layers
231
+ self.embedding = nn.Embedding(num_buckets, num_heads)
232
+
233
+ def forward(self, lq, lk):
234
+ device = self.embedding.weight.device
235
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
236
+ # torch.arange(lq).unsqueeze(1).to(device)
237
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
238
+ torch.arange(lq, device=device).unsqueeze(1)
239
+ rel_pos = self._relative_position_bucket(rel_pos)
240
+ rel_pos_embeds = self.embedding(rel_pos)
241
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
242
+ 0) # [1, N, Lq, Lk]
243
+ return rel_pos_embeds.contiguous()
244
+
245
+ def _relative_position_bucket(self, rel_pos):
246
+ # preprocess
247
+ if self.bidirectional:
248
+ num_buckets = self.num_buckets // 2
249
+ rel_buckets = (rel_pos > 0).long() * num_buckets
250
+ rel_pos = torch.abs(rel_pos)
251
+ else:
252
+ num_buckets = self.num_buckets
253
+ rel_buckets = 0
254
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
255
+
256
+ # embeddings for small and large positions
257
+ max_exact = num_buckets // 2
258
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
259
+ math.log(self.max_dist / max_exact) *
260
+ (num_buckets - max_exact)).long()
261
+ rel_pos_large = torch.min(
262
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
263
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
264
+ return rel_buckets
265
+
266
+
267
+ class T5Encoder(nn.Module):
268
+
269
+ def __init__(self,
270
+ vocab,
271
+ dim,
272
+ dim_attn,
273
+ dim_ffn,
274
+ num_heads,
275
+ num_layers,
276
+ num_buckets,
277
+ shared_pos=True,
278
+ dropout=0.1):
279
+ super(T5Encoder, self).__init__()
280
+ self.dim = dim
281
+ self.dim_attn = dim_attn
282
+ self.dim_ffn = dim_ffn
283
+ self.num_heads = num_heads
284
+ self.num_layers = num_layers
285
+ self.num_buckets = num_buckets
286
+ self.shared_pos = shared_pos
287
+
288
+ # layers
289
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
290
+ else nn.Embedding(vocab, dim)
291
+ self.pos_embedding = T5RelativeEmbedding(
292
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
293
+ self.dropout = nn.Dropout(dropout)
294
+ self.blocks = nn.ModuleList([
295
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
296
+ shared_pos, dropout) for _ in range(num_layers)
297
+ ])
298
+ self.norm = T5LayerNorm(dim)
299
+
300
+ # initialize weights
301
+ self.apply(init_weights)
302
+
303
+ def forward(self, ids, mask=None):
304
+ x = self.token_embedding(ids)
305
+ x = self.dropout(x)
306
+ e = self.pos_embedding(x.size(1),
307
+ x.size(1)) if self.shared_pos else None
308
+ for block in self.blocks:
309
+ x = block(x, mask, pos_bias=e)
310
+ x = self.norm(x)
311
+ x = self.dropout(x)
312
+ return x
313
+
314
+
315
+ class T5Decoder(nn.Module):
316
+
317
+ def __init__(self,
318
+ vocab,
319
+ dim,
320
+ dim_attn,
321
+ dim_ffn,
322
+ num_heads,
323
+ num_layers,
324
+ num_buckets,
325
+ shared_pos=True,
326
+ dropout=0.1):
327
+ super(T5Decoder, self).__init__()
328
+ self.dim = dim
329
+ self.dim_attn = dim_attn
330
+ self.dim_ffn = dim_ffn
331
+ self.num_heads = num_heads
332
+ self.num_layers = num_layers
333
+ self.num_buckets = num_buckets
334
+ self.shared_pos = shared_pos
335
+
336
+ # layers
337
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
338
+ else nn.Embedding(vocab, dim)
339
+ self.pos_embedding = T5RelativeEmbedding(
340
+ num_buckets, num_heads, bidirectional=False) if shared_pos else None
341
+ self.dropout = nn.Dropout(dropout)
342
+ self.blocks = nn.ModuleList([
343
+ T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
344
+ shared_pos, dropout) for _ in range(num_layers)
345
+ ])
346
+ self.norm = T5LayerNorm(dim)
347
+
348
+ # initialize weights
349
+ self.apply(init_weights)
350
+
351
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
352
+ b, s = ids.size()
353
+
354
+ # causal mask
355
+ if mask is None:
356
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
357
+ elif mask.ndim == 2:
358
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
359
+
360
+ # layers
361
+ x = self.token_embedding(ids)
362
+ x = self.dropout(x)
363
+ e = self.pos_embedding(x.size(1),
364
+ x.size(1)) if self.shared_pos else None
365
+ for block in self.blocks:
366
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
367
+ x = self.norm(x)
368
+ x = self.dropout(x)
369
+ return x
370
+
371
+
372
+ class T5Model(nn.Module):
373
+
374
+ def __init__(self,
375
+ vocab_size,
376
+ dim,
377
+ dim_attn,
378
+ dim_ffn,
379
+ num_heads,
380
+ encoder_layers,
381
+ decoder_layers,
382
+ num_buckets,
383
+ shared_pos=True,
384
+ dropout=0.1):
385
+ super(T5Model, self).__init__()
386
+ self.vocab_size = vocab_size
387
+ self.dim = dim
388
+ self.dim_attn = dim_attn
389
+ self.dim_ffn = dim_ffn
390
+ self.num_heads = num_heads
391
+ self.encoder_layers = encoder_layers
392
+ self.decoder_layers = decoder_layers
393
+ self.num_buckets = num_buckets
394
+
395
+ # layers
396
+ self.token_embedding = nn.Embedding(vocab_size, dim)
397
+ self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
398
+ num_heads, encoder_layers, num_buckets,
399
+ shared_pos, dropout)
400
+ self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
401
+ num_heads, decoder_layers, num_buckets,
402
+ shared_pos, dropout)
403
+ self.head = nn.Linear(dim, vocab_size, bias=False)
404
+
405
+ # initialize weights
406
+ self.apply(init_weights)
407
+
408
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
409
+ x = self.encoder(encoder_ids, encoder_mask)
410
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
411
+ x = self.head(x)
412
+ return x
413
+
414
+
415
+ def _t5(name,
416
+ encoder_only=False,
417
+ decoder_only=False,
418
+ return_tokenizer=False,
419
+ tokenizer_kwargs={},
420
+ dtype=torch.float32,
421
+ device='cpu',
422
+ **kwargs):
423
+ # sanity check
424
+ assert not (encoder_only and decoder_only)
425
+
426
+ # params
427
+ if encoder_only:
428
+ model_cls = T5Encoder
429
+ kwargs['vocab'] = kwargs.pop('vocab_size')
430
+ kwargs['num_layers'] = kwargs.pop('encoder_layers')
431
+ _ = kwargs.pop('decoder_layers')
432
+ elif decoder_only:
433
+ model_cls = T5Decoder
434
+ kwargs['vocab'] = kwargs.pop('vocab_size')
435
+ kwargs['num_layers'] = kwargs.pop('decoder_layers')
436
+ _ = kwargs.pop('encoder_layers')
437
+ else:
438
+ model_cls = T5Model
439
+
440
+ # init model
441
+ with torch.device(device):
442
+ model = model_cls(**kwargs)
443
+
444
+ # set device
445
+ model = model.to(dtype=dtype, device=device)
446
+
447
+ # init tokenizer
448
+ if return_tokenizer:
449
+ from .tokenizers import HuggingfaceTokenizer
450
+ tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
451
+ return model, tokenizer
452
+ else:
453
+ return model
454
+
455
+
456
+ def umt5_xxl(**kwargs):
457
+ cfg = dict(
458
+ vocab_size=256384,
459
+ dim=4096,
460
+ dim_attn=4096,
461
+ dim_ffn=10240,
462
+ num_heads=64,
463
+ encoder_layers=24,
464
+ decoder_layers=24,
465
+ num_buckets=32,
466
+ shared_pos=False,
467
+ dropout=0.1)
468
+ cfg.update(**kwargs)
469
+ return _t5('umt5-xxl', **cfg)
470
+
471
+
472
+ class T5EncoderModel:
473
+
474
+ def __init__(
475
+ self,
476
+ text_len,
477
+ dtype=torch.bfloat16,
478
+ device=torch.cuda.current_device(),
479
+ checkpoint_path=None,
480
+ tokenizer_path=None,
481
+ shard_fn=None,
482
+ ):
483
+ self.text_len = text_len
484
+ self.dtype = dtype
485
+ self.device = device
486
+ self.checkpoint_path = checkpoint_path
487
+ self.tokenizer_path = tokenizer_path
488
+
489
+ # init model
490
+ model = umt5_xxl(
491
+ encoder_only=True,
492
+ return_tokenizer=False,
493
+ dtype=dtype,
494
+ device=device).eval().requires_grad_(False)
495
+ logging.info(f'loading {checkpoint_path}')
496
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu', weights_only=True))
497
+ self.model = model
498
+ if shard_fn is not None:
499
+ self.model = shard_fn(self.model, sync_module_states=False)
500
+ else:
501
+ self.model.to(self.device)
502
+ # init tokenizer
503
+ self.tokenizer = HuggingfaceTokenizer(
504
+ name=tokenizer_path, seq_len=text_len, clean='whitespace')
505
+
506
+ def __call__(self, texts, device):
507
+ ids, mask = self.tokenizer(
508
+ texts, return_mask=True, add_special_tokens=True)
509
+ try:
510
+ ids = ids.to(device)
511
+ except Exception as e:
512
+ print(texts)
513
+ print(e)
514
+
515
+ mask = mask.to(device)
516
+ seq_lens = mask.gt(0).sum(dim=1).long()
517
+ context = self.model(ids, mask)
518
+ return [u[:v] for u, v in zip(context, seq_lens)]
wan/modules/tokenizers.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import html
3
+ import string
4
+
5
+ import ftfy
6
+ import regex as re
7
+ from transformers import AutoTokenizer
8
+
9
+ __all__ = ['HuggingfaceTokenizer']
10
+
11
+
12
+ def basic_clean(text):
13
+ text = ftfy.fix_text(text)
14
+ text = html.unescape(html.unescape(text))
15
+ return text.strip()
16
+
17
+
18
+ def whitespace_clean(text):
19
+ text = re.sub(r'\s+', ' ', text)
20
+ text = text.strip()
21
+ return text
22
+
23
+
24
+ def canonicalize(text, keep_punctuation_exact_string=None):
25
+ text = text.replace('_', ' ')
26
+ if keep_punctuation_exact_string:
27
+ text = keep_punctuation_exact_string.join(
28
+ part.translate(str.maketrans('', '', string.punctuation))
29
+ for part in text.split(keep_punctuation_exact_string))
30
+ else:
31
+ text = text.translate(str.maketrans('', '', string.punctuation))
32
+ text = text.lower()
33
+ text = re.sub(r'\s+', ' ', text)
34
+ return text.strip()
35
+
36
+
37
+ class HuggingfaceTokenizer:
38
+
39
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
40
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
41
+ self.name = name
42
+ self.seq_len = seq_len
43
+ self.clean = clean
44
+
45
+ # init tokenizer
46
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47
+ self.vocab_size = self.tokenizer.vocab_size
48
+
49
+ def __call__(self, sequence, **kwargs):
50
+ return_mask = kwargs.pop('return_mask', False)
51
+
52
+ # arguments
53
+ _kwargs = {'return_tensors': 'pt'}
54
+ if self.seq_len is not None:
55
+ _kwargs.update({
56
+ 'padding': 'max_length',
57
+ 'truncation': True,
58
+ 'max_length': self.seq_len
59
+ })
60
+ _kwargs.update(**kwargs)
61
+
62
+ # tokenization
63
+ if isinstance(sequence, str):
64
+ sequence = [sequence]
65
+ if self.clean:
66
+ sequence = [self._clean(u) for u in sequence]
67
+ ids = self.tokenizer(sequence, **_kwargs)
68
+
69
+ # output
70
+ if return_mask:
71
+ return ids.input_ids, ids.attention_mask
72
+ else:
73
+ return ids.input_ids
74
+
75
+ def _clean(self, text):
76
+ if self.clean == 'whitespace':
77
+ text = whitespace_clean(basic_clean(text))
78
+ elif self.clean == 'lower':
79
+ text = whitespace_clean(basic_clean(text)).lower()
80
+ elif self.clean == 'canonicalize':
81
+ text = canonicalize(basic_clean(text))
82
+ return text
wan/modules/vae.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+
4
+ import torch
5
+ import torch.amp as amp
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ __all__ = [
11
+ 'WanVAE',
12
+ ]
13
+
14
+ CACHE_T = 2
15
+
16
+
17
+ class CausalConv3d(nn.Conv3d):
18
+ """
19
+ Causal 3d convolusion.
20
+ """
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
25
+ self.padding[1], 2 * self.padding[0], 0)
26
+ self.padding = (0, 0, 0)
27
+
28
+ def forward(self, x, cache_x=None):
29
+ padding = list(self._padding)
30
+ if cache_x is not None and self._padding[4] > 0:
31
+ cache_x = cache_x.to(x.device)
32
+ x = torch.cat([cache_x, x], dim=2)
33
+ padding[4] -= cache_x.shape[2]
34
+ x = F.pad(x, padding)
35
+
36
+ return super().forward(x)
37
+
38
+
39
+ class RMS_norm(nn.Module):
40
+
41
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
42
+ super().__init__()
43
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
44
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
45
+
46
+ self.channel_first = channel_first
47
+ self.scale = dim**0.5
48
+ self.gamma = nn.Parameter(torch.ones(shape))
49
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
50
+
51
+ def forward(self, x):
52
+ return F.normalize(
53
+ x, dim=(1 if self.channel_first else
54
+ -1)) * self.scale * self.gamma + self.bias
55
+
56
+
57
+ class Upsample(nn.Upsample):
58
+
59
+ def forward(self, x):
60
+ """
61
+ Fix bfloat16 support for nearest neighbor interpolation.
62
+ """
63
+ return super().forward(x.float()).type_as(x)
64
+
65
+
66
+ class Resample(nn.Module):
67
+
68
+ def __init__(self, dim, mode):
69
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
70
+ 'downsample3d')
71
+ super().__init__()
72
+ self.dim = dim
73
+ self.mode = mode
74
+
75
+ # layers
76
+ if mode == 'upsample2d':
77
+ self.resample = nn.Sequential(
78
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
79
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
80
+ elif mode == 'upsample3d':
81
+ self.resample = nn.Sequential(
82
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
83
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
84
+ self.time_conv = CausalConv3d(
85
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
86
+
87
+ elif mode == 'downsample2d':
88
+ self.resample = nn.Sequential(
89
+ nn.ZeroPad2d((0, 1, 0, 1)),
90
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
91
+ elif mode == 'downsample3d':
92
+ self.resample = nn.Sequential(
93
+ nn.ZeroPad2d((0, 1, 0, 1)),
94
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
95
+ self.time_conv = CausalConv3d(
96
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
97
+
98
+ else:
99
+ self.resample = nn.Identity()
100
+
101
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
102
+ b, c, t, h, w = x.size()
103
+ if self.mode == 'upsample3d':
104
+ if feat_cache is not None:
105
+ idx = feat_idx[0]
106
+ if feat_cache[idx] is None:
107
+ feat_cache[idx] = 'Rep'
108
+ feat_idx[0] += 1
109
+ else:
110
+
111
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
112
+ if cache_x.shape[2] < 2 and feat_cache[
113
+ idx] is not None and feat_cache[idx] != 'Rep':
114
+ # cache last frame of last two chunk
115
+ cache_x = torch.cat([
116
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
117
+ cache_x.device), cache_x
118
+ ],
119
+ dim=2)
120
+ if cache_x.shape[2] < 2 and feat_cache[
121
+ idx] is not None and feat_cache[idx] == 'Rep':
122
+ cache_x = torch.cat([
123
+ torch.zeros_like(cache_x).to(cache_x.device),
124
+ cache_x
125
+ ],
126
+ dim=2)
127
+ if feat_cache[idx] == 'Rep':
128
+ x = self.time_conv(x)
129
+ else:
130
+ x = self.time_conv(x, feat_cache[idx])
131
+ feat_cache[idx] = cache_x
132
+ feat_idx[0] += 1
133
+
134
+ x = x.reshape(b, 2, c, t, h, w)
135
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
136
+ 3)
137
+ x = x.reshape(b, c, t * 2, h, w)
138
+ t = x.shape[2]
139
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
140
+ x = self.resample(x)
141
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
142
+
143
+ if self.mode == 'downsample3d':
144
+ if feat_cache is not None:
145
+ idx = feat_idx[0]
146
+ if feat_cache[idx] is None:
147
+ feat_cache[idx] = x.clone()
148
+ feat_idx[0] += 1
149
+ else:
150
+
151
+ cache_x = x[:, :, -1:, :, :].clone()
152
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
153
+ # # cache last frame of last two chunk
154
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
155
+
156
+ x = self.time_conv(
157
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
158
+ feat_cache[idx] = cache_x
159
+ feat_idx[0] += 1
160
+ return x
161
+
162
+ def init_weight(self, conv):
163
+ conv_weight = conv.weight
164
+ nn.init.zeros_(conv_weight)
165
+ c1, c2, t, h, w = conv_weight.size()
166
+ one_matrix = torch.eye(c1, c2)
167
+ init_matrix = one_matrix
168
+ nn.init.zeros_(conv_weight)
169
+ # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
170
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
171
+ conv.weight.data.copy_(conv_weight)
172
+ nn.init.zeros_(conv.bias.data)
173
+
174
+ def init_weight2(self, conv):
175
+ conv_weight = conv.weight.data
176
+ nn.init.zeros_(conv_weight)
177
+ c1, c2, t, h, w = conv_weight.size()
178
+ init_matrix = torch.eye(c1 // 2, c2)
179
+ # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
180
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
181
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
182
+ conv.weight.data.copy_(conv_weight)
183
+ nn.init.zeros_(conv.bias.data)
184
+
185
+
186
+ class ResidualBlock(nn.Module):
187
+
188
+ def __init__(self, in_dim, out_dim, dropout=0.0):
189
+ super().__init__()
190
+ self.in_dim = in_dim
191
+ self.out_dim = out_dim
192
+
193
+ # layers
194
+ self.residual = nn.Sequential(
195
+ RMS_norm(in_dim, images=False), nn.SiLU(),
196
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
197
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
198
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
199
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
200
+ if in_dim != out_dim else nn.Identity()
201
+
202
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
203
+ h = self.shortcut(x)
204
+ for layer in self.residual:
205
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
206
+ idx = feat_idx[0]
207
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
208
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
209
+ # cache last frame of last two chunk
210
+ cache_x = torch.cat([
211
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
212
+ cache_x.device), cache_x
213
+ ],
214
+ dim=2)
215
+ x = layer(x, feat_cache[idx])
216
+ feat_cache[idx] = cache_x
217
+ feat_idx[0] += 1
218
+ else:
219
+ x = layer(x)
220
+ return x + h
221
+
222
+
223
+ class AttentionBlock(nn.Module):
224
+ """
225
+ Causal self-attention with a single head.
226
+ """
227
+
228
+ def __init__(self, dim):
229
+ super().__init__()
230
+ self.dim = dim
231
+
232
+ # layers
233
+ self.norm = RMS_norm(dim)
234
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
235
+ self.proj = nn.Conv2d(dim, dim, 1)
236
+
237
+ # zero out the last layer params
238
+ nn.init.zeros_(self.proj.weight)
239
+
240
+ def forward(self, x):
241
+ identity = x
242
+ b, c, t, h, w = x.size()
243
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
244
+ x = self.norm(x)
245
+ # compute query, key, value
246
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
247
+ -1).permute(0, 1, 3,
248
+ 2).contiguous().chunk(
249
+ 3, dim=-1)
250
+
251
+ # apply attention
252
+ x = F.scaled_dot_product_attention(
253
+ q,
254
+ k,
255
+ v,
256
+ )
257
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
258
+
259
+ # output
260
+ x = self.proj(x)
261
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
262
+ return x + identity
263
+
264
+
265
+ class Encoder3d(nn.Module):
266
+
267
+ def __init__(self,
268
+ dim=128,
269
+ z_dim=4,
270
+ dim_mult=[1, 2, 4, 4],
271
+ num_res_blocks=2,
272
+ attn_scales=[],
273
+ temperal_downsample=[True, True, False],
274
+ dropout=0.0):
275
+ super().__init__()
276
+ self.dim = dim
277
+ self.z_dim = z_dim
278
+ self.dim_mult = dim_mult
279
+ self.num_res_blocks = num_res_blocks
280
+ self.attn_scales = attn_scales
281
+ self.temperal_downsample = temperal_downsample
282
+
283
+ # dimensions
284
+ dims = [dim * u for u in [1] + dim_mult]
285
+ scale = 1.0
286
+
287
+ # init block
288
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
289
+
290
+ # downsample blocks
291
+ downsamples = []
292
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
293
+ # residual (+attention) blocks
294
+ for _ in range(num_res_blocks):
295
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
296
+ if scale in attn_scales:
297
+ downsamples.append(AttentionBlock(out_dim))
298
+ in_dim = out_dim
299
+
300
+ # downsample block
301
+ if i != len(dim_mult) - 1:
302
+ mode = 'downsample3d' if temperal_downsample[
303
+ i] else 'downsample2d'
304
+ downsamples.append(Resample(out_dim, mode=mode))
305
+ scale /= 2.0
306
+ self.downsamples = nn.Sequential(*downsamples)
307
+
308
+ # middle blocks
309
+ self.middle = nn.Sequential(
310
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
311
+ ResidualBlock(out_dim, out_dim, dropout))
312
+
313
+ # output blocks
314
+ self.head = nn.Sequential(
315
+ RMS_norm(out_dim, images=False), nn.SiLU(),
316
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
317
+
318
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
319
+ if feat_cache is not None:
320
+ idx = feat_idx[0]
321
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
322
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
323
+ # cache last frame of last two chunk
324
+ cache_x = torch.cat([
325
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
326
+ cache_x.device), cache_x
327
+ ],
328
+ dim=2)
329
+ x = self.conv1(x, feat_cache[idx])
330
+ feat_cache[idx] = cache_x
331
+ feat_idx[0] += 1
332
+ else:
333
+ x = self.conv1(x)
334
+
335
+ # downsamples
336
+ for layer in self.downsamples:
337
+ if feat_cache is not None:
338
+ x = layer(x, feat_cache, feat_idx)
339
+ else:
340
+ x = layer(x)
341
+
342
+ # middle
343
+ for layer in self.middle:
344
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
345
+ x = layer(x, feat_cache, feat_idx)
346
+ else:
347
+ x = layer(x)
348
+
349
+ # head
350
+ for layer in self.head:
351
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
352
+ idx = feat_idx[0]
353
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
354
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
355
+ # cache last frame of last two chunk
356
+ cache_x = torch.cat([
357
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
358
+ cache_x.device), cache_x
359
+ ],
360
+ dim=2)
361
+ x = layer(x, feat_cache[idx])
362
+ feat_cache[idx] = cache_x
363
+ feat_idx[0] += 1
364
+ else:
365
+ x = layer(x)
366
+ return x
367
+
368
+
369
+ class Decoder3d(nn.Module):
370
+
371
+ def __init__(self,
372
+ dim=128,
373
+ z_dim=4,
374
+ dim_mult=[1, 2, 4, 4],
375
+ num_res_blocks=2,
376
+ attn_scales=[],
377
+ temperal_upsample=[False, True, True],
378
+ dropout=0.0):
379
+ super().__init__()
380
+ self.dim = dim
381
+ self.z_dim = z_dim
382
+ self.dim_mult = dim_mult
383
+ self.num_res_blocks = num_res_blocks
384
+ self.attn_scales = attn_scales
385
+ self.temperal_upsample = temperal_upsample
386
+
387
+ # dimensions
388
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
389
+ scale = 1.0 / 2**(len(dim_mult) - 2)
390
+
391
+ # init block
392
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
393
+
394
+ # middle blocks
395
+ self.middle = nn.Sequential(
396
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
397
+ ResidualBlock(dims[0], dims[0], dropout))
398
+
399
+ # upsample blocks
400
+ upsamples = []
401
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
402
+ # residual (+attention) blocks
403
+ if i == 1 or i == 2 or i == 3:
404
+ in_dim = in_dim // 2
405
+ for _ in range(num_res_blocks + 1):
406
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
407
+ if scale in attn_scales:
408
+ upsamples.append(AttentionBlock(out_dim))
409
+ in_dim = out_dim
410
+
411
+ # upsample block
412
+ if i != len(dim_mult) - 1:
413
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
414
+ upsamples.append(Resample(out_dim, mode=mode))
415
+ scale *= 2.0
416
+ self.upsamples = nn.Sequential(*upsamples)
417
+
418
+ # output blocks
419
+ self.head = nn.Sequential(
420
+ RMS_norm(out_dim, images=False), nn.SiLU(),
421
+ CausalConv3d(out_dim, 3, 3, padding=1))
422
+
423
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
424
+ # conv1
425
+ if feat_cache is not None:
426
+ idx = feat_idx[0]
427
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
428
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
429
+ # cache last frame of last two chunk
430
+ cache_x = torch.cat([
431
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
432
+ cache_x.device), cache_x
433
+ ],
434
+ dim=2)
435
+ x = self.conv1(x, feat_cache[idx])
436
+ feat_cache[idx] = cache_x
437
+ feat_idx[0] += 1
438
+ else:
439
+ x = self.conv1(x)
440
+
441
+ # middle
442
+ for layer in self.middle:
443
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
444
+ x = layer(x, feat_cache, feat_idx)
445
+ else:
446
+ x = layer(x)
447
+
448
+ # upsamples
449
+ for layer in self.upsamples:
450
+ if feat_cache is not None:
451
+ x = layer(x, feat_cache, feat_idx)
452
+ else:
453
+ x = layer(x)
454
+
455
+ # head
456
+ for layer in self.head:
457
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
458
+ idx = feat_idx[0]
459
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
460
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
461
+ # cache last frame of last two chunk
462
+ cache_x = torch.cat([
463
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
464
+ cache_x.device), cache_x
465
+ ],
466
+ dim=2)
467
+ x = layer(x, feat_cache[idx])
468
+ feat_cache[idx] = cache_x
469
+ feat_idx[0] += 1
470
+ else:
471
+ x = layer(x)
472
+ return x
473
+
474
+
475
+ def count_conv3d(model):
476
+ count = 0
477
+ for m in model.modules():
478
+ if isinstance(m, CausalConv3d):
479
+ count += 1
480
+ return count
481
+
482
+
483
+ class WanVAE_(nn.Module):
484
+
485
+ def __init__(self,
486
+ dim=128,
487
+ z_dim=4,
488
+ dim_mult=[1, 2, 4, 4],
489
+ num_res_blocks=2,
490
+ attn_scales=[],
491
+ temperal_downsample=[True, True, False],
492
+ dropout=0.0):
493
+ super().__init__()
494
+ self.dim = dim
495
+ self.z_dim = z_dim
496
+ self.dim_mult = dim_mult
497
+ self.num_res_blocks = num_res_blocks
498
+ self.attn_scales = attn_scales
499
+ self.temperal_downsample = temperal_downsample
500
+ self.temperal_upsample = temperal_downsample[::-1]
501
+
502
+ # modules
503
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
504
+ attn_scales, self.temperal_downsample, dropout)
505
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
506
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
507
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
508
+ attn_scales, self.temperal_upsample, dropout)
509
+
510
+ def forward(self, x):
511
+ mu, log_var = self.encode(x)
512
+ z = self.reparameterize(mu, log_var)
513
+ x_recon = self.decode(z)
514
+ return x_recon, mu, log_var
515
+
516
+ def encode(self, x, scale):
517
+ self.clear_cache()
518
+ # cache
519
+ t = x.shape[2]
520
+ iter_ = 1 + (t - 1) // 4
521
+ # 对encode输入的x,按时间拆分为1、4、4、4....
522
+ for i in range(iter_):
523
+ self._enc_conv_idx = [0]
524
+ if i == 0:
525
+ out = self.encoder(
526
+ x[:, :, :1, :, :],
527
+ feat_cache=self._enc_feat_map,
528
+ feat_idx=self._enc_conv_idx)
529
+ else:
530
+ out_ = self.encoder(
531
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
532
+ feat_cache=self._enc_feat_map,
533
+ feat_idx=self._enc_conv_idx)
534
+ out = torch.cat([out, out_], 2)
535
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
536
+ if isinstance(scale[0], torch.Tensor):
537
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
538
+ 1, self.z_dim, 1, 1, 1)
539
+ else:
540
+ mu = (mu - scale[0]) * scale[1]
541
+ self.clear_cache()
542
+ return mu
543
+
544
+ def decode(self, z, scale):
545
+ self.clear_cache()
546
+ # z: [b,c,t,h,w]
547
+ if isinstance(scale[0], torch.Tensor):
548
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
549
+ 1, self.z_dim, 1, 1, 1)
550
+ else:
551
+ z = z / scale[1] + scale[0]
552
+ iter_ = z.shape[2]
553
+ x = self.conv2(z)
554
+ for i in range(iter_):
555
+ self._conv_idx = [0]
556
+ if i == 0:
557
+ out = self.decoder(
558
+ x[:, :, i:i + 1, :, :],
559
+ feat_cache=self._feat_map,
560
+ feat_idx=self._conv_idx)
561
+ else:
562
+ out_ = self.decoder(
563
+ x[:, :, i:i + 1, :, :],
564
+ feat_cache=self._feat_map,
565
+ feat_idx=self._conv_idx)
566
+ out = torch.cat([out, out_], 2)
567
+ self.clear_cache()
568
+ return out
569
+
570
+ def reparameterize(self, mu, log_var):
571
+ std = torch.exp(0.5 * log_var)
572
+ eps = torch.randn_like(std)
573
+ return eps * std + mu
574
+
575
+ def sample(self, imgs, deterministic=False):
576
+ mu, log_var = self.encode(imgs)
577
+ if deterministic:
578
+ return mu
579
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
580
+ return mu + std * torch.randn_like(std)
581
+
582
+ def clear_cache(self):
583
+ self._conv_num = count_conv3d(self.decoder)
584
+ self._conv_idx = [0]
585
+ self._feat_map = [None] * self._conv_num
586
+ # cache encode
587
+ self._enc_conv_num = count_conv3d(self.encoder)
588
+ self._enc_conv_idx = [0]
589
+ self._enc_feat_map = [None] * self._enc_conv_num
590
+
591
+
592
+ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
593
+ """
594
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
595
+ """
596
+ # params
597
+ cfg = dict(
598
+ dim=96,
599
+ z_dim=z_dim,
600
+ dim_mult=[1, 2, 4, 4],
601
+ num_res_blocks=2,
602
+ attn_scales=[],
603
+ temperal_downsample=[False, True, True],
604
+ dropout=0.0)
605
+ cfg.update(**kwargs)
606
+
607
+ # init model
608
+ with torch.device('meta'):
609
+ model = WanVAE_(**cfg)
610
+
611
+ # load checkpoint
612
+ logging.info(f'loading {pretrained_path}')
613
+ model.load_state_dict(
614
+ torch.load(pretrained_path, map_location=device, weights_only=True), assign=True)
615
+
616
+ return model
617
+
618
+
619
+ class WanVAE:
620
+
621
+ def __init__(self,
622
+ z_dim=16,
623
+ vae_pth='cache/vae_step_411000.pth',
624
+ dtype=torch.float,
625
+ device="cuda"):
626
+ self.dtype = dtype
627
+ self.device = device
628
+
629
+ mean = [
630
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
631
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
632
+ ]
633
+ std = [
634
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
635
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
636
+ ]
637
+ self.mean = torch.tensor(mean, dtype=dtype, device=device)
638
+ self.std = torch.tensor(std, dtype=dtype, device=device)
639
+ self.scale = [self.mean, 1.0 / self.std]
640
+
641
+ # init model
642
+ self.model = _video_vae(
643
+ pretrained_path=vae_pth,
644
+ z_dim=z_dim,
645
+ ).eval().requires_grad_(False).to(device)
646
+
647
+ def encode(self, videos):
648
+ """
649
+ videos: A list of videos each with shape [C, T, H, W].
650
+ """
651
+ with amp.autocast("cuda", dtype=self.dtype):
652
+ return [
653
+ self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
654
+ for u in videos
655
+ ]
656
+
657
+ def decode(self, zs):
658
+ with amp.autocast("cuda", dtype=self.dtype):
659
+ return [
660
+ self.model.decode(u.unsqueeze(0),
661
+ self.scale).float().clamp_(-1, 1).squeeze(0)
662
+ for u in zs
663
+ ]
wan/modules/xlm_roberta.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = ['XLMRoberta', 'xlm_roberta_large']
8
+
9
+
10
+ class SelfAttention(nn.Module):
11
+
12
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
13
+ assert dim % num_heads == 0
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.num_heads = num_heads
17
+ self.head_dim = dim // num_heads
18
+ self.eps = eps
19
+
20
+ # layers
21
+ self.q = nn.Linear(dim, dim)
22
+ self.k = nn.Linear(dim, dim)
23
+ self.v = nn.Linear(dim, dim)
24
+ self.o = nn.Linear(dim, dim)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ def forward(self, x, mask):
28
+ """
29
+ x: [B, L, C].
30
+ """
31
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
32
+
33
+ # compute query, key, value
34
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
37
+
38
+ # compute attention
39
+ p = self.dropout.p if self.training else 0.0
40
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
41
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
42
+
43
+ # output
44
+ x = self.o(x)
45
+ x = self.dropout(x)
46
+ return x
47
+
48
+
49
+ class AttentionBlock(nn.Module):
50
+
51
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
52
+ super().__init__()
53
+ self.dim = dim
54
+ self.num_heads = num_heads
55
+ self.post_norm = post_norm
56
+ self.eps = eps
57
+
58
+ # layers
59
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
60
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
61
+ self.ffn = nn.Sequential(
62
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
63
+ nn.Dropout(dropout))
64
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
65
+
66
+ def forward(self, x, mask):
67
+ if self.post_norm:
68
+ x = self.norm1(x + self.attn(x, mask))
69
+ x = self.norm2(x + self.ffn(x))
70
+ else:
71
+ x = x + self.attn(self.norm1(x), mask)
72
+ x = x + self.ffn(self.norm2(x))
73
+ return x
74
+
75
+
76
+ class XLMRoberta(nn.Module):
77
+ """
78
+ XLMRobertaModel with no pooler and no LM head.
79
+ """
80
+
81
+ def __init__(self,
82
+ vocab_size=250002,
83
+ max_seq_len=514,
84
+ type_size=1,
85
+ pad_id=1,
86
+ dim=1024,
87
+ num_heads=16,
88
+ num_layers=24,
89
+ post_norm=True,
90
+ dropout=0.1,
91
+ eps=1e-5):
92
+ super().__init__()
93
+ self.vocab_size = vocab_size
94
+ self.max_seq_len = max_seq_len
95
+ self.type_size = type_size
96
+ self.pad_id = pad_id
97
+ self.dim = dim
98
+ self.num_heads = num_heads
99
+ self.num_layers = num_layers
100
+ self.post_norm = post_norm
101
+ self.eps = eps
102
+
103
+ # embeddings
104
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105
+ self.type_embedding = nn.Embedding(type_size, dim)
106
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ # blocks
110
+ self.blocks = nn.ModuleList([
111
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112
+ for _ in range(num_layers)
113
+ ])
114
+
115
+ # norm layer
116
+ self.norm = nn.LayerNorm(dim, eps=eps)
117
+
118
+ def forward(self, ids):
119
+ """
120
+ ids: [B, L] of torch.LongTensor.
121
+ """
122
+ b, s = ids.shape
123
+ mask = ids.ne(self.pad_id).long()
124
+
125
+ # embeddings
126
+ x = self.token_embedding(ids) + \
127
+ self.type_embedding(torch.zeros_like(ids)) + \
128
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129
+ if self.post_norm:
130
+ x = self.norm(x)
131
+ x = self.dropout(x)
132
+
133
+ # blocks
134
+ mask = torch.where(
135
+ mask.view(b, 1, 1, s).gt(0), 0.0,
136
+ torch.finfo(x.dtype).min)
137
+ for block in self.blocks:
138
+ x = block(x, mask)
139
+
140
+ # output
141
+ if not self.post_norm:
142
+ x = self.norm(x)
143
+ return x
144
+
145
+
146
+ def xlm_roberta_large(pretrained=False,
147
+ return_tokenizer=False,
148
+ device='cpu',
149
+ **kwargs):
150
+ """
151
+ XLMRobertaLarge adapted from Huggingface.
152
+ """
153
+ # params
154
+ cfg = dict(
155
+ vocab_size=250002,
156
+ max_seq_len=514,
157
+ type_size=1,
158
+ pad_id=1,
159
+ dim=1024,
160
+ num_heads=16,
161
+ num_layers=24,
162
+ post_norm=True,
163
+ dropout=0.1,
164
+ eps=1e-5)
165
+ cfg.update(**kwargs)
166
+
167
+ # init a model on device
168
+ with torch.device(device):
169
+ model = XLMRoberta(**cfg)
170
+ return model
wan/text2video.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import torch
13
+ import torch.cuda.amp as amp
14
+ import torch.distributed as dist
15
+ from tqdm import tqdm
16
+
17
+ from .distributed.fsdp import shard_model
18
+ from .modules.model import WanModel
19
+ from .modules.t5 import T5EncoderModel
20
+ from .modules.vae import WanVAE
21
+ from .utils.fm_solvers import (
22
+ FlowDPMSolverMultistepScheduler,
23
+ get_sampling_sigmas,
24
+ retrieve_timesteps,
25
+ )
26
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
+
28
+
29
+ class WanT2V:
30
+
31
+ def __init__(
32
+ self,
33
+ config,
34
+ checkpoint_dir,
35
+ device_id=0,
36
+ rank=0,
37
+ t5_fsdp=False,
38
+ dit_fsdp=False,
39
+ use_usp=False,
40
+ t5_cpu=False,
41
+ ):
42
+ r"""
43
+ Initializes the Wan text-to-video generation model components.
44
+
45
+ Args:
46
+ config (EasyDict):
47
+ Object containing model parameters initialized from config.py
48
+ checkpoint_dir (`str`):
49
+ Path to directory containing model checkpoints
50
+ device_id (`int`, *optional*, defaults to 0):
51
+ Id of target GPU device
52
+ rank (`int`, *optional*, defaults to 0):
53
+ Process rank for distributed training
54
+ t5_fsdp (`bool`, *optional*, defaults to False):
55
+ Enable FSDP sharding for T5 model
56
+ dit_fsdp (`bool`, *optional*, defaults to False):
57
+ Enable FSDP sharding for DiT model
58
+ use_usp (`bool`, *optional*, defaults to False):
59
+ Enable distribution strategy of USP.
60
+ t5_cpu (`bool`, *optional*, defaults to False):
61
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
62
+ """
63
+ self.device = torch.device(f"cuda:{device_id}")
64
+ self.config = config
65
+ self.rank = rank
66
+ self.t5_cpu = t5_cpu
67
+
68
+ self.num_train_timesteps = config.num_train_timesteps
69
+ self.param_dtype = config.param_dtype
70
+
71
+ shard_fn = partial(shard_model, device_id=device_id)
72
+ self.text_encoder = T5EncoderModel(
73
+ text_len=config.text_len,
74
+ dtype=config.t5_dtype,
75
+ device=torch.device('cpu'),
76
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
77
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
78
+ shard_fn=shard_fn if t5_fsdp else None)
79
+
80
+ self.vae_stride = config.vae_stride
81
+ self.patch_size = config.patch_size
82
+ self.vae = WanVAE(
83
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
84
+ device=self.device)
85
+
86
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
87
+ self.model = WanModel.from_pretrained(checkpoint_dir)
88
+ self.model.eval().requires_grad_(False)
89
+
90
+ if use_usp:
91
+ from xfuser.core.distributed import get_sequence_parallel_world_size
92
+
93
+ from .distributed.xdit_context_parallel import (
94
+ usp_attn_forward,
95
+ usp_dit_forward,
96
+ )
97
+ for block in self.model.blocks:
98
+ block.self_attn.forward = types.MethodType(
99
+ usp_attn_forward, block.self_attn)
100
+ self.model.forward = types.MethodType(usp_dit_forward, self.model)
101
+ self.sp_size = get_sequence_parallel_world_size()
102
+ else:
103
+ self.sp_size = 1
104
+
105
+ if dist.is_initialized():
106
+ dist.barrier()
107
+ if dit_fsdp:
108
+ self.model = shard_fn(self.model)
109
+ else:
110
+ self.model.to(self.device)
111
+
112
+ self.sample_neg_prompt = config.sample_neg_prompt
113
+
114
+ def generate(self,
115
+ input_prompt,
116
+ size=(1280, 720),
117
+ frame_num=81,
118
+ shift=5.0,
119
+ sample_solver='unipc',
120
+ sampling_steps=50,
121
+ guide_scale=5.0,
122
+ n_prompt="",
123
+ seed=-1,
124
+ offload_model=True):
125
+ r"""
126
+ Generates video frames from text prompt using diffusion process.
127
+
128
+ Args:
129
+ input_prompt (`str`):
130
+ Text prompt for content generation
131
+ size (tupele[`int`], *optional*, defaults to (1280,720)):
132
+ Controls video resolution, (width,height).
133
+ frame_num (`int`, *optional*, defaults to 81):
134
+ How many frames to sample from a video. The number should be 4n+1
135
+ shift (`float`, *optional*, defaults to 5.0):
136
+ Noise schedule shift parameter. Affects temporal dynamics
137
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
138
+ Solver used to sample the video.
139
+ sampling_steps (`int`, *optional*, defaults to 40):
140
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
141
+ guide_scale (`float`, *optional*, defaults 5.0):
142
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity
143
+ n_prompt (`str`, *optional*, defaults to ""):
144
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
145
+ seed (`int`, *optional*, defaults to -1):
146
+ Random seed for noise generation. If -1, use random seed.
147
+ offload_model (`bool`, *optional*, defaults to True):
148
+ If True, offloads models to CPU during generation to save VRAM
149
+
150
+ Returns:
151
+ torch.Tensor:
152
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
153
+ - C: Color channels (3 for RGB)
154
+ - N: Number of frames (81)
155
+ - H: Frame height (from size)
156
+ - W: Frame width from size)
157
+ """
158
+ # preprocess
159
+ F = frame_num
160
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
161
+ size[1] // self.vae_stride[1],
162
+ size[0] // self.vae_stride[2])
163
+
164
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
165
+ (self.patch_size[1] * self.patch_size[2]) *
166
+ target_shape[1] / self.sp_size) * self.sp_size
167
+
168
+ if n_prompt == "":
169
+ n_prompt = self.sample_neg_prompt
170
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
171
+ seed_g = torch.Generator(device=self.device)
172
+ seed_g.manual_seed(seed)
173
+
174
+ if not self.t5_cpu:
175
+ self.text_encoder.model.to(self.device)
176
+ context = self.text_encoder([input_prompt], self.device)
177
+ context_null = self.text_encoder([n_prompt], self.device)
178
+ if offload_model:
179
+ self.text_encoder.model.cpu()
180
+ else:
181
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
182
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
183
+ context = [t.to(self.device) for t in context]
184
+ context_null = [t.to(self.device) for t in context_null]
185
+
186
+ noise = [
187
+ torch.randn(
188
+ target_shape[0],
189
+ target_shape[1],
190
+ target_shape[2],
191
+ target_shape[3],
192
+ dtype=torch.float32,
193
+ device=self.device,
194
+ generator=seed_g)
195
+ ]
196
+
197
+ @contextmanager
198
+ def noop_no_sync():
199
+ yield
200
+
201
+ no_sync = getattr(self.model, 'no_sync', noop_no_sync)
202
+
203
+ # evaluation mode
204
+ with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
205
+
206
+ if sample_solver == 'unipc':
207
+ sample_scheduler = FlowUniPCMultistepScheduler(
208
+ num_train_timesteps=self.num_train_timesteps,
209
+ shift=1,
210
+ use_dynamic_shifting=False)
211
+ sample_scheduler.set_timesteps(
212
+ sampling_steps, device=self.device, shift=shift)
213
+ timesteps = sample_scheduler.timesteps
214
+ elif sample_solver == 'dpm++':
215
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
216
+ num_train_timesteps=self.num_train_timesteps,
217
+ shift=1,
218
+ use_dynamic_shifting=False)
219
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
220
+ timesteps, _ = retrieve_timesteps(
221
+ sample_scheduler,
222
+ device=self.device,
223
+ sigmas=sampling_sigmas)
224
+ else:
225
+ raise NotImplementedError("Unsupported solver.")
226
+
227
+ # sample videos
228
+ latents = noise
229
+
230
+ arg_c = {'context': context, 'seq_len': seq_len}
231
+ arg_null = {'context': context_null, 'seq_len': seq_len}
232
+
233
+ for _, t in enumerate(tqdm(timesteps)):
234
+ latent_model_input = latents
235
+ timestep = [t]
236
+
237
+ timestep = torch.stack(timestep)
238
+
239
+ self.model.to(self.device)
240
+ noise_pred_cond = self.model(
241
+ latent_model_input, t=timestep, **arg_c)[0]
242
+ noise_pred_uncond = self.model(
243
+ latent_model_input, t=timestep, **arg_null)[0]
244
+
245
+ noise_pred = noise_pred_uncond + guide_scale * (
246
+ noise_pred_cond - noise_pred_uncond)
247
+
248
+ temp_x0 = sample_scheduler.step(
249
+ noise_pred.unsqueeze(0),
250
+ t,
251
+ latents[0].unsqueeze(0),
252
+ return_dict=False,
253
+ generator=seed_g)[0]
254
+ latents = [temp_x0.squeeze(0)]
255
+
256
+ x0 = latents
257
+ if offload_model:
258
+ self.model.cpu()
259
+ torch.cuda.empty_cache()
260
+ if self.rank == 0:
261
+ videos = self.vae.decode(x0)
262
+
263
+ del noise, latents
264
+ del sample_scheduler
265
+ if offload_model:
266
+ gc.collect()
267
+ torch.cuda.synchronize()
268
+ if dist.is_initialized():
269
+ dist.barrier()
270
+
271
+ return videos[0] if self.rank == 0 else None
wan/utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
2
+ retrieve_timesteps)
3
+ from .fm_solvers_unipc import FlowUniPCMultistepScheduler
4
+
5
+ __all__ = [
6
+ 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
7
+ 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
8
+ ]
wan/utils/fm_solvers.py ADDED
@@ -0,0 +1,934 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
2
+ # Convert dpm solver for flow matching
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+
5
+ import inspect
6
+ import math
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
13
+ SchedulerMixin,
14
+ SchedulerOutput)
15
+ from diffusers.utils import deprecate, is_scipy_available
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+
18
+ if is_scipy_available():
19
+ pass
20
+
21
+
22
+ def get_sampling_sigmas(sampling_steps, shift):
23
+ sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
24
+ sigma = (shift * sigma / (1 + (shift - 1) * sigma))
25
+
26
+ return sigma
27
+
28
+
29
+ def retrieve_timesteps(
30
+ scheduler,
31
+ num_inference_steps=None,
32
+ device=None,
33
+ timesteps=None,
34
+ sigmas=None,
35
+ **kwargs,
36
+ ):
37
+ if timesteps is not None and sigmas is not None:
38
+ raise ValueError(
39
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
40
+ )
41
+ if timesteps is not None:
42
+ accepts_timesteps = "timesteps" in set(
43
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
44
+ if not accepts_timesteps:
45
+ raise ValueError(
46
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
47
+ f" timestep schedules. Please check whether you are using the correct scheduler."
48
+ )
49
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
50
+ timesteps = scheduler.timesteps
51
+ num_inference_steps = len(timesteps)
52
+ elif sigmas is not None:
53
+ accept_sigmas = "sigmas" in set(
54
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
55
+ if not accept_sigmas:
56
+ raise ValueError(
57
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
58
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
59
+ )
60
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
61
+ timesteps = scheduler.timesteps
62
+ num_inference_steps = len(timesteps)
63
+ else:
64
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
65
+ timesteps = scheduler.timesteps
66
+ return timesteps, num_inference_steps
67
+
68
+
69
+ class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
70
+ """
71
+ `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
72
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
73
+ methods the library implements for all schedulers such as loading and saving.
74
+ Args:
75
+ num_train_timesteps (`int`, defaults to 1000):
76
+ The number of diffusion steps to train the model. This determines the resolution of the diffusion process.
77
+ solver_order (`int`, defaults to 2):
78
+ The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided
79
+ sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored
80
+ and used in multistep updates.
81
+ prediction_type (`str`, defaults to "flow_prediction"):
82
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
83
+ the flow of the diffusion process.
84
+ shift (`float`, *optional*, defaults to 1.0):
85
+ A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling
86
+ process.
87
+ use_dynamic_shifting (`bool`, defaults to `False`):
88
+ Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is
89
+ applied on the fly.
90
+ thresholding (`bool`, defaults to `False`):
91
+ Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent
92
+ saturation and improve photorealism.
93
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
94
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
95
+ sample_max_value (`float`, defaults to 1.0):
96
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
97
+ `algorithm_type="dpmsolver++"`.
98
+ algorithm_type (`str`, defaults to `dpmsolver++`):
99
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
100
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
101
+ paper, and the `dpmsolver++` type implements the algorithms in the
102
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
103
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
104
+ solver_type (`str`, defaults to `midpoint`):
105
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
106
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
107
+ lower_order_final (`bool`, defaults to `True`):
108
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
109
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
110
+ euler_at_final (`bool`, defaults to `False`):
111
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
112
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
113
+ steps, but sometimes may result in blurring.
114
+ final_sigmas_type (`str`, *optional*, defaults to "zero"):
115
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
116
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
117
+ lambda_min_clipped (`float`, defaults to `-inf`):
118
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
119
+ cosine (`squaredcos_cap_v2`) noise schedule.
120
+ variance_type (`str`, *optional*):
121
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
122
+ contains the predicted Gaussian variance.
123
+ """
124
+
125
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
126
+ order = 1
127
+
128
+ @register_to_config
129
+ def __init__(
130
+ self,
131
+ num_train_timesteps: int = 1000,
132
+ solver_order: int = 2,
133
+ prediction_type: str = "flow_prediction",
134
+ shift: Optional[float] = 1.0,
135
+ use_dynamic_shifting=False,
136
+ thresholding: bool = False,
137
+ dynamic_thresholding_ratio: float = 0.995,
138
+ sample_max_value: float = 1.0,
139
+ algorithm_type: str = "dpmsolver++",
140
+ solver_type: str = "midpoint",
141
+ lower_order_final: bool = True,
142
+ euler_at_final: bool = False,
143
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
144
+ lambda_min_clipped: float = -float("inf"),
145
+ variance_type: Optional[str] = None,
146
+ invert_sigmas: bool = False,
147
+ ):
148
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
149
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
150
+ deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0",
151
+ deprecation_message)
152
+
153
+ # settings for DPM-Solver
154
+ if algorithm_type not in [
155
+ "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"
156
+ ]:
157
+ if algorithm_type == "deis":
158
+ self.register_to_config(algorithm_type="dpmsolver++")
159
+ else:
160
+ raise NotImplementedError(
161
+ f"{algorithm_type} is not implemented for {self.__class__}")
162
+
163
+ if solver_type not in ["midpoint", "heun"]:
164
+ if solver_type in ["logrho", "bh1", "bh2"]:
165
+ self.register_to_config(solver_type="midpoint")
166
+ else:
167
+ raise NotImplementedError(
168
+ f"{solver_type} is not implemented for {self.__class__}")
169
+
170
+ if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"
171
+ ] and final_sigmas_type == "zero":
172
+ raise ValueError(
173
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
174
+ )
175
+
176
+ # setable values
177
+ self.num_inference_steps = None
178
+ alphas = np.linspace(1, 1 / num_train_timesteps,
179
+ num_train_timesteps)[::-1].copy()
180
+ sigmas = 1.0 - alphas
181
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
182
+
183
+ if not use_dynamic_shifting:
184
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
185
+ sigmas = shift * sigmas / (1 +
186
+ (shift - 1) * sigmas) # pyright: ignore
187
+
188
+ self.sigmas = sigmas
189
+ self.timesteps = sigmas * num_train_timesteps
190
+
191
+ self.model_outputs = [None] * solver_order
192
+ self.lower_order_nums = 0
193
+ self._step_index = None
194
+ self._begin_index = None
195
+
196
+ # self.sigmas = self.sigmas.to(
197
+ # "cpu") # to avoid too much CPU/GPU communication
198
+ self.sigma_min = self.sigmas[-1].item()
199
+ self.sigma_max = self.sigmas[0].item()
200
+
201
+ @property
202
+ def step_index(self):
203
+ """
204
+ The index counter for current timestep. It will increase 1 after each scheduler step.
205
+ """
206
+ return self._step_index
207
+
208
+ @property
209
+ def begin_index(self):
210
+ """
211
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
212
+ """
213
+ return self._begin_index
214
+
215
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
216
+ def set_begin_index(self, begin_index: int = 0):
217
+ """
218
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
219
+ Args:
220
+ begin_index (`int`):
221
+ The begin index for the scheduler.
222
+ """
223
+ self._begin_index = begin_index
224
+
225
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
226
+ def set_timesteps(
227
+ self,
228
+ num_inference_steps: Union[int, None] = None,
229
+ device: Union[str, torch.device] = None,
230
+ sigmas: Optional[List[float]] = None,
231
+ mu: Optional[Union[float, None]] = None,
232
+ shift: Optional[Union[float, None]] = None,
233
+ ):
234
+ """
235
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
236
+ Args:
237
+ num_inference_steps (`int`):
238
+ Total number of the spacing of the time steps.
239
+ device (`str` or `torch.device`, *optional*):
240
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
241
+ """
242
+
243
+ if self.config.use_dynamic_shifting and mu is None:
244
+ raise ValueError(
245
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
246
+ )
247
+
248
+ if sigmas is None:
249
+ sigmas = np.linspace(self.sigma_max, self.sigma_min,
250
+ num_inference_steps +
251
+ 1).copy()[:-1] # pyright: ignore
252
+
253
+ if self.config.use_dynamic_shifting:
254
+ sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
255
+ else:
256
+ if shift is None:
257
+ shift = self.config.shift
258
+ sigmas = shift * sigmas / (1 +
259
+ (shift - 1) * sigmas) # pyright: ignore
260
+
261
+ if self.config.final_sigmas_type == "sigma_min":
262
+ sigma_last = ((1 - self.alphas_cumprod[0]) /
263
+ self.alphas_cumprod[0])**0.5
264
+ elif self.config.final_sigmas_type == "zero":
265
+ sigma_last = 0
266
+ else:
267
+ raise ValueError(
268
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
269
+ )
270
+
271
+ timesteps = sigmas * self.config.num_train_timesteps
272
+ sigmas = np.concatenate([sigmas, [sigma_last]
273
+ ]).astype(np.float32) # pyright: ignore
274
+
275
+ self.sigmas = torch.from_numpy(sigmas)
276
+ self.timesteps = torch.from_numpy(timesteps).to(
277
+ device=device, dtype=torch.int64)
278
+
279
+ self.num_inference_steps = len(timesteps)
280
+
281
+ self.model_outputs = [
282
+ None,
283
+ ] * self.config.solver_order
284
+ self.lower_order_nums = 0
285
+
286
+ self._step_index = None
287
+ self._begin_index = None
288
+ # self.sigmas = self.sigmas.to(
289
+ # "cpu") # to avoid too much CPU/GPU communication
290
+
291
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
292
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
293
+ """
294
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
295
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
296
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
297
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
298
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
299
+ https://arxiv.org/abs/2205.11487
300
+ """
301
+ dtype = sample.dtype
302
+ batch_size, channels, *remaining_dims = sample.shape
303
+
304
+ if dtype not in (torch.float32, torch.float64):
305
+ sample = sample.float(
306
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
307
+
308
+ # Flatten sample for doing quantile calculation along each image
309
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
310
+
311
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
312
+
313
+ s = torch.quantile(
314
+ abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
315
+ s = torch.clamp(
316
+ s, min=1, max=self.config.sample_max_value
317
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
318
+ s = s.unsqueeze(
319
+ 1) # (batch_size, 1) because clamp will broadcast along dim=0
320
+ sample = torch.clamp(
321
+ sample, -s, s
322
+ ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
323
+
324
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
325
+ sample = sample.to(dtype)
326
+
327
+ return sample
328
+
329
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
330
+ def _sigma_to_t(self, sigma):
331
+ return sigma * self.config.num_train_timesteps
332
+
333
+ def _sigma_to_alpha_sigma_t(self, sigma):
334
+ return 1 - sigma, sigma
335
+
336
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
337
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
338
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
339
+
340
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
341
+ def convert_model_output(
342
+ self,
343
+ model_output: torch.Tensor,
344
+ *args,
345
+ sample: torch.Tensor = None,
346
+ **kwargs,
347
+ ) -> torch.Tensor:
348
+ """
349
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
350
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
351
+ integral of the data prediction model.
352
+ <Tip>
353
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
354
+ prediction and data prediction models.
355
+ </Tip>
356
+ Args:
357
+ model_output (`torch.Tensor`):
358
+ The direct output from the learned diffusion model.
359
+ sample (`torch.Tensor`):
360
+ A current instance of a sample created by the diffusion process.
361
+ Returns:
362
+ `torch.Tensor`:
363
+ The converted model output.
364
+ """
365
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
366
+ if sample is None:
367
+ if len(args) > 1:
368
+ sample = args[1]
369
+ else:
370
+ raise ValueError(
371
+ "missing `sample` as a required keyward argument")
372
+ if timestep is not None:
373
+ deprecate(
374
+ "timesteps",
375
+ "1.0.0",
376
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
377
+ )
378
+
379
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
380
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
381
+ if self.config.prediction_type == "flow_prediction":
382
+ sigma_t = self.sigmas[self.step_index]
383
+ x0_pred = sample - sigma_t * model_output
384
+ else:
385
+ raise ValueError(
386
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
387
+ " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
388
+ )
389
+
390
+ if self.config.thresholding:
391
+ x0_pred = self._threshold_sample(x0_pred)
392
+
393
+ return x0_pred
394
+
395
+ # DPM-Solver needs to solve an integral of the noise prediction model.
396
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
397
+ if self.config.prediction_type == "flow_prediction":
398
+ sigma_t = self.sigmas[self.step_index]
399
+ epsilon = sample - (1 - sigma_t) * model_output
400
+ else:
401
+ raise ValueError(
402
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
403
+ " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
404
+ )
405
+
406
+ if self.config.thresholding:
407
+ sigma_t = self.sigmas[self.step_index]
408
+ x0_pred = sample - sigma_t * model_output
409
+ x0_pred = self._threshold_sample(x0_pred)
410
+ epsilon = model_output + x0_pred
411
+
412
+ return epsilon
413
+
414
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
415
+ def dpm_solver_first_order_update(
416
+ self,
417
+ model_output: torch.Tensor,
418
+ *args,
419
+ sample: torch.Tensor = None,
420
+ noise: Optional[torch.Tensor] = None,
421
+ **kwargs,
422
+ ) -> torch.Tensor:
423
+ """
424
+ One step for the first-order DPMSolver (equivalent to DDIM).
425
+ Args:
426
+ model_output (`torch.Tensor`):
427
+ The direct output from the learned diffusion model.
428
+ sample (`torch.Tensor`):
429
+ A current instance of a sample created by the diffusion process.
430
+ Returns:
431
+ `torch.Tensor`:
432
+ The sample tensor at the previous timestep.
433
+ """
434
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
435
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
436
+ "prev_timestep", None)
437
+ if sample is None:
438
+ if len(args) > 2:
439
+ sample = args[2]
440
+ else:
441
+ raise ValueError(
442
+ " missing `sample` as a required keyward argument")
443
+ if timestep is not None:
444
+ deprecate(
445
+ "timesteps",
446
+ "1.0.0",
447
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
448
+ )
449
+
450
+ if prev_timestep is not None:
451
+ deprecate(
452
+ "prev_timestep",
453
+ "1.0.0",
454
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
455
+ )
456
+
457
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[
458
+ self.step_index] # pyright: ignore
459
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
460
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
461
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
462
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
463
+
464
+ h = lambda_t - lambda_s
465
+ if self.config.algorithm_type == "dpmsolver++":
466
+ x_t = (sigma_t /
467
+ sigma_s) * sample - (alpha_t *
468
+ (torch.exp(-h) - 1.0)) * model_output
469
+ elif self.config.algorithm_type == "dpmsolver":
470
+ x_t = (alpha_t /
471
+ alpha_s) * sample - (sigma_t *
472
+ (torch.exp(h) - 1.0)) * model_output
473
+ elif self.config.algorithm_type == "sde-dpmsolver++":
474
+ assert noise is not None
475
+ x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample +
476
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output +
477
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
478
+ elif self.config.algorithm_type == "sde-dpmsolver":
479
+ assert noise is not None
480
+ x_t = ((alpha_t / alpha_s) * sample - 2.0 *
481
+ (sigma_t * (torch.exp(h) - 1.0)) * model_output +
482
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
483
+ return x_t # pyright: ignore
484
+
485
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
486
+ def multistep_dpm_solver_second_order_update(
487
+ self,
488
+ model_output_list: List[torch.Tensor],
489
+ *args,
490
+ sample: torch.Tensor = None,
491
+ noise: Optional[torch.Tensor] = None,
492
+ **kwargs,
493
+ ) -> torch.Tensor:
494
+ """
495
+ One step for the second-order multistep DPMSolver.
496
+ Args:
497
+ model_output_list (`List[torch.Tensor]`):
498
+ The direct outputs from learned diffusion model at current and latter timesteps.
499
+ sample (`torch.Tensor`):
500
+ A current instance of a sample created by the diffusion process.
501
+ Returns:
502
+ `torch.Tensor`:
503
+ The sample tensor at the previous timestep.
504
+ """
505
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop(
506
+ "timestep_list", None)
507
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
508
+ "prev_timestep", None)
509
+ if sample is None:
510
+ if len(args) > 2:
511
+ sample = args[2]
512
+ else:
513
+ raise ValueError(
514
+ " missing `sample` as a required keyward argument")
515
+ if timestep_list is not None:
516
+ deprecate(
517
+ "timestep_list",
518
+ "1.0.0",
519
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
520
+ )
521
+
522
+ if prev_timestep is not None:
523
+ deprecate(
524
+ "prev_timestep",
525
+ "1.0.0",
526
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
527
+ )
528
+
529
+ sigma_t, sigma_s0, sigma_s1 = (
530
+ self.sigmas[self.step_index + 1], # pyright: ignore
531
+ self.sigmas[self.step_index],
532
+ self.sigmas[self.step_index - 1], # pyright: ignore
533
+ )
534
+
535
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
536
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
537
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
538
+
539
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
540
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
541
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
542
+
543
+ m0, m1 = model_output_list[-1], model_output_list[-2]
544
+
545
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
546
+ r0 = h_0 / h
547
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
548
+ if self.config.algorithm_type == "dpmsolver++":
549
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
550
+ if self.config.solver_type == "midpoint":
551
+ x_t = ((sigma_t / sigma_s0) * sample -
552
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 *
553
+ (alpha_t * (torch.exp(-h) - 1.0)) * D1)
554
+ elif self.config.solver_type == "heun":
555
+ x_t = ((sigma_t / sigma_s0) * sample -
556
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
557
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1)
558
+ elif self.config.algorithm_type == "dpmsolver":
559
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
560
+ if self.config.solver_type == "midpoint":
561
+ x_t = ((alpha_t / alpha_s0) * sample -
562
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 *
563
+ (sigma_t * (torch.exp(h) - 1.0)) * D1)
564
+ elif self.config.solver_type == "heun":
565
+ x_t = ((alpha_t / alpha_s0) * sample -
566
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 -
567
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1)
568
+ elif self.config.algorithm_type == "sde-dpmsolver++":
569
+ assert noise is not None
570
+ if self.config.solver_type == "midpoint":
571
+ x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
572
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 *
573
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 +
574
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
575
+ elif self.config.solver_type == "heun":
576
+ x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
577
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 +
578
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) /
579
+ (-2.0 * h) + 1.0)) * D1 +
580
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
581
+ elif self.config.algorithm_type == "sde-dpmsolver":
582
+ assert noise is not None
583
+ if self.config.solver_type == "midpoint":
584
+ x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
585
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 -
586
+ (sigma_t * (torch.exp(h) - 1.0)) * D1 +
587
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
588
+ elif self.config.solver_type == "heun":
589
+ x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
590
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 *
591
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 +
592
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
593
+ return x_t # pyright: ignore
594
+
595
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
596
+ def multistep_dpm_solver_third_order_update(
597
+ self,
598
+ model_output_list: List[torch.Tensor],
599
+ *args,
600
+ sample: torch.Tensor = None,
601
+ **kwargs,
602
+ ) -> torch.Tensor:
603
+ """
604
+ One step for the third-order multistep DPMSolver.
605
+ Args:
606
+ model_output_list (`List[torch.Tensor]`):
607
+ The direct outputs from learned diffusion model at current and latter timesteps.
608
+ sample (`torch.Tensor`):
609
+ A current instance of a sample created by diffusion process.
610
+ Returns:
611
+ `torch.Tensor`:
612
+ The sample tensor at the previous timestep.
613
+ """
614
+
615
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop(
616
+ "timestep_list", None)
617
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
618
+ "prev_timestep", None)
619
+ if sample is None:
620
+ if len(args) > 2:
621
+ sample = args[2]
622
+ else:
623
+ raise ValueError(
624
+ " missing`sample` as a required keyward argument")
625
+ if timestep_list is not None:
626
+ deprecate(
627
+ "timestep_list",
628
+ "1.0.0",
629
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
630
+ )
631
+
632
+ if prev_timestep is not None:
633
+ deprecate(
634
+ "prev_timestep",
635
+ "1.0.0",
636
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
637
+ )
638
+
639
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
640
+ self.sigmas[self.step_index + 1], # pyright: ignore
641
+ self.sigmas[self.step_index],
642
+ self.sigmas[self.step_index - 1], # pyright: ignore
643
+ self.sigmas[self.step_index - 2], # pyright: ignore
644
+ )
645
+
646
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
647
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
648
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
649
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
650
+
651
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
652
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
653
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
654
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
655
+
656
+ m0, m1, m2 = model_output_list[-1], model_output_list[
657
+ -2], model_output_list[-3]
658
+
659
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
660
+ r0, r1 = h_0 / h, h_1 / h
661
+ D0 = m0
662
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
663
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
664
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
665
+ if self.config.algorithm_type == "dpmsolver++":
666
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
667
+ x_t = ((sigma_t / sigma_s0) * sample -
668
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
669
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 -
670
+ (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2)
671
+ elif self.config.algorithm_type == "dpmsolver":
672
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
673
+ x_t = ((alpha_t / alpha_s0) * sample - (sigma_t *
674
+ (torch.exp(h) - 1.0)) * D0 -
675
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 -
676
+ (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2)
677
+ return x_t # pyright: ignore
678
+
679
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
680
+ if schedule_timesteps is None:
681
+ schedule_timesteps = self.timesteps
682
+
683
+ indices = (schedule_timesteps == timestep).nonzero()
684
+
685
+ # The sigma index that is taken for the **very** first `step`
686
+ # is always the second index (or the last index if there is only 1)
687
+ # This way we can ensure we don't accidentally skip a sigma in
688
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
689
+ pos = 1 if len(indices) > 1 else 0
690
+
691
+ return indices[pos].item()
692
+
693
+ def _init_step_index(self, timestep):
694
+ """
695
+ Initialize the step_index counter for the scheduler.
696
+ """
697
+
698
+ if self.begin_index is None:
699
+ if isinstance(timestep, torch.Tensor):
700
+ timestep = timestep.to(self.timesteps.device)
701
+ self._step_index = self.index_for_timestep(timestep)
702
+ else:
703
+ self._step_index = self._begin_index
704
+
705
+ # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
706
+ def step(
707
+ self,
708
+ model_output: torch.Tensor,
709
+ timestep: Union[int, torch.Tensor],
710
+ sample: torch.Tensor,
711
+ generator=None,
712
+ variance_noise: Optional[torch.Tensor] = None,
713
+ return_dict: bool = True,
714
+ ) -> Union[SchedulerOutput, Tuple]:
715
+ """
716
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
717
+ the multistep DPMSolver.
718
+ Args:
719
+ model_output (`torch.Tensor`):
720
+ The direct output from learned diffusion model.
721
+ timestep (`int`):
722
+ The current discrete timestep in the diffusion chain.
723
+ sample (`torch.Tensor`):
724
+ A current instance of a sample created by the diffusion process.
725
+ generator (`torch.Generator`, *optional*):
726
+ A random number generator.
727
+ variance_noise (`torch.Tensor`):
728
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
729
+ itself. Useful for methods such as [`LEdits++`].
730
+ return_dict (`bool`):
731
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
732
+ Returns:
733
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
734
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
735
+ tuple is returned where the first element is the sample tensor.
736
+ """
737
+ if self.num_inference_steps is None:
738
+ raise ValueError(
739
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
740
+ )
741
+
742
+ if self.step_index is None:
743
+ self._init_step_index(timestep)
744
+
745
+ # Improve numerical stability for small number of steps
746
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
747
+ self.config.euler_at_final or
748
+ (self.config.lower_order_final and len(self.timesteps) < 15) or
749
+ self.config.final_sigmas_type == "zero")
750
+ lower_order_second = ((self.step_index == len(self.timesteps) - 2) and
751
+ self.config.lower_order_final and
752
+ len(self.timesteps) < 15)
753
+
754
+ model_output = self.convert_model_output(model_output, sample=sample)
755
+ for i in range(self.config.solver_order - 1):
756
+ self.model_outputs[i] = self.model_outputs[i + 1]
757
+ self.model_outputs[-1] = model_output
758
+
759
+ # Upcast to avoid precision issues when computing prev_sample
760
+ sample = sample.to(torch.float32)
761
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"
762
+ ] and variance_noise is None:
763
+ noise = randn_tensor(
764
+ model_output.shape,
765
+ generator=generator,
766
+ device=model_output.device,
767
+ dtype=torch.float32)
768
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
769
+ noise = variance_noise.to(
770
+ device=model_output.device,
771
+ dtype=torch.float32) # pyright: ignore
772
+ else:
773
+ noise = None
774
+
775
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
776
+ prev_sample = self.dpm_solver_first_order_update(
777
+ model_output, sample=sample, noise=noise)
778
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
779
+ prev_sample = self.multistep_dpm_solver_second_order_update(
780
+ self.model_outputs, sample=sample, noise=noise)
781
+ else:
782
+ prev_sample = self.multistep_dpm_solver_third_order_update(
783
+ self.model_outputs, sample=sample)
784
+
785
+ if self.lower_order_nums < self.config.solver_order:
786
+ self.lower_order_nums += 1
787
+
788
+ # Cast sample back to expected dtype
789
+ prev_sample = prev_sample.to(model_output.dtype)
790
+
791
+ # upon completion increase step index by one
792
+ self._step_index += 1 # pyright: ignore
793
+
794
+ if not return_dict:
795
+ return (prev_sample,)
796
+
797
+ return SchedulerOutput(prev_sample=prev_sample)
798
+
799
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
800
+ def scale_model_input(self, sample: torch.Tensor, *args,
801
+ **kwargs) -> torch.Tensor:
802
+ """
803
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
804
+ current timestep.
805
+ Args:
806
+ sample (`torch.Tensor`):
807
+ The input sample.
808
+ Returns:
809
+ `torch.Tensor`:
810
+ A scaled input sample.
811
+ """
812
+ return sample
813
+
814
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
815
+ def add_noise(
816
+ self,
817
+ original_samples: torch.Tensor,
818
+ noise: torch.Tensor,
819
+ timesteps: torch.IntTensor,
820
+ ) -> torch.Tensor:
821
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
822
+ sigmas = self.sigmas.to(
823
+ device=original_samples.device, dtype=original_samples.dtype)
824
+ if original_samples.device.type == "mps" and torch.is_floating_point(
825
+ timesteps):
826
+ # mps does not support float64
827
+ schedule_timesteps = self.timesteps.to(
828
+ original_samples.device, dtype=torch.float32)
829
+ timesteps = timesteps.to(
830
+ original_samples.device, dtype=torch.float32)
831
+ else:
832
+ schedule_timesteps = self.timesteps.to(original_samples.device)
833
+ timesteps = timesteps.to(original_samples.device)
834
+
835
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
836
+ if self.begin_index is None:
837
+ step_indices = [
838
+ self.index_for_timestep(t, schedule_timesteps)
839
+ for t in timesteps
840
+ ]
841
+ elif self.step_index is not None:
842
+ # add_noise is called after first denoising step (for inpainting)
843
+ step_indices = [self.step_index] * timesteps.shape[0]
844
+ else:
845
+ # add noise is called before first denoising step to create initial latent(img2img)
846
+ step_indices = [self.begin_index] * timesteps.shape[0]
847
+
848
+ sigma = sigmas[step_indices].flatten()
849
+ while len(sigma.shape) < len(original_samples.shape):
850
+ sigma = sigma.unsqueeze(-1)
851
+
852
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
853
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
854
+ return noisy_samples
855
+
856
+ def __len__(self):
857
+ return self.config.num_train_timesteps
858
+
859
+
860
+ class FlowMatchScheduler():
861
+
862
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
863
+ self.num_train_timesteps = num_train_timesteps
864
+ self.shift = shift
865
+ self.sigma_max = sigma_max
866
+ self.sigma_min = sigma_min
867
+ self.inverse_timesteps = inverse_timesteps
868
+ self.extra_one_step = extra_one_step
869
+ self.reverse_sigmas = reverse_sigmas
870
+ self.set_timesteps(num_inference_steps)
871
+
872
+
873
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
874
+ if shift is not None:
875
+ self.shift = shift
876
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
877
+ if self.extra_one_step:
878
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
879
+ else:
880
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
881
+ if self.inverse_timesteps:
882
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
883
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
884
+ if self.reverse_sigmas:
885
+ self.sigmas = 1 - self.sigmas
886
+ self.timesteps = self.sigmas * self.num_train_timesteps
887
+ if training:
888
+ x = self.timesteps
889
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
890
+ y_shifted = y - y.min()
891
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
892
+ self.linear_timesteps_weights = bsmntw_weighing
893
+
894
+
895
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
896
+ if isinstance(timestep, torch.Tensor):
897
+ timestep = timestep.cpu()
898
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
899
+ sigma = self.sigmas[timestep_id]
900
+ if to_final or timestep_id + 1 >= len(self.timesteps):
901
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
902
+ else:
903
+ sigma_ = self.sigmas[timestep_id + 1]
904
+ prev_sample = sample + model_output * (sigma_ - sigma)
905
+ return prev_sample
906
+
907
+
908
+ def return_to_timestep(self, timestep, sample, sample_stablized):
909
+ if isinstance(timestep, torch.Tensor):
910
+ timestep = timestep.cpu()
911
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
912
+ sigma = self.sigmas[timestep_id]
913
+ model_output = (sample - sample_stablized) / sigma
914
+ return model_output
915
+
916
+
917
+ def add_noise(self, original_samples, noise, timestep):
918
+ if isinstance(timestep, torch.Tensor):
919
+ timestep = timestep.cpu()
920
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
921
+ sigma = self.sigmas[timestep_id]
922
+ sample = (1 - sigma) * original_samples + sigma * noise
923
+ return sample
924
+
925
+
926
+ def training_target(self, sample, noise, timestep):
927
+ target = noise - sample
928
+ return target
929
+
930
+
931
+ def training_weight(self, timestep):
932
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
933
+ weights = self.linear_timesteps_weights[timestep_id]
934
+ return weights
wan/utils/fm_solvers_unipc.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
2
+ # Convert unipc for flow matching
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+
5
+ import math
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
12
+ SchedulerMixin,
13
+ SchedulerOutput)
14
+ from diffusers.utils import deprecate, is_scipy_available
15
+
16
+ if is_scipy_available():
17
+ import scipy.stats
18
+
19
+
20
+ class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
21
+ """
22
+ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
23
+
24
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
25
+ methods the library implements for all schedulers such as loading and saving.
26
+
27
+ Args:
28
+ num_train_timesteps (`int`, defaults to 1000):
29
+ The number of diffusion steps to train the model.
30
+ solver_order (`int`, default `2`):
31
+ The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
32
+ due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
33
+ unconditional sampling.
34
+ prediction_type (`str`, defaults to "flow_prediction"):
35
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
36
+ the flow of the diffusion process.
37
+ thresholding (`bool`, defaults to `False`):
38
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
39
+ as Stable Diffusion.
40
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
41
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
42
+ sample_max_value (`float`, defaults to 1.0):
43
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
44
+ predict_x0 (`bool`, defaults to `True`):
45
+ Whether to use the updating algorithm on the predicted x0.
46
+ solver_type (`str`, default `bh2`):
47
+ Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
48
+ otherwise.
49
+ lower_order_final (`bool`, default `True`):
50
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
51
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
52
+ disable_corrector (`list`, default `[]`):
53
+ Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
54
+ and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
55
+ usually disabled during the first few steps.
56
+ solver_p (`SchedulerMixin`, default `None`):
57
+ Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
58
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
59
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
60
+ the sigmas are determined according to a sequence of noise levels {σi}.
61
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
62
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
63
+ timestep_spacing (`str`, defaults to `"linspace"`):
64
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
65
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
66
+ steps_offset (`int`, defaults to 0):
67
+ An offset added to the inference steps, as required by some model families.
68
+ final_sigmas_type (`str`, defaults to `"zero"`):
69
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
70
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
71
+ """
72
+
73
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
74
+ order = 1
75
+
76
+ @register_to_config
77
+ def __init__(
78
+ self,
79
+ num_train_timesteps: int = 1000,
80
+ solver_order: int = 2,
81
+ prediction_type: str = "flow_prediction",
82
+ shift: Optional[float] = 1.0,
83
+ use_dynamic_shifting=False,
84
+ thresholding: bool = False,
85
+ dynamic_thresholding_ratio: float = 0.995,
86
+ sample_max_value: float = 1.0,
87
+ predict_x0: bool = True,
88
+ solver_type: str = "bh2",
89
+ lower_order_final: bool = True,
90
+ disable_corrector: List[int] = [],
91
+ solver_p: SchedulerMixin = None,
92
+ timestep_spacing: str = "linspace",
93
+ steps_offset: int = 0,
94
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
95
+ ):
96
+
97
+ if solver_type not in ["bh1", "bh2"]:
98
+ if solver_type in ["midpoint", "heun", "logrho"]:
99
+ self.register_to_config(solver_type="bh2")
100
+ else:
101
+ raise NotImplementedError(
102
+ f"{solver_type} is not implemented for {self.__class__}")
103
+
104
+ self.predict_x0 = predict_x0
105
+ # setable values
106
+ self.num_inference_steps = None
107
+ alphas = np.linspace(1, 1 / num_train_timesteps,
108
+ num_train_timesteps)[::-1].copy()
109
+ sigmas = 1.0 - alphas
110
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
111
+
112
+ if not use_dynamic_shifting:
113
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
114
+ sigmas = shift * sigmas / (1 +
115
+ (shift - 1) * sigmas) # pyright: ignore
116
+
117
+ self.sigmas = sigmas
118
+ self.timesteps = sigmas * num_train_timesteps
119
+
120
+ self.model_outputs = [None] * solver_order
121
+ self.timestep_list = [None] * solver_order
122
+ self.lower_order_nums = 0
123
+ self.disable_corrector = disable_corrector
124
+ self.solver_p = solver_p
125
+ self.last_sample = None
126
+ self._step_index = None
127
+ self._begin_index = None
128
+
129
+ self.sigmas = self.sigmas.to(
130
+ "cpu") # to avoid too much CPU/GPU communication
131
+ self.sigma_min = self.sigmas[-1].item()
132
+ self.sigma_max = self.sigmas[0].item()
133
+
134
+ @property
135
+ def step_index(self):
136
+ """
137
+ The index counter for current timestep. It will increase 1 after each scheduler step.
138
+ """
139
+ return self._step_index
140
+
141
+ @property
142
+ def begin_index(self):
143
+ """
144
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
145
+ """
146
+ return self._begin_index
147
+
148
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
149
+ def set_begin_index(self, begin_index: int = 0):
150
+ """
151
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
152
+
153
+ Args:
154
+ begin_index (`int`):
155
+ The begin index for the scheduler.
156
+ """
157
+ self._begin_index = begin_index
158
+
159
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
160
+ def set_timesteps(
161
+ self,
162
+ num_inference_steps: Union[int, None] = None,
163
+ device: Union[str, torch.device] = None,
164
+ sigmas: Optional[List[float]] = None,
165
+ mu: Optional[Union[float, None]] = None,
166
+ shift: Optional[Union[float, None]] = None,
167
+ ):
168
+ """
169
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
170
+ Args:
171
+ num_inference_steps (`int`):
172
+ Total number of the spacing of the time steps.
173
+ device (`str` or `torch.device`, *optional*):
174
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
175
+ """
176
+
177
+ if self.config.use_dynamic_shifting and mu is None:
178
+ raise ValueError(
179
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
180
+ )
181
+
182
+ if sigmas is None:
183
+ sigmas = np.linspace(self.sigma_max, self.sigma_min,
184
+ num_inference_steps +
185
+ 1).copy()[:-1] # pyright: ignore
186
+
187
+ if self.config.use_dynamic_shifting:
188
+ sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
189
+ else:
190
+ if shift is None:
191
+ shift = self.config.shift
192
+ sigmas = shift * sigmas / (1 +
193
+ (shift - 1) * sigmas) # pyright: ignore
194
+
195
+ if self.config.final_sigmas_type == "sigma_min":
196
+ sigma_last = ((1 - self.alphas_cumprod[0]) /
197
+ self.alphas_cumprod[0])**0.5
198
+ elif self.config.final_sigmas_type == "zero":
199
+ sigma_last = 0
200
+ else:
201
+ raise ValueError(
202
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
203
+ )
204
+
205
+ timesteps = sigmas * self.config.num_train_timesteps
206
+ sigmas = np.concatenate([sigmas, [sigma_last]
207
+ ]).astype(np.float32) # pyright: ignore
208
+
209
+ self.sigmas = torch.from_numpy(sigmas)
210
+ self.timesteps = torch.from_numpy(timesteps).to(
211
+ device=device, dtype=torch.int64)
212
+
213
+ self.num_inference_steps = len(timesteps)
214
+
215
+ self.model_outputs = [
216
+ None,
217
+ ] * self.config.solver_order
218
+ self.lower_order_nums = 0
219
+ self.last_sample = None
220
+ if self.solver_p:
221
+ self.solver_p.set_timesteps(self.num_inference_steps, device=device)
222
+
223
+ # add an index counter for schedulers that allow duplicated timesteps
224
+ self._step_index = None
225
+ self._begin_index = None
226
+ self.sigmas = self.sigmas.to(
227
+ "cpu") # to avoid too much CPU/GPU communication
228
+
229
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
230
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
231
+ """
232
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
233
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
234
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
235
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
236
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
237
+
238
+ https://arxiv.org/abs/2205.11487
239
+ """
240
+ dtype = sample.dtype
241
+ batch_size, channels, *remaining_dims = sample.shape
242
+
243
+ if dtype not in (torch.float32, torch.float64):
244
+ sample = sample.float(
245
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
246
+
247
+ # Flatten sample for doing quantile calculation along each image
248
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
249
+
250
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
251
+
252
+ s = torch.quantile(
253
+ abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
254
+ s = torch.clamp(
255
+ s, min=1, max=self.config.sample_max_value
256
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
257
+ s = s.unsqueeze(
258
+ 1) # (batch_size, 1) because clamp will broadcast along dim=0
259
+ sample = torch.clamp(
260
+ sample, -s, s
261
+ ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
262
+
263
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
264
+ sample = sample.to(dtype)
265
+
266
+ return sample
267
+
268
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
269
+ def _sigma_to_t(self, sigma):
270
+ return sigma * self.config.num_train_timesteps
271
+
272
+ def _sigma_to_alpha_sigma_t(self, sigma):
273
+ return 1 - sigma, sigma
274
+
275
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
276
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
277
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
278
+
279
+ def convert_model_output(
280
+ self,
281
+ model_output: torch.Tensor,
282
+ *args,
283
+ sample: torch.Tensor = None,
284
+ **kwargs,
285
+ ) -> torch.Tensor:
286
+ r"""
287
+ Convert the model output to the corresponding type the UniPC algorithm needs.
288
+
289
+ Args:
290
+ model_output (`torch.Tensor`):
291
+ The direct output from the learned diffusion model.
292
+ timestep (`int`):
293
+ The current discrete timestep in the diffusion chain.
294
+ sample (`torch.Tensor`):
295
+ A current instance of a sample created by the diffusion process.
296
+
297
+ Returns:
298
+ `torch.Tensor`:
299
+ The converted model output.
300
+ """
301
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
302
+ if sample is None:
303
+ if len(args) > 1:
304
+ sample = args[1]
305
+ else:
306
+ raise ValueError(
307
+ "missing `sample` as a required keyward argument")
308
+ if timestep is not None:
309
+ deprecate(
310
+ "timesteps",
311
+ "1.0.0",
312
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
313
+ )
314
+
315
+ sigma = self.sigmas[self.step_index]
316
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
317
+
318
+ print("sigma_t ==>", self.step_index, sigma, sigma_t, alpha_t, sample.shape, model_output.shape)
319
+ if self.predict_x0:
320
+ if self.config.prediction_type == "flow_prediction":
321
+ sigma_t = self.sigmas[self.step_index]
322
+ x0_pred = sample - sigma_t * model_output
323
+ else:
324
+ raise ValueError(
325
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
326
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
327
+ )
328
+
329
+ if self.config.thresholding:
330
+ x0_pred = self._threshold_sample(x0_pred)
331
+ print("self.config.thresholding", self.config.thresholding)
332
+ return x0_pred
333
+ else:
334
+ if self.config.prediction_type == "flow_prediction":
335
+ sigma_t = self.sigmas[self.step_index]
336
+ epsilon = sample - (1 - sigma_t) * model_output
337
+ else:
338
+ raise ValueError(
339
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
340
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
341
+ )
342
+
343
+ if self.config.thresholding:
344
+ sigma_t = self.sigmas[self.step_index]
345
+ x0_pred = sample - sigma_t * model_output
346
+ x0_pred = self._threshold_sample(x0_pred)
347
+ epsilon = model_output + x0_pred
348
+
349
+ return epsilon
350
+
351
+ def multistep_uni_p_bh_update(
352
+ self,
353
+ model_output: torch.Tensor,
354
+ *args,
355
+ sample: torch.Tensor = None,
356
+ order: int = None, # pyright: ignore
357
+ **kwargs,
358
+ ) -> torch.Tensor:
359
+ """
360
+ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
361
+
362
+ Args:
363
+ model_output (`torch.Tensor`):
364
+ The direct output from the learned diffusion model at the current timestep.
365
+ prev_timestep (`int`):
366
+ The previous discrete timestep in the diffusion chain.
367
+ sample (`torch.Tensor`):
368
+ A current instance of a sample created by the diffusion process.
369
+ order (`int`):
370
+ The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
371
+
372
+ Returns:
373
+ `torch.Tensor`:
374
+ The sample tensor at the previous timestep.
375
+ """
376
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop(
377
+ "prev_timestep", None)
378
+ if sample is None:
379
+ if len(args) > 1:
380
+ sample = args[1]
381
+ else:
382
+ raise ValueError(
383
+ " missing `sample` as a required keyward argument")
384
+ if order is None:
385
+ if len(args) > 2:
386
+ order = args[2]
387
+ else:
388
+ raise ValueError(
389
+ " missing `order` as a required keyward argument")
390
+ if prev_timestep is not None:
391
+ deprecate(
392
+ "prev_timestep",
393
+ "1.0.0",
394
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
395
+ )
396
+ model_output_list = self.model_outputs
397
+
398
+ s0 = self.timestep_list[-1]
399
+ m0 = model_output_list[-1]
400
+ x = sample
401
+
402
+ if self.solver_p:
403
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
404
+ return x_t
405
+
406
+ sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[
407
+ self.step_index] # pyright: ignore
408
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
409
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
410
+
411
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
412
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
413
+
414
+ h = lambda_t - lambda_s0
415
+ device = sample.device
416
+
417
+ rks = []
418
+ D1s = []
419
+ for i in range(1, order):
420
+ si = self.step_index - i # pyright: ignore
421
+ mi = model_output_list[-(i + 1)]
422
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
423
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
424
+ rk = (lambda_si - lambda_s0) / h
425
+ rks.append(rk)
426
+ D1s.append((mi - m0) / rk) # pyright: ignore
427
+
428
+ rks.append(1.0)
429
+ rks = torch.tensor(rks, device=device)
430
+
431
+ R = []
432
+ b = []
433
+
434
+ hh = -h if self.predict_x0 else h
435
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
436
+ h_phi_k = h_phi_1 / hh - 1
437
+
438
+ factorial_i = 1
439
+
440
+ if self.config.solver_type == "bh1":
441
+ B_h = hh
442
+ elif self.config.solver_type == "bh2":
443
+ B_h = torch.expm1(hh)
444
+ else:
445
+ raise NotImplementedError()
446
+
447
+ for i in range(1, order + 1):
448
+ R.append(torch.pow(rks, i - 1))
449
+ b.append(h_phi_k * factorial_i / B_h)
450
+ factorial_i *= i + 1
451
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
452
+
453
+ R = torch.stack(R)
454
+ b = torch.tensor(b, device=device)
455
+
456
+ if len(D1s) > 0:
457
+ D1s = torch.stack(D1s, dim=1) # (B, K)
458
+ # for order 2, we use a simplified version
459
+ if order == 2:
460
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
461
+ else:
462
+ rhos_p = torch.linalg.solve(R[:-1, :-1],
463
+ b[:-1]).to(device).to(x.dtype)
464
+ else:
465
+ D1s = None
466
+
467
+ if self.predict_x0:
468
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
469
+ if D1s is not None:
470
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
471
+ D1s) # pyright: ignore
472
+ else:
473
+ pred_res = 0
474
+ x_t = x_t_ - alpha_t * B_h * pred_res
475
+ else:
476
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
477
+ if D1s is not None:
478
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
479
+ D1s) # pyright: ignore
480
+ else:
481
+ pred_res = 0
482
+ x_t = x_t_ - sigma_t * B_h * pred_res
483
+
484
+ x_t = x_t.to(x.dtype)
485
+ return x_t
486
+
487
+ def multistep_uni_c_bh_update(
488
+ self,
489
+ this_model_output: torch.Tensor,
490
+ *args,
491
+ last_sample: torch.Tensor = None,
492
+ this_sample: torch.Tensor = None,
493
+ order: int = None, # pyright: ignore
494
+ **kwargs,
495
+ ) -> torch.Tensor:
496
+ """
497
+ One step for the UniC (B(h) version).
498
+
499
+ Args:
500
+ this_model_output (`torch.Tensor`):
501
+ The model outputs at `x_t`.
502
+ this_timestep (`int`):
503
+ The current timestep `t`.
504
+ last_sample (`torch.Tensor`):
505
+ The generated sample before the last predictor `x_{t-1}`.
506
+ this_sample (`torch.Tensor`):
507
+ The generated sample after the last predictor `x_{t}`.
508
+ order (`int`):
509
+ The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
510
+
511
+ Returns:
512
+ `torch.Tensor`:
513
+ The corrected sample tensor at the current timestep.
514
+ """
515
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop(
516
+ "this_timestep", None)
517
+ if last_sample is None:
518
+ if len(args) > 1:
519
+ last_sample = args[1]
520
+ else:
521
+ raise ValueError(
522
+ " missing`last_sample` as a required keyward argument")
523
+ if this_sample is None:
524
+ if len(args) > 2:
525
+ this_sample = args[2]
526
+ else:
527
+ raise ValueError(
528
+ " missing`this_sample` as a required keyward argument")
529
+ if order is None:
530
+ if len(args) > 3:
531
+ order = args[3]
532
+ else:
533
+ raise ValueError(
534
+ " missing`order` as a required keyward argument")
535
+ if this_timestep is not None:
536
+ deprecate(
537
+ "this_timestep",
538
+ "1.0.0",
539
+ "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
540
+ )
541
+
542
+ model_output_list = self.model_outputs
543
+
544
+ m0 = model_output_list[-1]
545
+ x = last_sample
546
+ x_t = this_sample
547
+ model_t = this_model_output
548
+
549
+ sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[
550
+ self.step_index - 1] # pyright: ignore
551
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
552
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
553
+
554
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
555
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
556
+
557
+ h = lambda_t - lambda_s0
558
+ device = this_sample.device
559
+
560
+ rks = []
561
+ D1s = []
562
+ for i in range(1, order):
563
+ si = self.step_index - (i + 1) # pyright: ignore
564
+ mi = model_output_list[-(i + 1)]
565
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
566
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
567
+ rk = (lambda_si - lambda_s0) / h
568
+ rks.append(rk)
569
+ D1s.append((mi - m0) / rk) # pyright: ignore
570
+
571
+ rks.append(1.0)
572
+ rks = torch.tensor(rks, device=device)
573
+
574
+ R = []
575
+ b = []
576
+
577
+ hh = -h if self.predict_x0 else h
578
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
579
+ h_phi_k = h_phi_1 / hh - 1
580
+
581
+ factorial_i = 1
582
+
583
+ if self.config.solver_type == "bh1":
584
+ B_h = hh
585
+ elif self.config.solver_type == "bh2":
586
+ B_h = torch.expm1(hh)
587
+ else:
588
+ raise NotImplementedError()
589
+
590
+ for i in range(1, order + 1):
591
+ R.append(torch.pow(rks, i - 1))
592
+ b.append(h_phi_k * factorial_i / B_h)
593
+ factorial_i *= i + 1
594
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
595
+
596
+ R = torch.stack(R)
597
+ b = torch.tensor(b, device=device)
598
+
599
+ if len(D1s) > 0:
600
+ D1s = torch.stack(D1s, dim=1)
601
+ else:
602
+ D1s = None
603
+
604
+ # for order 1, we use a simplified version
605
+ if order == 1:
606
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
607
+ else:
608
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
609
+
610
+ if self.predict_x0:
611
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
612
+ if D1s is not None:
613
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
614
+ else:
615
+ corr_res = 0
616
+ D1_t = model_t - m0
617
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
618
+ else:
619
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
620
+ if D1s is not None:
621
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
622
+ else:
623
+ corr_res = 0
624
+ D1_t = model_t - m0
625
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
626
+ x_t = x_t.to(x.dtype)
627
+ return x_t
628
+
629
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
630
+ if schedule_timesteps is None:
631
+ schedule_timesteps = self.timesteps
632
+
633
+ indices = (schedule_timesteps == timestep).nonzero()
634
+
635
+ # The sigma index that is taken for the **very** first `step`
636
+ # is always the second index (or the last index if there is only 1)
637
+ # This way we can ensure we don't accidentally skip a sigma in
638
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
639
+ pos = 1 if len(indices) > 1 else 0
640
+
641
+ return indices[pos].item()
642
+
643
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
644
+ def _init_step_index(self, timestep):
645
+ """
646
+ Initialize the step_index counter for the scheduler.
647
+ """
648
+
649
+ if self.begin_index is None:
650
+ if isinstance(timestep, torch.Tensor):
651
+ timestep = timestep.to(self.timesteps.device)
652
+ self._step_index = self.index_for_timestep(timestep)
653
+ else:
654
+ self._step_index = self._begin_index
655
+
656
+ def step(self,
657
+ model_output: torch.Tensor,
658
+ timestep: Union[int, torch.Tensor],
659
+ sample: torch.Tensor,
660
+ return_dict: bool = True,
661
+ generator=None) -> Union[SchedulerOutput, Tuple]:
662
+ """
663
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
664
+ the multistep UniPC.
665
+
666
+ Args:
667
+ model_output (`torch.Tensor`):
668
+ The direct output from learned diffusion model.
669
+ timestep (`int`):
670
+ The current discrete timestep in the diffusion chain.
671
+ sample (`torch.Tensor`):
672
+ A current instance of a sample created by the diffusion process.
673
+ return_dict (`bool`):
674
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
675
+
676
+ Returns:
677
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
678
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
679
+ tuple is returned where the first element is the sample tensor.
680
+
681
+ """
682
+ if self.num_inference_steps is None:
683
+ raise ValueError(
684
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
685
+ )
686
+
687
+ if self.step_index is None:
688
+ self._init_step_index(timestep)
689
+
690
+ print("self.step_index ==> ", self.step_index)
691
+
692
+ use_corrector = (
693
+ self.step_index > 0 and
694
+ self.step_index - 1 not in self.disable_corrector and
695
+ self.last_sample is not None # pyright: ignore
696
+ )
697
+
698
+ model_output_convert = self.convert_model_output(model_output, sample=sample)
699
+
700
+ if use_corrector:
701
+ sample = self.multistep_uni_c_bh_update(
702
+ this_model_output=model_output_convert,
703
+ last_sample=self.last_sample,
704
+ this_sample=sample,
705
+ order=self.this_order,
706
+ )
707
+
708
+ for i in range(self.config.solver_order - 1):
709
+ self.model_outputs[i] = self.model_outputs[i + 1]
710
+ self.timestep_list[i] = self.timestep_list[i + 1]
711
+
712
+ self.model_outputs[-1] = model_output_convert
713
+ self.timestep_list[-1] = timestep # pyright: ignore
714
+
715
+ if self.config.lower_order_final:
716
+ this_order = min(self.config.solver_order,
717
+ len(self.timesteps) -
718
+ self.step_index) # pyright: ignore
719
+ else:
720
+ this_order = self.config.solver_order
721
+
722
+ self.this_order = min(this_order,
723
+ self.lower_order_nums + 1) # warmup for multistep
724
+ assert self.this_order > 0
725
+
726
+ self.last_sample = sample
727
+ prev_sample = self.multistep_uni_p_bh_update(
728
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
729
+ sample=sample,
730
+ order=self.this_order,
731
+ )
732
+
733
+ if self.lower_order_nums < self.config.solver_order:
734
+ self.lower_order_nums += 1
735
+
736
+ # upon completion increase step index by one
737
+ self._step_index += 1 # pyright: ignore
738
+
739
+ if not return_dict:
740
+ return (prev_sample, model_output_convert)
741
+
742
+ return SchedulerOutput(prev_sample=prev_sample)
743
+
744
+ def scale_model_input(self, sample: torch.Tensor, *args,
745
+ **kwargs) -> torch.Tensor:
746
+ """
747
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
748
+ current timestep.
749
+
750
+ Args:
751
+ sample (`torch.Tensor`):
752
+ The input sample.
753
+
754
+ Returns:
755
+ `torch.Tensor`:
756
+ A scaled input sample.
757
+ """
758
+ return sample
759
+
760
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
761
+ def add_noise(
762
+ self,
763
+ original_samples: torch.Tensor,
764
+ noise: torch.Tensor,
765
+ timesteps: torch.IntTensor,
766
+ ) -> torch.Tensor:
767
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
768
+ sigmas = self.sigmas.to(
769
+ device=original_samples.device, dtype=original_samples.dtype)
770
+ if original_samples.device.type == "mps" and torch.is_floating_point(
771
+ timesteps):
772
+ # mps does not support float64
773
+ schedule_timesteps = self.timesteps.to(
774
+ original_samples.device, dtype=torch.float32)
775
+ timesteps = timesteps.to(
776
+ original_samples.device, dtype=torch.float32)
777
+ else:
778
+ schedule_timesteps = self.timesteps.to(original_samples.device)
779
+ timesteps = timesteps.to(original_samples.device)
780
+
781
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
782
+ if self.begin_index is None:
783
+ step_indices = [
784
+ self.index_for_timestep(t, schedule_timesteps)
785
+ for t in timesteps
786
+ ]
787
+ elif self.step_index is not None:
788
+ # add_noise is called after first denoising step (for inpainting)
789
+ step_indices = [self.step_index] * timesteps.shape[0]
790
+ else:
791
+ # add noise is called before first denoising step to create initial latent(img2img)
792
+ step_indices = [self.begin_index] * timesteps.shape[0]
793
+
794
+ sigma = sigmas[step_indices].flatten()
795
+ while len(sigma.shape) < len(original_samples.shape):
796
+ sigma = sigma.unsqueeze(-1)
797
+
798
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
799
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
800
+ return noisy_samples
801
+
802
+ def __len__(self):
803
+ return self.config.num_train_timesteps
wan/utils/prompt_extend.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import json
3
+ import math
4
+ import os
5
+ import random
6
+ import sys
7
+ import tempfile
8
+ from dataclasses import dataclass
9
+ from http import HTTPStatus
10
+ from typing import Optional, Union
11
+
12
+ import dashscope
13
+ import torch
14
+ from PIL import Image
15
+
16
+ try:
17
+ from flash_attn import flash_attn_varlen_func
18
+ FLASH_VER = 2
19
+ except ModuleNotFoundError:
20
+ flash_attn_varlen_func = None # in compatible with CPU machines
21
+ FLASH_VER = None
22
+
23
+ LM_EN_SYS_PROMPT = "You are an advanced AI model tasked with generating and extending structured and detailed video captions. You must respond in the language used by the user."
24
+
25
+ @dataclass
26
+ class PromptOutput(object):
27
+ status: bool
28
+ prompt: str
29
+ seed: int
30
+ system_prompt: str
31
+ message: str
32
+
33
+ def add_custom_field(self, key: str, value) -> None:
34
+ self.__setattr__(key, value)
35
+
36
+
37
+ class PromptExpander:
38
+
39
+ def __init__(self, model_name, is_vl=False, device=0, **kwargs):
40
+ self.model_name = model_name
41
+ self.is_vl = is_vl
42
+ self.device = device
43
+
44
+ def extend_with_img(self,
45
+ prompt,
46
+ system_prompt,
47
+ image=None,
48
+ seed=-1,
49
+ *args,
50
+ **kwargs):
51
+ pass
52
+
53
+ def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
54
+ pass
55
+
56
+ def decide_system_prompt(self, tar_lang="en"):
57
+ return LM_EN_SYS_PROMPT
58
+
59
+ def __call__(self,
60
+ prompt,
61
+ tar_lang="en",
62
+ image=None,
63
+ seed=-1,
64
+ *args,
65
+ **kwargs):
66
+ system_prompt = self.decide_system_prompt(tar_lang=tar_lang)
67
+ if seed < 0:
68
+ seed = random.randint(0, sys.maxsize)
69
+ if image is not None and self.is_vl:
70
+ return self.extend_with_img(
71
+ prompt, system_prompt, image=image, seed=seed, *args, **kwargs)
72
+ elif not self.is_vl:
73
+ return self.extend(prompt, system_prompt, seed, *args, **kwargs)
74
+ else:
75
+ raise NotImplementedError
76
+
77
+
78
+ class QwenPromptExpander(PromptExpander):
79
+
80
+ def __init__(self, model_name=None, device=0, is_vl=False, **kwargs):
81
+ '''
82
+ Args:
83
+ model_name: Use predefined model names such as 'QwenVL2.5_7B' and 'Qwen2.5_14B',
84
+ which are specific versions of the Qwen model. Alternatively, you can use the
85
+ local path to a downloaded model or the model name from Hugging Face."
86
+ Detailed Breakdown:
87
+ Predefined Model Names:
88
+ * 'QwenVL2.5_7B' and 'Qwen2.5_14B' are specific versions of the Qwen model.
89
+ Local Path:
90
+ * You can provide the path to a model that you have downloaded locally.
91
+ Hugging Face Model Name:
92
+ * You can also specify the model name from Hugging Face's model hub.
93
+ is_vl: A flag indicating whether the task involves visual-language processing.
94
+ **kwargs: Additional keyword arguments that can be passed to the function or method.
95
+ '''
96
+ if model_name is None:
97
+ model_name = 'ZuluVision/MoviiGen1.1_Prompt_Rewriter'
98
+ super().__init__(model_name, is_vl, device, **kwargs)
99
+ self.model_name = model_name
100
+
101
+ if self.is_vl:
102
+ raise NotImplementedError("VL is not supported")
103
+
104
+ from transformers import AutoModelForCausalLM, AutoTokenizer
105
+ self.model = AutoModelForCausalLM.from_pretrained(
106
+ self.model_name,
107
+ torch_dtype=torch.float16
108
+ if "AWQ" in self.model_name else "auto",
109
+ attn_implementation="flash_attention_2"
110
+ if FLASH_VER == 2 else None,
111
+ device_map="cpu")
112
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
113
+
114
+ def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
115
+ self.model = self.model.to(self.device)
116
+ messages = [{
117
+ "role": "system",
118
+ "content": system_prompt
119
+ }, {
120
+ "role": "user",
121
+ "content": prompt
122
+ }]
123
+ text = self.tokenizer.apply_chat_template(
124
+ messages, tokenize=False, add_generation_prompt=True)
125
+ model_inputs = self.tokenizer([text],
126
+ return_tensors="pt").to(self.model.device)
127
+
128
+ generated_ids = self.model.generate(**model_inputs, max_new_tokens=512)
129
+ generated_ids = [
130
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(
131
+ model_inputs.input_ids, generated_ids)
132
+ ]
133
+
134
+ expanded_prompt = self.tokenizer.batch_decode(
135
+ generated_ids, skip_special_tokens=True)[0]
136
+ self.model = self.model.to("cpu")
137
+ return PromptOutput(
138
+ status=True,
139
+ prompt=expanded_prompt,
140
+ seed=seed,
141
+ system_prompt=system_prompt,
142
+ message=json.dumps({"content": expanded_prompt},
143
+ ensure_ascii=False))
144
+
145
+ def extend_with_img(self,
146
+ prompt,
147
+ system_prompt,
148
+ image: Union[Image.Image, str] = None,
149
+ seed=-1,
150
+ *args,
151
+ **kwargs):
152
+ self.model = self.model.to(self.device)
153
+ messages = [{
154
+ 'role': 'system',
155
+ 'content': [{
156
+ "type": "text",
157
+ "text": system_prompt
158
+ }]
159
+ }, {
160
+ "role":
161
+ "user",
162
+ "content": [
163
+ {
164
+ "type": "image",
165
+ "image": image,
166
+ },
167
+ {
168
+ "type": "text",
169
+ "text": prompt
170
+ },
171
+ ],
172
+ }]
173
+
174
+ # Preparation for inference
175
+ text = self.processor.apply_chat_template(
176
+ messages, tokenize=False, add_generation_prompt=True)
177
+ image_inputs, video_inputs = self.process_vision_info(messages)
178
+ inputs = self.processor(
179
+ text=[text],
180
+ images=image_inputs,
181
+ videos=video_inputs,
182
+ padding=True,
183
+ return_tensors="pt",
184
+ )
185
+ inputs = inputs.to(self.device)
186
+
187
+ # Inference: Generation of the output
188
+ generated_ids = self.model.generate(**inputs, max_new_tokens=512)
189
+ generated_ids_trimmed = [
190
+ out_ids[len(in_ids):]
191
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
192
+ ]
193
+ expanded_prompt = self.processor.batch_decode(
194
+ generated_ids_trimmed,
195
+ skip_special_tokens=True,
196
+ clean_up_tokenization_spaces=False)[0]
197
+ self.model = self.model.to("cpu")
198
+ return PromptOutput(
199
+ status=True,
200
+ prompt=expanded_prompt,
201
+ seed=seed,
202
+ system_prompt=system_prompt,
203
+ message=json.dumps({"content": expanded_prompt},
204
+ ensure_ascii=False))
205
+
206
+
207
+ if __name__ == "__main__":
208
+
209
+ seed = 100
210
+ prompt = "夏日海滩度假风格,一只戴着墨镜的白色猫咪坐在冲浪板上。猫咪毛发蓬松,表情悠闲,直视镜头。背景是模糊的海滩景色,海水清澈,远处有绿色的山丘和蓝天白云。猫咪的姿态自然放松,仿佛在享受海风和阳光。近景特写,强调猫咪的细节和海滩的清新氛围。"
211
+ en_prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
212
+ # test cases for prompt extend
213
+ ds_model_name = "qwen-plus"
214
+ # for qwenmodel, you can download the model form modelscope or huggingface and use the model path as model_name
215
+ qwen_model_name = "./models/Qwen2.5-14B-Instruct/" # VRAM: 29136MiB
216
+ # qwen_model_name = "./models/Qwen2.5-14B-Instruct-AWQ/" # VRAM: 10414MiB
217
+
218
+ # test dashscope api
219
+ dashscope_prompt_expander = DashScopePromptExpander(
220
+ model_name=ds_model_name)
221
+ dashscope_result = dashscope_prompt_expander(prompt, tar_lang="ch")
222
+ print("LM dashscope result -> ch",
223
+ dashscope_result.prompt) #dashscope_result.system_prompt)
224
+ dashscope_result = dashscope_prompt_expander(prompt, tar_lang="en")
225
+ print("LM dashscope result -> en",
226
+ dashscope_result.prompt) #dashscope_result.system_prompt)
227
+ dashscope_result = dashscope_prompt_expander(en_prompt, tar_lang="ch")
228
+ print("LM dashscope en result -> ch",
229
+ dashscope_result.prompt) #dashscope_result.system_prompt)
230
+ dashscope_result = dashscope_prompt_expander(en_prompt, tar_lang="en")
231
+ print("LM dashscope en result -> en",
232
+ dashscope_result.prompt) #dashscope_result.system_prompt)
233
+ # # test qwen api
234
+ qwen_prompt_expander = QwenPromptExpander(
235
+ model_name=qwen_model_name, is_vl=False, device=0)
236
+ qwen_result = qwen_prompt_expander(prompt, tar_lang="ch")
237
+ print("LM qwen result -> ch",
238
+ qwen_result.prompt) #qwen_result.system_prompt)
239
+ qwen_result = qwen_prompt_expander(prompt, tar_lang="en")
240
+ print("LM qwen result -> en",
241
+ qwen_result.prompt) # qwen_result.system_prompt)
242
+ qwen_result = qwen_prompt_expander(en_prompt, tar_lang="ch")
243
+ print("LM qwen en result -> ch",
244
+ qwen_result.prompt) #, qwen_result.system_prompt)
245
+ qwen_result = qwen_prompt_expander(en_prompt, tar_lang="en")
246
+ print("LM qwen en result -> en",
247
+ qwen_result.prompt) # , qwen_result.system_prompt)
248
+ # test case for prompt-image extend
249
+ ds_model_name = "qwen-vl-max"
250
+ #qwen_model_name = "./models/Qwen2.5-VL-3B-Instruct/" #VRAM: 9686MiB
251
+ qwen_model_name = "./models/Qwen2.5-VL-7B-Instruct-AWQ/" # VRAM: 8492
252
+ image = "./examples/i2v_input.JPG"
253
+
254
+ # test dashscope api why image_path is local directory; skip
255
+ dashscope_prompt_expander = DashScopePromptExpander(
256
+ model_name=ds_model_name, is_vl=True)
257
+ dashscope_result = dashscope_prompt_expander(
258
+ prompt, tar_lang="ch", image=image, seed=seed)
259
+ print("VL dashscope result -> ch",
260
+ dashscope_result.prompt) #, dashscope_result.system_prompt)
261
+ dashscope_result = dashscope_prompt_expander(
262
+ prompt, tar_lang="en", image=image, seed=seed)
263
+ print("VL dashscope result -> en",
264
+ dashscope_result.prompt) # , dashscope_result.system_prompt)
265
+ dashscope_result = dashscope_prompt_expander(
266
+ en_prompt, tar_lang="ch", image=image, seed=seed)
267
+ print("VL dashscope en result -> ch",
268
+ dashscope_result.prompt) #, dashscope_result.system_prompt)
269
+ dashscope_result = dashscope_prompt_expander(
270
+ en_prompt, tar_lang="en", image=image, seed=seed)
271
+ print("VL dashscope en result -> en",
272
+ dashscope_result.prompt) # , dashscope_result.system_prompt)
273
+ # test qwen api
274
+ qwen_prompt_expander = QwenPromptExpander(
275
+ model_name=qwen_model_name, is_vl=True, device=0)
276
+ qwen_result = qwen_prompt_expander(
277
+ prompt, tar_lang="ch", image=image, seed=seed)
278
+ print("VL qwen result -> ch",
279
+ qwen_result.prompt) #, qwen_result.system_prompt)
280
+ qwen_result = qwen_prompt_expander(
281
+ prompt, tar_lang="en", image=image, seed=seed)
282
+ print("VL qwen result ->en",
283
+ qwen_result.prompt) # , qwen_result.system_prompt)
284
+ qwen_result = qwen_prompt_expander(
285
+ en_prompt, tar_lang="ch", image=image, seed=seed)
286
+ print("VL qwen vl en result -> ch",
287
+ qwen_result.prompt) #, qwen_result.system_prompt)
288
+ qwen_result = qwen_prompt_expander(
289
+ en_prompt, tar_lang="en", image=image, seed=seed)
290
+ print("VL qwen vl en result -> en",
291
+ qwen_result.prompt) # , qwen_result.system_prompt)
wan/utils/qwen_vl_utils.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kq-chen/qwen-vl-utils
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import logging
7
+ import math
8
+ import os
9
+ import sys
10
+ import time
11
+ import warnings
12
+ from functools import lru_cache
13
+ from io import BytesIO
14
+
15
+ import requests
16
+ import torch
17
+ import torchvision
18
+ from packaging import version
19
+ from PIL import Image
20
+ from torchvision import io, transforms
21
+ from torchvision.transforms import InterpolationMode
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ IMAGE_FACTOR = 28
26
+ MIN_PIXELS = 4 * 28 * 28
27
+ MAX_PIXELS = 16384 * 28 * 28
28
+ MAX_RATIO = 200
29
+
30
+ VIDEO_MIN_PIXELS = 128 * 28 * 28
31
+ VIDEO_MAX_PIXELS = 768 * 28 * 28
32
+ VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
33
+ FRAME_FACTOR = 2
34
+ FPS = 2.0
35
+ FPS_MIN_FRAMES = 4
36
+ FPS_MAX_FRAMES = 768
37
+
38
+
39
+ def round_by_factor(number: int, factor: int) -> int:
40
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
41
+ return round(number / factor) * factor
42
+
43
+
44
+ def ceil_by_factor(number: int, factor: int) -> int:
45
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
46
+ return math.ceil(number / factor) * factor
47
+
48
+
49
+ def floor_by_factor(number: int, factor: int) -> int:
50
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
51
+ return math.floor(number / factor) * factor
52
+
53
+
54
+ def smart_resize(height: int,
55
+ width: int,
56
+ factor: int = IMAGE_FACTOR,
57
+ min_pixels: int = MIN_PIXELS,
58
+ max_pixels: int = MAX_PIXELS) -> tuple[int, int]:
59
+ """
60
+ Rescales the image so that the following conditions are met:
61
+
62
+ 1. Both dimensions (height and width) are divisible by 'factor'.
63
+
64
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
65
+
66
+ 3. The aspect ratio of the image is maintained as closely as possible.
67
+ """
68
+ if max(height, width) / min(height, width) > MAX_RATIO:
69
+ raise ValueError(
70
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
71
+ )
72
+ h_bar = max(factor, round_by_factor(height, factor))
73
+ w_bar = max(factor, round_by_factor(width, factor))
74
+ if h_bar * w_bar > max_pixels:
75
+ beta = math.sqrt((height * width) / max_pixels)
76
+ h_bar = floor_by_factor(height / beta, factor)
77
+ w_bar = floor_by_factor(width / beta, factor)
78
+ elif h_bar * w_bar < min_pixels:
79
+ beta = math.sqrt(min_pixels / (height * width))
80
+ h_bar = ceil_by_factor(height * beta, factor)
81
+ w_bar = ceil_by_factor(width * beta, factor)
82
+ return h_bar, w_bar
83
+
84
+
85
+ def fetch_image(ele: dict[str, str | Image.Image],
86
+ size_factor: int = IMAGE_FACTOR) -> Image.Image:
87
+ if "image" in ele:
88
+ image = ele["image"]
89
+ else:
90
+ image = ele["image_url"]
91
+ image_obj = None
92
+ if isinstance(image, Image.Image):
93
+ image_obj = image
94
+ elif image.startswith("http://") or image.startswith("https://"):
95
+ image_obj = Image.open(requests.get(image, stream=True).raw)
96
+ elif image.startswith("file://"):
97
+ image_obj = Image.open(image[7:])
98
+ elif image.startswith("data:image"):
99
+ if "base64," in image:
100
+ _, base64_data = image.split("base64,", 1)
101
+ data = base64.b64decode(base64_data)
102
+ image_obj = Image.open(BytesIO(data))
103
+ else:
104
+ image_obj = Image.open(image)
105
+ if image_obj is None:
106
+ raise ValueError(
107
+ f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
108
+ )
109
+ image = image_obj.convert("RGB")
110
+ ## resize
111
+ if "resized_height" in ele and "resized_width" in ele:
112
+ resized_height, resized_width = smart_resize(
113
+ ele["resized_height"],
114
+ ele["resized_width"],
115
+ factor=size_factor,
116
+ )
117
+ else:
118
+ width, height = image.size
119
+ min_pixels = ele.get("min_pixels", MIN_PIXELS)
120
+ max_pixels = ele.get("max_pixels", MAX_PIXELS)
121
+ resized_height, resized_width = smart_resize(
122
+ height,
123
+ width,
124
+ factor=size_factor,
125
+ min_pixels=min_pixels,
126
+ max_pixels=max_pixels,
127
+ )
128
+ image = image.resize((resized_width, resized_height))
129
+
130
+ return image
131
+
132
+
133
+ def smart_nframes(
134
+ ele: dict,
135
+ total_frames: int,
136
+ video_fps: int | float,
137
+ ) -> int:
138
+ """calculate the number of frames for video used for model inputs.
139
+
140
+ Args:
141
+ ele (dict): a dict contains the configuration of video.
142
+ support either `fps` or `nframes`:
143
+ - nframes: the number of frames to extract for model inputs.
144
+ - fps: the fps to extract frames for model inputs.
145
+ - min_frames: the minimum number of frames of the video, only used when fps is provided.
146
+ - max_frames: the maximum number of frames of the video, only used when fps is provided.
147
+ total_frames (int): the original total number of frames of the video.
148
+ video_fps (int | float): the original fps of the video.
149
+
150
+ Raises:
151
+ ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
152
+
153
+ Returns:
154
+ int: the number of frames for video used for model inputs.
155
+ """
156
+ assert not ("fps" in ele and
157
+ "nframes" in ele), "Only accept either `fps` or `nframes`"
158
+ if "nframes" in ele:
159
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
160
+ else:
161
+ fps = ele.get("fps", FPS)
162
+ min_frames = ceil_by_factor(
163
+ ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
164
+ max_frames = floor_by_factor(
165
+ ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)),
166
+ FRAME_FACTOR)
167
+ nframes = total_frames / video_fps * fps
168
+ nframes = min(max(nframes, min_frames), max_frames)
169
+ nframes = round_by_factor(nframes, FRAME_FACTOR)
170
+ if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
171
+ raise ValueError(
172
+ f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
173
+ )
174
+ return nframes
175
+
176
+
177
+ def _read_video_torchvision(ele: dict,) -> torch.Tensor:
178
+ """read video using torchvision.io.read_video
179
+
180
+ Args:
181
+ ele (dict): a dict contains the configuration of video.
182
+ support keys:
183
+ - video: the path of video. support "file://", "http://", "https://" and local path.
184
+ - video_start: the start time of video.
185
+ - video_end: the end time of video.
186
+ Returns:
187
+ torch.Tensor: the video tensor with shape (T, C, H, W).
188
+ """
189
+ video_path = ele["video"]
190
+ if version.parse(torchvision.__version__) < version.parse("0.19.0"):
191
+ if "http://" in video_path or "https://" in video_path:
192
+ warnings.warn(
193
+ "torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0."
194
+ )
195
+ if "file://" in video_path:
196
+ video_path = video_path[7:]
197
+ st = time.time()
198
+ video, audio, info = io.read_video(
199
+ video_path,
200
+ start_pts=ele.get("video_start", 0.0),
201
+ end_pts=ele.get("video_end", None),
202
+ pts_unit="sec",
203
+ output_format="TCHW",
204
+ )
205
+ total_frames, video_fps = video.size(0), info["video_fps"]
206
+ logger.info(
207
+ f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
208
+ )
209
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
210
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long()
211
+ video = video[idx]
212
+ return video
213
+
214
+
215
+ def is_decord_available() -> bool:
216
+ import importlib.util
217
+
218
+ return importlib.util.find_spec("decord") is not None
219
+
220
+
221
+ def _read_video_decord(ele: dict,) -> torch.Tensor:
222
+ """read video using decord.VideoReader
223
+
224
+ Args:
225
+ ele (dict): a dict contains the configuration of video.
226
+ support keys:
227
+ - video: the path of video. support "file://", "http://", "https://" and local path.
228
+ - video_start: the start time of video.
229
+ - video_end: the end time of video.
230
+ Returns:
231
+ torch.Tensor: the video tensor with shape (T, C, H, W).
232
+ """
233
+ import decord
234
+ video_path = ele["video"]
235
+ st = time.time()
236
+ vr = decord.VideoReader(video_path)
237
+ # TODO: support start_pts and end_pts
238
+ if 'video_start' in ele or 'video_end' in ele:
239
+ raise NotImplementedError(
240
+ "not support start_pts and end_pts in decord for now.")
241
+ total_frames, video_fps = len(vr), vr.get_avg_fps()
242
+ logger.info(
243
+ f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
244
+ )
245
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
246
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
247
+ video = vr.get_batch(idx).asnumpy()
248
+ video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
249
+ return video
250
+
251
+
252
+ VIDEO_READER_BACKENDS = {
253
+ "decord": _read_video_decord,
254
+ "torchvision": _read_video_torchvision,
255
+ }
256
+
257
+ FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
258
+
259
+
260
+ @lru_cache(maxsize=1)
261
+ def get_video_reader_backend() -> str:
262
+ if FORCE_QWENVL_VIDEO_READER is not None:
263
+ video_reader_backend = FORCE_QWENVL_VIDEO_READER
264
+ elif is_decord_available():
265
+ video_reader_backend = "decord"
266
+ else:
267
+ video_reader_backend = "torchvision"
268
+ print(
269
+ f"qwen-vl-utils using {video_reader_backend} to read video.",
270
+ file=sys.stderr)
271
+ return video_reader_backend
272
+
273
+
274
+ def fetch_video(
275
+ ele: dict,
276
+ image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
277
+ if isinstance(ele["video"], str):
278
+ video_reader_backend = get_video_reader_backend()
279
+ video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
280
+ nframes, _, height, width = video.shape
281
+
282
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
283
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
284
+ max_pixels = max(
285
+ min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
286
+ int(min_pixels * 1.05))
287
+ max_pixels = ele.get("max_pixels", max_pixels)
288
+ if "resized_height" in ele and "resized_width" in ele:
289
+ resized_height, resized_width = smart_resize(
290
+ ele["resized_height"],
291
+ ele["resized_width"],
292
+ factor=image_factor,
293
+ )
294
+ else:
295
+ resized_height, resized_width = smart_resize(
296
+ height,
297
+ width,
298
+ factor=image_factor,
299
+ min_pixels=min_pixels,
300
+ max_pixels=max_pixels,
301
+ )
302
+ video = transforms.functional.resize(
303
+ video,
304
+ [resized_height, resized_width],
305
+ interpolation=InterpolationMode.BICUBIC,
306
+ antialias=True,
307
+ ).float()
308
+ return video
309
+ else:
310
+ assert isinstance(ele["video"], (list, tuple))
311
+ process_info = ele.copy()
312
+ process_info.pop("type", None)
313
+ process_info.pop("video", None)
314
+ images = [
315
+ fetch_image({
316
+ "image": video_element,
317
+ **process_info
318
+ },
319
+ size_factor=image_factor)
320
+ for video_element in ele["video"]
321
+ ]
322
+ nframes = ceil_by_factor(len(images), FRAME_FACTOR)
323
+ if len(images) < nframes:
324
+ images.extend([images[-1]] * (nframes - len(images)))
325
+ return images
326
+
327
+
328
+ def extract_vision_info(
329
+ conversations: list[dict] | list[list[dict]]) -> list[dict]:
330
+ vision_infos = []
331
+ if isinstance(conversations[0], dict):
332
+ conversations = [conversations]
333
+ for conversation in conversations:
334
+ for message in conversation:
335
+ if isinstance(message["content"], list):
336
+ for ele in message["content"]:
337
+ if ("image" in ele or "image_url" in ele or
338
+ "video" in ele or
339
+ ele["type"] in ("image", "image_url", "video")):
340
+ vision_infos.append(ele)
341
+ return vision_infos
342
+
343
+
344
+ def process_vision_info(
345
+ conversations: list[dict] | list[list[dict]],
346
+ ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] |
347
+ None]:
348
+ vision_infos = extract_vision_info(conversations)
349
+ ## Read images or videos
350
+ image_inputs = []
351
+ video_inputs = []
352
+ for vision_info in vision_infos:
353
+ if "image" in vision_info or "image_url" in vision_info:
354
+ image_inputs.append(fetch_image(vision_info))
355
+ elif "video" in vision_info:
356
+ video_inputs.append(fetch_video(vision_info))
357
+ else:
358
+ raise ValueError("image, image_url or video should in content.")
359
+ if len(image_inputs) == 0:
360
+ image_inputs = None
361
+ if len(video_inputs) == 0:
362
+ video_inputs = None
363
+ return image_inputs, video_inputs
wan/utils/utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import binascii
4
+ import os
5
+ import os.path as osp
6
+
7
+ import imageio
8
+ import torch
9
+ import torchvision
10
+
11
+ __all__ = ['cache_video', 'cache_image', 'str2bool']
12
+
13
+
14
+ def rand_name(length=8, suffix=''):
15
+ name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
16
+ if suffix:
17
+ if not suffix.startswith('.'):
18
+ suffix = '.' + suffix
19
+ name += suffix
20
+ return name
21
+
22
+
23
+ def cache_video(tensor,
24
+ save_file=None,
25
+ fps=30,
26
+ suffix='.mp4',
27
+ nrow=8,
28
+ normalize=True,
29
+ value_range=(-1, 1),
30
+ retry=5):
31
+ # cache file
32
+ cache_file = osp.join('/tmp', rand_name(
33
+ suffix=suffix)) if save_file is None else save_file
34
+
35
+ # save to cache
36
+ error = None
37
+ for _ in range(retry):
38
+ try:
39
+ # preprocess
40
+ tensor = tensor.clamp(min(value_range), max(value_range))
41
+ tensor = torch.stack([
42
+ torchvision.utils.make_grid(
43
+ u, nrow=nrow, normalize=normalize, value_range=value_range)
44
+ for u in tensor.unbind(2)
45
+ ],
46
+ dim=1).permute(1, 2, 3, 0)
47
+ tensor = (tensor * 255).type(torch.uint8).cpu()
48
+
49
+ # write video
50
+ writer = imageio.get_writer(
51
+ cache_file, fps=fps, codec='libx264', quality=8)
52
+ for frame in tensor.numpy():
53
+ writer.append_data(frame)
54
+ writer.close()
55
+ return cache_file
56
+ except Exception as e:
57
+ error = e
58
+ continue
59
+ else:
60
+ print(f'cache_video failed, error: {error}', flush=True)
61
+ return None
62
+
63
+
64
+ def cache_image(tensor,
65
+ save_file,
66
+ nrow=8,
67
+ normalize=True,
68
+ value_range=(-1, 1),
69
+ retry=5):
70
+ # cache file
71
+ suffix = osp.splitext(save_file)[1]
72
+ if suffix.lower() not in [
73
+ '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
74
+ ]:
75
+ suffix = '.png'
76
+
77
+ # save to cache
78
+ error = None
79
+ for _ in range(retry):
80
+ try:
81
+ tensor = tensor.clamp(min(value_range), max(value_range))
82
+ torchvision.utils.save_image(
83
+ tensor,
84
+ save_file,
85
+ nrow=nrow,
86
+ normalize=normalize,
87
+ value_range=value_range)
88
+ return save_file
89
+ except Exception as e:
90
+ error = e
91
+ continue
92
+
93
+
94
+ def str2bool(v):
95
+ """
96
+ Convert a string to a boolean.
97
+
98
+ Supported true values: 'yes', 'true', 't', 'y', '1'
99
+ Supported false values: 'no', 'false', 'f', 'n', '0'
100
+
101
+ Args:
102
+ v (str): String to convert.
103
+
104
+ Returns:
105
+ bool: Converted boolean value.
106
+
107
+ Raises:
108
+ argparse.ArgumentTypeError: If the value cannot be converted to boolean.
109
+ """
110
+ if isinstance(v, bool):
111
+ return v
112
+ v_lower = v.lower()
113
+ if v_lower in ('yes', 'true', 't', 'y', '1'):
114
+ return True
115
+ elif v_lower in ('no', 'false', 'f', 'n', '0'):
116
+ return False
117
+ else:
118
+ raise argparse.ArgumentTypeError('Boolean value expected (True/False)')