|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import types |
|
import inspect |
|
from typing import Callable, Dict, List, Optional, Union, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
from transformers import T5EncoderModel, T5TokenizerFast |
|
|
|
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback |
|
from diffusers.image_processor import PipelineImageInput |
|
from diffusers.loaders import FromSingleFileMixin |
|
from diffusers.pipelines.ltx.pipeline_ltx_image2video import LTXImageToVideoPipeline |
|
from diffusers.models.autoencoders import AutoencoderKLLTXVideo |
|
from diffusers.models.transformers import LTXVideoTransformer3DModel |
|
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
|
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from diffusers.video_processor import VideoProcessor |
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
|
from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput |
|
from diffusers.models.attention_processor import Attention |
|
from diffusers.models.transformers.transformer_ltx import apply_rotary_emb |
|
|
|
import torch.nn.functional as F |
|
|
|
if is_torch_xla_available(): |
|
import torch_xla.core.xla_model as xm |
|
|
|
XLA_AVAILABLE = True |
|
else: |
|
XLA_AVAILABLE = False |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
def forward_with_stg( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
temb: torch.Tensor, |
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
|
|
hidden_states_ptb = hidden_states[2:] |
|
encoder_hidden_states_ptb = encoder_hidden_states[2:] |
|
|
|
batch_size = hidden_states.size(0) |
|
norm_hidden_states = self.norm1(hidden_states) |
|
|
|
num_ada_params = self.scale_shift_table.shape[0] |
|
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) |
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
|
|
|
attn_hidden_states = self.attn1( |
|
hidden_states=norm_hidden_states, |
|
encoder_hidden_states=None, |
|
image_rotary_emb=image_rotary_emb, |
|
) |
|
hidden_states = hidden_states + attn_hidden_states * gate_msa |
|
|
|
attn_hidden_states = self.attn2( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
image_rotary_emb=None, |
|
attention_mask=encoder_attention_mask, |
|
) |
|
hidden_states = hidden_states + attn_hidden_states |
|
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp |
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
hidden_states = hidden_states + ff_output * gate_mlp |
|
|
|
hidden_states[2:] = hidden_states_ptb |
|
encoder_hidden_states[2:] = encoder_hidden_states_ptb |
|
|
|
return hidden_states |
|
|
|
class STGLTXVideoAttentionProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is |
|
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector. |
|
""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError( |
|
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." |
|
) |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
image_rotary_emb: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
|
|
hidden_states_uncond, hidden_states_text, hidden_states_perturb = hidden_states.chunk(3) |
|
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_text]) |
|
|
|
emb_sin, emb_cos = image_rotary_emb |
|
emb_sin_uncond, emb_sin_text, emb_sin_perturb = emb_sin.chunk(3) |
|
emb_cos_uncond, emb_cos_text, emb_cos_perturb = emb_cos.chunk(3) |
|
emb_sin_org = torch.cat([emb_sin_uncond, emb_sin_text]) |
|
emb_cos_org = torch.cat([emb_cos_uncond, emb_cos_text]) |
|
|
|
image_rotary_emb_org = (emb_sin_org, emb_cos_org) |
|
image_rotary_emb_perturb = (emb_sin_perturb, emb_cos_perturb) |
|
|
|
|
|
assert encoder_hidden_states is None |
|
batch_size, sequence_length, _ = ( |
|
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states_org = hidden_states_org |
|
|
|
query_org = attn.to_q(hidden_states_org) |
|
key_org = attn.to_k(encoder_hidden_states_org) |
|
value_org = attn.to_v(encoder_hidden_states_org) |
|
|
|
query_org = attn.norm_q(query_org) |
|
key_org = attn.norm_k(key_org) |
|
|
|
if image_rotary_emb is not None: |
|
query_org = apply_rotary_emb(query_org, image_rotary_emb_org) |
|
key_org = apply_rotary_emb(key_org, image_rotary_emb_org) |
|
|
|
query_org = query_org.unflatten(2, (attn.heads, -1)).transpose(1, 2) |
|
key_org = key_org.unflatten(2, (attn.heads, -1)).transpose(1, 2) |
|
value_org = value_org.unflatten(2, (attn.heads, -1)).transpose(1, 2) |
|
|
|
hidden_states_org = F.scaled_dot_product_attention( |
|
query_org, key_org, value_org, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
hidden_states_org = hidden_states_org.transpose(1, 2).flatten(2, 3) |
|
hidden_states_org = hidden_states_org.to(query_org.dtype) |
|
|
|
hidden_states_org = attn.to_out[0](hidden_states_org) |
|
hidden_states_org = attn.to_out[1](hidden_states_org) |
|
|
|
|
|
batch_size, sequence_length, _ = hidden_states_perturb.shape |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states_perturb = hidden_states_perturb |
|
|
|
query_perturb = attn.to_q(hidden_states_perturb) |
|
key_perturb = attn.to_k(encoder_hidden_states_perturb) |
|
value_perturb = attn.to_v(encoder_hidden_states_perturb) |
|
|
|
query_perturb = attn.norm_q(query_perturb) |
|
key_perturb = attn.norm_k(key_perturb) |
|
|
|
if image_rotary_emb is not None: |
|
query_perturb = apply_rotary_emb(query_perturb, image_rotary_emb_perturb) |
|
key_perturb = apply_rotary_emb(key_perturb, image_rotary_emb_perturb) |
|
|
|
query_perturb = query_perturb.unflatten(2, (attn.heads, -1)).transpose(1, 2) |
|
key_perturb = key_perturb.unflatten(2, (attn.heads, -1)).transpose(1, 2) |
|
value_perturb = value_perturb.unflatten(2, (attn.heads, -1)).transpose(1, 2) |
|
|
|
hidden_states_perturb = value_perturb |
|
|
|
hidden_states_perturb = hidden_states_perturb.transpose(1, 2).flatten(2, 3) |
|
hidden_states_perturb = hidden_states_perturb.to(query_perturb.dtype) |
|
|
|
hidden_states_perturb = attn.to_out[0](hidden_states_perturb) |
|
hidden_states_perturb = attn.to_out[1](hidden_states_perturb) |
|
|
|
|
|
hidden_states = torch.cat([hidden_states_org, hidden_states_perturb], dim=0) |
|
|
|
return hidden_states |
|
|
|
|
|
def calculate_shift( |
|
image_seq_len, |
|
base_seq_len: int = 256, |
|
max_seq_len: int = 4096, |
|
base_shift: float = 0.5, |
|
max_shift: float = 1.16, |
|
): |
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
|
b = base_shift - m * base_seq_len |
|
mu = image_seq_len * m + b |
|
return mu |
|
|
|
|
|
|
|
def retrieve_timesteps( |
|
scheduler, |
|
num_inference_steps: Optional[int] = None, |
|
device: Optional[Union[str, torch.device]] = None, |
|
timesteps: Optional[List[int]] = None, |
|
sigmas: Optional[List[float]] = None, |
|
**kwargs, |
|
): |
|
r""" |
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
|
|
|
Args: |
|
scheduler (`SchedulerMixin`): |
|
The scheduler to get timesteps from. |
|
num_inference_steps (`int`): |
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
|
must be `None`. |
|
device (`str` or `torch.device`, *optional*): |
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
timesteps (`List[int]`, *optional*): |
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
|
`num_inference_steps` and `sigmas` must be `None`. |
|
sigmas (`List[float]`, *optional*): |
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
|
`num_inference_steps` and `timesteps` must be `None`. |
|
|
|
Returns: |
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
|
second element is the number of inference steps. |
|
""" |
|
if timesteps is not None and sigmas is not None: |
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
|
if timesteps is not None: |
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accepts_timesteps: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" timestep schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
elif sigmas is not None: |
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accept_sigmas: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" sigmas schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
else: |
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
return timesteps, num_inference_steps |
|
|
|
|
|
|
|
def retrieve_latents( |
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
|
): |
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
|
return encoder_output.latent_dist.sample(generator) |
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
|
return encoder_output.latent_dist.mode() |
|
elif hasattr(encoder_output, "latents"): |
|
return encoder_output.latents |
|
else: |
|
raise AttributeError("Could not access latents of provided encoder_output") |
|
|
|
|
|
class LTXImageToVideoSTGPipeline(LTXImageToVideoPipeline): |
|
def extract_layers(self, file_path="./unet_info.txt"): |
|
layers = [] |
|
with open(file_path, "w") as f: |
|
for name, module in self.transformer.named_modules(): |
|
if "attn1" in name and "to" not in name and "add" not in name and "norm" not in name: |
|
f.write(f"{name}\n") |
|
layer_type = name.split(".")[0].split("_")[0] |
|
layers.append((name, module)) |
|
|
|
return layers |
|
|
|
def replace_layer_processor(self, layers, replace_processor, target_layers_idx=[]): |
|
for layer_idx in target_layers_idx: |
|
layers[layer_idx][1].processor = replace_processor |
|
|
|
return |
|
|
|
@property |
|
def do_spatio_temporal_guidance(self): |
|
return self._stg_scale > 0.0 |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
image: PipelineImageInput = None, |
|
prompt: Union[str, List[str]] = None, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
height: int = 512, |
|
width: int = 704, |
|
num_frames: int = 161, |
|
frame_rate: int = 25, |
|
num_inference_steps: int = 50, |
|
timesteps: List[int] = None, |
|
guidance_scale: float = 3, |
|
num_videos_per_prompt: Optional[int] = 1, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.Tensor] = None, |
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
prompt_attention_mask: Optional[torch.Tensor] = None, |
|
negative_prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
max_sequence_length: int = 128, |
|
stg_mode: Optional[str] = "STG-R", |
|
stg_applied_layers_idx: Optional[List[int]] = [35], |
|
stg_scale: Optional[float] = 1.0, |
|
do_rescaling: Optional[bool] = False, |
|
decode_timestep: Union[float, List[float]] = 0.0, |
|
decode_noise_scale: Optional[Union[float, List[float]]] = None, |
|
): |
|
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): |
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs |
|
|
|
layers = self.extract_layers() |
|
|
|
|
|
self.check_inputs( |
|
prompt=prompt, |
|
height=height, |
|
width=width, |
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
prompt_attention_mask=prompt_attention_mask, |
|
negative_prompt_attention_mask=negative_prompt_attention_mask, |
|
) |
|
|
|
self._stg_scale = stg_scale |
|
self._guidance_scale = guidance_scale |
|
self._interrupt = False |
|
|
|
if self.do_spatio_temporal_guidance: |
|
if stg_mode == "STG-A": |
|
layers = self.extract_layers() |
|
replace_processor = STGLTXVideoAttentionProcessor2_0() |
|
self.replace_layer_processor(layers, replace_processor, stg_applied_layers_idx) |
|
elif stg_mode == "STG-R": |
|
for i in stg_applied_layers_idx: |
|
self.transformer.transformer_blocks[i].forward = types.MethodType(forward_with_stg, self.transformer.transformer_blocks[i]) |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
|
|
( |
|
prompt_embeds, |
|
prompt_attention_mask, |
|
negative_prompt_embeds, |
|
negative_prompt_attention_mask, |
|
) = self.encode_prompt( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
do_classifier_free_guidance=self.do_classifier_free_guidance, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
prompt_attention_mask=prompt_attention_mask, |
|
negative_prompt_attention_mask=negative_prompt_attention_mask, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
) |
|
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance: |
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
|
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) |
|
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance: |
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0) |
|
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0) |
|
|
|
|
|
if latents is None: |
|
image = self.video_processor.preprocess(image, height=height, width=width) |
|
image = image.to(device=device, dtype=prompt_embeds.dtype) |
|
|
|
num_channels_latents = self.transformer.config.in_channels |
|
latents, conditioning_mask = self.prepare_latents( |
|
image, |
|
batch_size * num_videos_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
num_frames, |
|
torch.float32, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance: |
|
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) |
|
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance: |
|
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask, conditioning_mask]) |
|
|
|
|
|
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 |
|
latent_height = height // self.vae_spatial_compression_ratio |
|
latent_width = width // self.vae_spatial_compression_ratio |
|
video_sequence_length = latent_num_frames * latent_height * latent_width |
|
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) |
|
mu = calculate_shift( |
|
video_sequence_length, |
|
self.scheduler.config.base_image_seq_len, |
|
self.scheduler.config.max_image_seq_len, |
|
self.scheduler.config.base_shift, |
|
self.scheduler.config.max_shift, |
|
) |
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
self.scheduler, |
|
num_inference_steps, |
|
device, |
|
timesteps, |
|
sigmas=sigmas, |
|
mu=mu, |
|
) |
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
self._num_timesteps = len(timesteps) |
|
|
|
|
|
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio |
|
rope_interpolation_scale = ( |
|
1 / latent_frame_rate, |
|
self.vae_spatial_compression_ratio, |
|
self.vae_spatial_compression_ratio, |
|
) |
|
|
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
if self.interrupt: |
|
continue |
|
|
|
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance: |
|
latent_model_input = torch.cat([latents] * 2) |
|
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance: |
|
latent_model_input = torch.cat([latents] * 3) |
|
|
|
latent_model_input = latent_model_input.to(prompt_embeds.dtype) |
|
|
|
|
|
timestep = t.expand(latent_model_input.shape[0]) |
|
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) |
|
|
|
noise_pred = self.transformer( |
|
hidden_states=latent_model_input, |
|
encoder_hidden_states=prompt_embeds, |
|
timestep=timestep, |
|
encoder_attention_mask=prompt_attention_mask, |
|
num_frames=latent_num_frames, |
|
height=latent_height, |
|
width=latent_width, |
|
rope_interpolation_scale=rope_interpolation_scale, |
|
return_dict=False, |
|
)[0] |
|
noise_pred = noise_pred.float() |
|
|
|
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
timestep, _ = timestep.chunk(2) |
|
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance: |
|
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) |
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) \ |
|
+ self._stg_scale * (noise_pred_text - noise_pred_perturb) |
|
timestep, _, _ = timestep.chunk(3) |
|
|
|
if do_rescaling: |
|
rescaling_scale = 0.7 |
|
factor = noise_pred_text.std() / noise_pred.std() |
|
factor = rescaling_scale * factor + (1 - rescaling_scale) |
|
noise_pred = noise_pred * factor |
|
|
|
|
|
noise_pred = self._unpack_latents( |
|
noise_pred, |
|
latent_num_frames, |
|
latent_height, |
|
latent_width, |
|
self.transformer_spatial_patch_size, |
|
self.transformer_temporal_patch_size, |
|
) |
|
latents = self._unpack_latents( |
|
latents, |
|
latent_num_frames, |
|
latent_height, |
|
latent_width, |
|
self.transformer_spatial_patch_size, |
|
self.transformer_temporal_patch_size, |
|
) |
|
|
|
noise_pred = noise_pred[:, :, 1:] |
|
noise_latents = latents[:, :, 1:] |
|
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] |
|
|
|
latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) |
|
latents = self._pack_latents( |
|
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size |
|
) |
|
|
|
if callback_on_step_end is not None: |
|
callback_kwargs = {} |
|
for k in callback_on_step_end_tensor_inputs: |
|
callback_kwargs[k] = locals()[k] |
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
|
latents = callback_outputs.pop("latents", latents) |
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
|
|
if XLA_AVAILABLE: |
|
xm.mark_step() |
|
|
|
if output_type == "latent": |
|
video = latents |
|
else: |
|
latents = self._unpack_latents( |
|
latents, |
|
latent_num_frames, |
|
latent_height, |
|
latent_width, |
|
self.transformer_spatial_patch_size, |
|
self.transformer_temporal_patch_size, |
|
) |
|
latents = self._denormalize_latents( |
|
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor |
|
) |
|
latents = latents.to(prompt_embeds.dtype) |
|
|
|
if not self.vae.config.timestep_conditioning: |
|
timestep = None |
|
else: |
|
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) |
|
if not isinstance(decode_timestep, list): |
|
decode_timestep = [decode_timestep] * batch_size |
|
if decode_noise_scale is None: |
|
decode_noise_scale = decode_timestep |
|
elif not isinstance(decode_noise_scale, list): |
|
decode_noise_scale = [decode_noise_scale] * batch_size |
|
|
|
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) |
|
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ |
|
:, None, None, None, None |
|
] |
|
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise |
|
|
|
video = self.vae.decode(latents, timestep, return_dict=False)[0] |
|
video = self.video_processor.postprocess_video(video, output_type=output_type) |
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (video,) |
|
|
|
return LTXPipelineOutput(frames=video) |
|
|