aleafy's picture
acc
c8013a6
import torch
import numpy as np
from typing import Optional, Union, Tuple, List, Callable, Dict
from tqdm import tqdm
import torch
from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
import torch.nn.functional as nnf
import numpy as np
from einops import rearrange
from misc_utils.flow_utils import warp_image, RAFTFlow, resize_flow
from functools import partial
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class Inference():
def __init__(
self,
unet,
scheduler='ddim',
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_ddim_steps=20, guidance_scale=5,
):
self.unet = unet
if scheduler == 'ddim':
scheduler_cls = DDIMScheduler
scheduler_kwargs = {'set_alpha_to_one': False, 'steps_offset': 1, 'clip_sample': False}
elif scheduler == 'ddpm':
scheduler_cls = DDPMScheduler
scheduler_kwargs = {'clip_sample': False}
else:
raise NotImplementedError()
self.scheduler = scheduler_cls(
beta_start = beta_start,
beta_end = beta_end,
beta_schedule = beta_schedule,
**scheduler_kwargs
)
self.scheduler.set_timesteps(num_ddim_steps)
self.num_ddim_steps = num_ddim_steps
self.guidance_scale = guidance_scale
@torch.no_grad()
def __call__(
self,
latent: torch.Tensor,
context: torch.Tensor,
uncond_context: torch.Tensor=None,
start_time: int = 0,
null_embedding: List[torch.Tensor]=None,
context_kwargs={},
model_kwargs={},
):
all_latent = []
all_pred = [] # x0_hat
do_classifier_free_guidance = self.guidance_scale > 1 and ((uncond_context is not None) or (null_embedding is not None))
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
if do_classifier_free_guidance:
latent_input = torch.cat([latent, latent], dim=0)
if null_embedding is not None:
context_input = torch.cat([null_embedding[i], context], dim=0)
else:
context_input = torch.cat([uncond_context, context], dim=0)
else:
latent_input = latent
context_input = context
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
context={ 'text': context_input, **context_kwargs},
**model_kwargs
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
all_latent.append(latent.detach())
all_pred.append(pred.detach())
return {
'latent': latent,
'all_latent': all_latent,
'all_pred': all_pred
}
class InferenceIP2PEditRef(Inference):
def zeros(self, x):
return torch.zeros_like(x)
@torch.no_grad()
def __call__(
self,
latent: torch.Tensor,
text_cond: torch.Tensor,
text_uncond: torch.Tensor,
img_cond: torch.Tensor,
edit_cond: torch.Tensor,
text_cfg = 7.5,
img_cfg = 1.2,
edit_cfg = 1.2,
start_time: int = 0,
):
'''
latent1 | latent2 | latent3 | latent4
text x x x v
edit x x v v
img x v v v
'''
all_latent = []
all_pred = [] # x0_hat
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
latent1 = torch.cat([latent, self.zeros(img_cond), self.zeros(edit_cond)], dim=1)
latent2 = torch.cat([latent, img_cond, self.zeros(edit_cond)], dim=1)
latent3 = torch.cat([latent, img_cond, edit_cond], dim=1)
latent4 = latent3.clone()
latent_input = torch.cat([latent1, latent2, latent3, latent4], dim=0)
context_input = torch.cat([text_uncond, text_uncond, text_uncond, text_cond], dim=0)
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
context={ 'text': context_input},
)
noise_pred1, noise_pred2, noise_pred3, noise_pred4 = noise_pred.chunk(4, dim=0)
noise_pred = (
noise_pred1 +
img_cfg * (noise_pred2 - noise_pred1) +
edit_cfg * (noise_pred3 - noise_pred2) +
text_cfg * (noise_pred4 - noise_pred3)
) # when edit_cfg == img_cfg, noise_pred2 is not used
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
all_latent.append(latent.detach())
all_pred.append(pred.detach())
return {
'latent': latent,
'all_latent': all_latent,
'all_pred': all_pred
}
class InferenceIP2PVideo(Inference):
def zeros(self, x):
return torch.zeros_like(x)
@torch.no_grad()
def __call__(
self,
latent: torch.Tensor,
text_cond: torch.Tensor,
text_uncond: torch.Tensor,
img_cond: torch.Tensor,
text_cfg = 7.5,
img_cfg = 1.2,
start_time: int = 0,
guidance_rescale: float = 0.0,
):
'''
latent1 | latent2 | latent3
text x x v
img x v v
'''
# all_latent = []
# all_pred = [] # x0_hat
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2)
latent2 = torch.cat([latent, img_cond], dim=2)
latent3 = latent2.clone()
latent_input = torch.cat([latent1, latent2, latent3], dim=0)
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0)
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w')
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
encoder_hidden_states=context_input,
).sample
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w')
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0)
noise_pred = (
noise_pred1 +
img_cfg * (noise_pred2 - noise_pred1) +
text_cfg * (noise_pred3 - noise_pred2)
)
if guidance_rescale > 0:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale)
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
del noise_pred, noise_pred1, noise_pred2, noise_pred3, pred_samples
del latent_input, context_input
torch.cuda.empty_cache()
# all_latent.append(latent.detach())
# all_pred.append(pred.detach())
return {
'latent': latent,
# 'all_latent': all_latent,
# 'all_pred': all_pred
}
@torch.no_grad()
def second_clip_forward(
self,
latent: torch.Tensor,
text_cond: torch.Tensor,
text_uncond: torch.Tensor,
img_cond: torch.Tensor,
latent_ref: torch.Tensor,
noise_correct_step: float = 1.,
text_cfg = 7.5,
img_cfg = 1.2,
start_time: int = 0,
guidance_rescale: float = 0.0,
):
'''
latent1 | latent2 | latent3
text x x v
img x v v
'''
num_ref_frames = latent_ref.shape[1]
all_latent = []
all_pred = [] # x0_hat
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2)
latent2 = torch.cat([latent, img_cond], dim=2)
latent3 = latent2.clone()
latent_input = torch.cat([latent1, latent2, latent3], dim=0)
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0)
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w')
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
encoder_hidden_states=context_input,
).sample
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w')
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0)
noise_pred = (
noise_pred1 +
img_cfg * (noise_pred2 - noise_pred1) +
text_cfg * (noise_pred3 - noise_pred2)
)
if guidance_rescale > 0:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale)
# 所谓的再inference阶段加入 Long Video Sampling Correction(LVSC)
if noise_correct_step * self.num_ddim_steps > i:
alpha_prod_t = self.scheduler.alphas_cumprod[t]
beta_prod_t = 1 - alpha_prod_t
noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w
delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames]
delta_noise_remaining = delta_noise_ref.mean(dim=1, keepdim=True)
noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref
noise_pred[:, num_ref_frames:] = noise_pred[:, num_ref_frames:] + delta_noise_remaining
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
all_latent.append(latent.detach())
all_pred.append(pred.detach())
return {
'latent': latent,
'all_latent': all_latent,
'all_pred': all_pred
}
class InferenceIP2PVideoEnsemble(Inference):
def zeros(self, x):
return torch.zeros_like(x)
@torch.no_grad()
def __call__(
self,
latent: torch.Tensor,
text_cond: torch.Tensor,
text_uncond: torch.Tensor,
img_cond: torch.Tensor,
text_cfg = 7.5,
img_cfg = 1.2,
start_time: int = 0,
guidance_rescale: float = 0.0,
):
'''
latent1 | latent2 | latent3
text x x v
img x v v
'''
all_latent = []
all_pred = [] # x0_hat
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2)
latent2 = torch.cat([latent, img_cond], dim=2)
latent3 = latent2.clone()
latent_input = torch.cat([latent1, latent2, latent3], dim=0)
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0)
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w')
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
encoder_hidden_states=context_input,
).sample
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w')
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0)
noise_pred = (
noise_pred1 +
img_cfg * (noise_pred2 - noise_pred1) +
text_cfg * (noise_pred3 - noise_pred2)
)
if guidance_rescale > 0:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale)
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
# average over all three samples.
latent = latent.mean(dim=0, keepdim=True).repeat(latent.shape[0], 1, 1, 1, 1)
# latent = latent[[0]].repeat(latent.shape[0], 1, 1, 1, 1)
pred = pred_samples.pred_original_sample
all_latent.append(latent.detach())
all_pred.append(pred.detach())
return {
'latent': latent,
'all_latent': all_latent,
'all_pred': all_pred
}
@torch.no_grad()
def second_clip_forward(
self,
latent: torch.Tensor,
text_cond: torch.Tensor,
text_uncond: torch.Tensor,
img_cond: torch.Tensor,
latent_ref: torch.Tensor,
noise_correct_step: float = 1.,
text_cfg = 7.5,
img_cfg = 1.2,
start_time: int = 0,
guidance_rescale: float = 0.0,
):
'''
latent1 | latent2 | latent3
text x x v
img x v v
'''
num_ref_frames = latent_ref.shape[1]
all_latent = []
all_pred = [] # x0_hat
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2)
latent2 = torch.cat([latent, img_cond], dim=2)
latent3 = latent2.clone()
latent_input = torch.cat([latent1, latent2, latent3], dim=0)
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0)
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w')
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
encoder_hidden_states=context_input,
).sample
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w')
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0)
noise_pred = (
noise_pred1 +
img_cfg * (noise_pred2 - noise_pred1) +
text_cfg * (noise_pred3 - noise_pred2)
)
if guidance_rescale > 0:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale)
# 所谓的再inference阶段加入 Long Video Sampling Correction(LVSC)
if noise_correct_step * self.num_ddim_steps > i:
alpha_prod_t = self.scheduler.alphas_cumprod[t]
beta_prod_t = 1 - alpha_prod_t
noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w
delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames]
delta_noise_remaining = delta_noise_ref.mean(dim=1, keepdim=True)
noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref
noise_pred[:, num_ref_frames:] = noise_pred[:, num_ref_frames:] + delta_noise_remaining
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
all_latent.append(latent.detach())
all_pred.append(pred.detach())
return {
'latent': latent,
'all_latent': all_latent,
'all_pred': all_pred
}
class InferenceIP2PVideoHDR(Inference):
def zeros(self, x):
return torch.zeros_like(x)
@torch.no_grad()
def __call__(
self,
latent: torch.Tensor,
text_cond: torch.Tensor,
text_uncond: torch.Tensor,#(1,77,768)
hdr_cond: torch.Tensor, #(1,3,768)
img_cond: torch.Tensor,
text_cfg = 7.5,
img_cfg = 1.2,
hdr_cfg = 7.5,
start_time: int = 0,
guidance_rescale: float = 0.0,
):
'''
latent1 | latent2 | latent3 | latent4
text x x v v
img x v v v
hdr x x x v
'''
all_latent = []
all_pred = [] # x0_hat
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2)
latent2 = torch.cat([latent, img_cond], dim=2)
latent3 = latent2.clone()
latent4 = latent2.clone()
latent_input = torch.cat([latent1, latent2, latent3, latent4], dim=0)
context_input = torch.cat([text_uncond, text_uncond, text_cond, text_cond], dim=0) #(4,77,768)
hdr_uncond = self.zeros(hdr_cond)
hdr_input = torch.cat([hdr_uncond, hdr_uncond, hdr_uncond, hdr_cond]) #(4,3,768)
model_kwargs1 = {'hdr_latents': hdr_input, 'encoder_hidden_states': context_input}
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w')
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
encoder_hidden_states=model_kwargs1,
).sample
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w')
noise_pred1, noise_pred2, noise_pred3, noise_pred4 = noise_pred.chunk(4, dim=0)
noise_pred = (
noise_pred1 +
img_cfg * (noise_pred2 - noise_pred1) +
text_cfg * (noise_pred3 - noise_pred2) +
hdr_cfg * (noise_pred4 - noise_pred3)
)
if guidance_rescale > 0:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale)
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
all_latent.append(latent.detach())
all_pred.append(pred.detach())
return {
'latent': latent,
'all_latent': all_latent,
'all_pred': all_pred
}
@torch.no_grad()
def second_clip_forward(
self,
latent: torch.Tensor,
text_cond: torch.Tensor,
text_uncond: torch.Tensor,
img_cond: torch.Tensor,
latent_ref: torch.Tensor,
noise_correct_step: float = 1.,
text_cfg = 7.5,
img_cfg = 1.2,
start_time: int = 0,
guidance_rescale: float = 0.0,
):
'''
latent1 | latent2 | latent3
text x x v
img x v v
'''
num_ref_frames = latent_ref.shape[1]
all_latent = []
all_pred = [] # x0_hat
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2)
latent2 = torch.cat([latent, img_cond], dim=2)
latent3 = latent2.clone()
latent_input = torch.cat([latent1, latent2, latent3], dim=0)
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0)
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w')
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
encoder_hidden_states=context_input,
).sample
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w')
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0)
noise_pred = (
noise_pred1 +
img_cfg * (noise_pred2 - noise_pred1) +
text_cfg * (noise_pred3 - noise_pred2)
)
if guidance_rescale > 0:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale)
# 所谓的再inference阶段加入 Long Video Sampling Correction(LVSC)
if noise_correct_step * self.num_ddim_steps > i:
alpha_prod_t = self.scheduler.alphas_cumprod[t]
beta_prod_t = 1 - alpha_prod_t
noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w
delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames]
delta_noise_remaining = delta_noise_ref.mean(dim=1, keepdim=True)
noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref
noise_pred[:, num_ref_frames:] = noise_pred[:, num_ref_frames:] + delta_noise_remaining
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
all_latent.append(latent.detach())
all_pred.append(pred.detach())
return {
'latent': latent,
'all_latent': all_latent,
'all_pred': all_pred
}
class InferenceIP2PVideoOpticalFlow(InferenceIP2PVideo):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.flow_estimator = RAFTFlow().cuda() # 使用光流估计器
def obtain_delta_noise(self, delta_noise_ref, flow):
flow = resize_flow(flow, delta_noise_ref.shape[2:])
warped_delta_noise_ref = warp_image(delta_noise_ref, flow) # 根据光流扭曲参考帧的噪声差异
valid_mask = torch.ones_like(delta_noise_ref)[:, :1]
valid_mask = warp_image(valid_mask, flow)
return warped_delta_noise_ref, valid_mask
def obtain_flow_batched(self, ref_images, query_images):
ref_images = ref_images.to()
warp_funcs = []
for query_image in query_images:
query_image = query_image.unsqueeze(0).repeat(len(ref_images), 1, 1, 1)
flow = self.flow_estimator(query_image, ref_images) # 估计光流
warp_func = partial(self.obtain_delta_noise, flow=flow)
warp_funcs.append(warp_func)
return warp_funcs
@torch.no_grad()
def second_clip_forward(
self,
latent: torch.Tensor,
text_cond: torch.Tensor,
text_uncond: torch.Tensor,
img_cond: torch.Tensor,
latent_ref: torch.Tensor,
ref_images: torch.Tensor,
query_images: torch.Tensor,
noise_correct_step: float = 1.,
text_cfg = 7.5,
img_cfg = 1.2,
start_time: int = 0,
guidance_rescale: float = 0.0,
):
'''
latent1 | latent2 | latent3
text x x v
img x v v
'''
assert ref_images.shape[0] == 1, 'only support batch size 1'
warp_funcs = self.obtain_flow_batched(ref_images[0], query_images[0])
num_ref_frames = latent_ref.shape[1]
all_latent = []
all_pred = [] # x0_hat
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2)
latent2 = torch.cat([latent, img_cond], dim=2)
latent3 = latent2.clone()
latent_input = torch.cat([latent1, latent2, latent3], dim=0)
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0)
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w')
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
encoder_hidden_states=context_input,
).sample
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w')
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0)
noise_pred = (
noise_pred1 +
img_cfg * (noise_pred2 - noise_pred1) +
text_cfg * (noise_pred3 - noise_pred2)
)
if guidance_rescale > 0:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale)
if noise_correct_step * self.num_ddim_steps > i:
alpha_prod_t = self.scheduler.alphas_cumprod[t]
beta_prod_t = 1 - alpha_prod_t
noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w
delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames]
noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref
for refed_index, warp_func in zip(range(num_ref_frames, noise_pred.shape[1]), warp_funcs):
delta_noise_remaining, delta_noise_mask = warp_func(delta_noise_ref[0])
mask_sum = delta_noise_mask[None].sum(dim=1, keepdim=True)
delta_noise_remaining = torch.where(
mask_sum > 0.5,
delta_noise_remaining[None].sum(dim=1, keepdim=True) / mask_sum,
0.
)
noise_pred[:, refed_index: refed_index+1] += torch.where(
mask_sum > 0.5,
delta_noise_remaining,
0
) # 将这个扭曲的噪声差异应用到当前帧,确保帧之间的噪声变化符合视频中物体的移动
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
all_latent.append(latent.detach())
all_pred.append(pred.detach())
return {
'latent': latent,
'all_latent': all_latent,
'all_pred': all_pred
}