Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
from typing import Optional, Union, Tuple, List, Callable, Dict | |
from tqdm import tqdm | |
import torch | |
from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler | |
import torch.nn.functional as nnf | |
import numpy as np | |
from einops import rearrange | |
from misc_utils.flow_utils import warp_image, RAFTFlow, resize_flow | |
from functools import partial | |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
""" | |
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | |
""" | |
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | |
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
# rescale the results from guidance (fixes overexposure) | |
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
return noise_cfg | |
class Inference(): | |
def __init__( | |
self, | |
unet, | |
scheduler='ddim', | |
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", | |
num_ddim_steps=20, guidance_scale=5, | |
): | |
self.unet = unet | |
if scheduler == 'ddim': | |
scheduler_cls = DDIMScheduler | |
scheduler_kwargs = {'set_alpha_to_one': False, 'steps_offset': 1, 'clip_sample': False} | |
elif scheduler == 'ddpm': | |
scheduler_cls = DDPMScheduler | |
scheduler_kwargs = {'clip_sample': False} | |
else: | |
raise NotImplementedError() | |
self.scheduler = scheduler_cls( | |
beta_start = beta_start, | |
beta_end = beta_end, | |
beta_schedule = beta_schedule, | |
**scheduler_kwargs | |
) | |
self.scheduler.set_timesteps(num_ddim_steps) | |
self.num_ddim_steps = num_ddim_steps | |
self.guidance_scale = guidance_scale | |
def __call__( | |
self, | |
latent: torch.Tensor, | |
context: torch.Tensor, | |
uncond_context: torch.Tensor=None, | |
start_time: int = 0, | |
null_embedding: List[torch.Tensor]=None, | |
context_kwargs={}, | |
model_kwargs={}, | |
): | |
all_latent = [] | |
all_pred = [] # x0_hat | |
do_classifier_free_guidance = self.guidance_scale > 1 and ((uncond_context is not None) or (null_embedding is not None)) | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
if do_classifier_free_guidance: | |
latent_input = torch.cat([latent, latent], dim=0) | |
if null_embedding is not None: | |
context_input = torch.cat([null_embedding[i], context], dim=0) | |
else: | |
context_input = torch.cat([uncond_context, context], dim=0) | |
else: | |
latent_input = latent | |
context_input = context | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
context={ 'text': context_input, **context_kwargs}, | |
**model_kwargs | |
) | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
all_latent.append(latent.detach()) | |
all_pred.append(pred.detach()) | |
return { | |
'latent': latent, | |
'all_latent': all_latent, | |
'all_pred': all_pred | |
} | |
class InferenceIP2PEditRef(Inference): | |
def zeros(self, x): | |
return torch.zeros_like(x) | |
def __call__( | |
self, | |
latent: torch.Tensor, | |
text_cond: torch.Tensor, | |
text_uncond: torch.Tensor, | |
img_cond: torch.Tensor, | |
edit_cond: torch.Tensor, | |
text_cfg = 7.5, | |
img_cfg = 1.2, | |
edit_cfg = 1.2, | |
start_time: int = 0, | |
): | |
''' | |
latent1 | latent2 | latent3 | latent4 | |
text x x x v | |
edit x x v v | |
img x v v v | |
''' | |
all_latent = [] | |
all_pred = [] # x0_hat | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
latent1 = torch.cat([latent, self.zeros(img_cond), self.zeros(edit_cond)], dim=1) | |
latent2 = torch.cat([latent, img_cond, self.zeros(edit_cond)], dim=1) | |
latent3 = torch.cat([latent, img_cond, edit_cond], dim=1) | |
latent4 = latent3.clone() | |
latent_input = torch.cat([latent1, latent2, latent3, latent4], dim=0) | |
context_input = torch.cat([text_uncond, text_uncond, text_uncond, text_cond], dim=0) | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
context={ 'text': context_input}, | |
) | |
noise_pred1, noise_pred2, noise_pred3, noise_pred4 = noise_pred.chunk(4, dim=0) | |
noise_pred = ( | |
noise_pred1 + | |
img_cfg * (noise_pred2 - noise_pred1) + | |
edit_cfg * (noise_pred3 - noise_pred2) + | |
text_cfg * (noise_pred4 - noise_pred3) | |
) # when edit_cfg == img_cfg, noise_pred2 is not used | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
all_latent.append(latent.detach()) | |
all_pred.append(pred.detach()) | |
return { | |
'latent': latent, | |
'all_latent': all_latent, | |
'all_pred': all_pred | |
} | |
class InferenceIP2PVideo(Inference): | |
def zeros(self, x): | |
return torch.zeros_like(x) | |
def __call__( | |
self, | |
latent: torch.Tensor, | |
text_cond: torch.Tensor, | |
text_uncond: torch.Tensor, | |
img_cond: torch.Tensor, | |
text_cfg = 7.5, | |
img_cfg = 1.2, | |
start_time: int = 0, | |
guidance_rescale: float = 0.0, | |
): | |
''' | |
latent1 | latent2 | latent3 | |
text x x v | |
img x v v | |
''' | |
# all_latent = [] | |
# all_pred = [] # x0_hat | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) | |
latent2 = torch.cat([latent, img_cond], dim=2) | |
latent3 = latent2.clone() | |
latent_input = torch.cat([latent1, latent2, latent3], dim=0) | |
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) | |
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
encoder_hidden_states=context_input, | |
).sample | |
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') | |
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) | |
noise_pred = ( | |
noise_pred1 + | |
img_cfg * (noise_pred2 - noise_pred1) + | |
text_cfg * (noise_pred3 - noise_pred2) | |
) | |
if guidance_rescale > 0: | |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
del noise_pred, noise_pred1, noise_pred2, noise_pred3, pred_samples | |
del latent_input, context_input | |
torch.cuda.empty_cache() | |
# all_latent.append(latent.detach()) | |
# all_pred.append(pred.detach()) | |
return { | |
'latent': latent, | |
# 'all_latent': all_latent, | |
# 'all_pred': all_pred | |
} | |
def second_clip_forward( | |
self, | |
latent: torch.Tensor, | |
text_cond: torch.Tensor, | |
text_uncond: torch.Tensor, | |
img_cond: torch.Tensor, | |
latent_ref: torch.Tensor, | |
noise_correct_step: float = 1., | |
text_cfg = 7.5, | |
img_cfg = 1.2, | |
start_time: int = 0, | |
guidance_rescale: float = 0.0, | |
): | |
''' | |
latent1 | latent2 | latent3 | |
text x x v | |
img x v v | |
''' | |
num_ref_frames = latent_ref.shape[1] | |
all_latent = [] | |
all_pred = [] # x0_hat | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) | |
latent2 = torch.cat([latent, img_cond], dim=2) | |
latent3 = latent2.clone() | |
latent_input = torch.cat([latent1, latent2, latent3], dim=0) | |
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) | |
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
encoder_hidden_states=context_input, | |
).sample | |
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') | |
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) | |
noise_pred = ( | |
noise_pred1 + | |
img_cfg * (noise_pred2 - noise_pred1) + | |
text_cfg * (noise_pred3 - noise_pred2) | |
) | |
if guidance_rescale > 0: | |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) | |
# 所谓的再inference阶段加入 Long Video Sampling Correction(LVSC) | |
if noise_correct_step * self.num_ddim_steps > i: | |
alpha_prod_t = self.scheduler.alphas_cumprod[t] | |
beta_prod_t = 1 - alpha_prod_t | |
noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w | |
delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames] | |
delta_noise_remaining = delta_noise_ref.mean(dim=1, keepdim=True) | |
noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref | |
noise_pred[:, num_ref_frames:] = noise_pred[:, num_ref_frames:] + delta_noise_remaining | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
all_latent.append(latent.detach()) | |
all_pred.append(pred.detach()) | |
return { | |
'latent': latent, | |
'all_latent': all_latent, | |
'all_pred': all_pred | |
} | |
class InferenceIP2PVideoEnsemble(Inference): | |
def zeros(self, x): | |
return torch.zeros_like(x) | |
def __call__( | |
self, | |
latent: torch.Tensor, | |
text_cond: torch.Tensor, | |
text_uncond: torch.Tensor, | |
img_cond: torch.Tensor, | |
text_cfg = 7.5, | |
img_cfg = 1.2, | |
start_time: int = 0, | |
guidance_rescale: float = 0.0, | |
): | |
''' | |
latent1 | latent2 | latent3 | |
text x x v | |
img x v v | |
''' | |
all_latent = [] | |
all_pred = [] # x0_hat | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) | |
latent2 = torch.cat([latent, img_cond], dim=2) | |
latent3 = latent2.clone() | |
latent_input = torch.cat([latent1, latent2, latent3], dim=0) | |
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) | |
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
encoder_hidden_states=context_input, | |
).sample | |
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') | |
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) | |
noise_pred = ( | |
noise_pred1 + | |
img_cfg * (noise_pred2 - noise_pred1) + | |
text_cfg * (noise_pred3 - noise_pred2) | |
) | |
if guidance_rescale > 0: | |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
# average over all three samples. | |
latent = latent.mean(dim=0, keepdim=True).repeat(latent.shape[0], 1, 1, 1, 1) | |
# latent = latent[[0]].repeat(latent.shape[0], 1, 1, 1, 1) | |
pred = pred_samples.pred_original_sample | |
all_latent.append(latent.detach()) | |
all_pred.append(pred.detach()) | |
return { | |
'latent': latent, | |
'all_latent': all_latent, | |
'all_pred': all_pred | |
} | |
def second_clip_forward( | |
self, | |
latent: torch.Tensor, | |
text_cond: torch.Tensor, | |
text_uncond: torch.Tensor, | |
img_cond: torch.Tensor, | |
latent_ref: torch.Tensor, | |
noise_correct_step: float = 1., | |
text_cfg = 7.5, | |
img_cfg = 1.2, | |
start_time: int = 0, | |
guidance_rescale: float = 0.0, | |
): | |
''' | |
latent1 | latent2 | latent3 | |
text x x v | |
img x v v | |
''' | |
num_ref_frames = latent_ref.shape[1] | |
all_latent = [] | |
all_pred = [] # x0_hat | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) | |
latent2 = torch.cat([latent, img_cond], dim=2) | |
latent3 = latent2.clone() | |
latent_input = torch.cat([latent1, latent2, latent3], dim=0) | |
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) | |
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
encoder_hidden_states=context_input, | |
).sample | |
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') | |
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) | |
noise_pred = ( | |
noise_pred1 + | |
img_cfg * (noise_pred2 - noise_pred1) + | |
text_cfg * (noise_pred3 - noise_pred2) | |
) | |
if guidance_rescale > 0: | |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) | |
# 所谓的再inference阶段加入 Long Video Sampling Correction(LVSC) | |
if noise_correct_step * self.num_ddim_steps > i: | |
alpha_prod_t = self.scheduler.alphas_cumprod[t] | |
beta_prod_t = 1 - alpha_prod_t | |
noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w | |
delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames] | |
delta_noise_remaining = delta_noise_ref.mean(dim=1, keepdim=True) | |
noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref | |
noise_pred[:, num_ref_frames:] = noise_pred[:, num_ref_frames:] + delta_noise_remaining | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
all_latent.append(latent.detach()) | |
all_pred.append(pred.detach()) | |
return { | |
'latent': latent, | |
'all_latent': all_latent, | |
'all_pred': all_pred | |
} | |
class InferenceIP2PVideoHDR(Inference): | |
def zeros(self, x): | |
return torch.zeros_like(x) | |
def __call__( | |
self, | |
latent: torch.Tensor, | |
text_cond: torch.Tensor, | |
text_uncond: torch.Tensor,#(1,77,768) | |
hdr_cond: torch.Tensor, #(1,3,768) | |
img_cond: torch.Tensor, | |
text_cfg = 7.5, | |
img_cfg = 1.2, | |
hdr_cfg = 7.5, | |
start_time: int = 0, | |
guidance_rescale: float = 0.0, | |
): | |
''' | |
latent1 | latent2 | latent3 | latent4 | |
text x x v v | |
img x v v v | |
hdr x x x v | |
''' | |
all_latent = [] | |
all_pred = [] # x0_hat | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) | |
latent2 = torch.cat([latent, img_cond], dim=2) | |
latent3 = latent2.clone() | |
latent4 = latent2.clone() | |
latent_input = torch.cat([latent1, latent2, latent3, latent4], dim=0) | |
context_input = torch.cat([text_uncond, text_uncond, text_cond, text_cond], dim=0) #(4,77,768) | |
hdr_uncond = self.zeros(hdr_cond) | |
hdr_input = torch.cat([hdr_uncond, hdr_uncond, hdr_uncond, hdr_cond]) #(4,3,768) | |
model_kwargs1 = {'hdr_latents': hdr_input, 'encoder_hidden_states': context_input} | |
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
encoder_hidden_states=model_kwargs1, | |
).sample | |
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') | |
noise_pred1, noise_pred2, noise_pred3, noise_pred4 = noise_pred.chunk(4, dim=0) | |
noise_pred = ( | |
noise_pred1 + | |
img_cfg * (noise_pred2 - noise_pred1) + | |
text_cfg * (noise_pred3 - noise_pred2) + | |
hdr_cfg * (noise_pred4 - noise_pred3) | |
) | |
if guidance_rescale > 0: | |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
all_latent.append(latent.detach()) | |
all_pred.append(pred.detach()) | |
return { | |
'latent': latent, | |
'all_latent': all_latent, | |
'all_pred': all_pred | |
} | |
def second_clip_forward( | |
self, | |
latent: torch.Tensor, | |
text_cond: torch.Tensor, | |
text_uncond: torch.Tensor, | |
img_cond: torch.Tensor, | |
latent_ref: torch.Tensor, | |
noise_correct_step: float = 1., | |
text_cfg = 7.5, | |
img_cfg = 1.2, | |
start_time: int = 0, | |
guidance_rescale: float = 0.0, | |
): | |
''' | |
latent1 | latent2 | latent3 | |
text x x v | |
img x v v | |
''' | |
num_ref_frames = latent_ref.shape[1] | |
all_latent = [] | |
all_pred = [] # x0_hat | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) | |
latent2 = torch.cat([latent, img_cond], dim=2) | |
latent3 = latent2.clone() | |
latent_input = torch.cat([latent1, latent2, latent3], dim=0) | |
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) | |
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
encoder_hidden_states=context_input, | |
).sample | |
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') | |
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) | |
noise_pred = ( | |
noise_pred1 + | |
img_cfg * (noise_pred2 - noise_pred1) + | |
text_cfg * (noise_pred3 - noise_pred2) | |
) | |
if guidance_rescale > 0: | |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) | |
# 所谓的再inference阶段加入 Long Video Sampling Correction(LVSC) | |
if noise_correct_step * self.num_ddim_steps > i: | |
alpha_prod_t = self.scheduler.alphas_cumprod[t] | |
beta_prod_t = 1 - alpha_prod_t | |
noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w | |
delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames] | |
delta_noise_remaining = delta_noise_ref.mean(dim=1, keepdim=True) | |
noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref | |
noise_pred[:, num_ref_frames:] = noise_pred[:, num_ref_frames:] + delta_noise_remaining | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
all_latent.append(latent.detach()) | |
all_pred.append(pred.detach()) | |
return { | |
'latent': latent, | |
'all_latent': all_latent, | |
'all_pred': all_pred | |
} | |
class InferenceIP2PVideoOpticalFlow(InferenceIP2PVideo): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.flow_estimator = RAFTFlow().cuda() # 使用光流估计器 | |
def obtain_delta_noise(self, delta_noise_ref, flow): | |
flow = resize_flow(flow, delta_noise_ref.shape[2:]) | |
warped_delta_noise_ref = warp_image(delta_noise_ref, flow) # 根据光流扭曲参考帧的噪声差异 | |
valid_mask = torch.ones_like(delta_noise_ref)[:, :1] | |
valid_mask = warp_image(valid_mask, flow) | |
return warped_delta_noise_ref, valid_mask | |
def obtain_flow_batched(self, ref_images, query_images): | |
ref_images = ref_images.to() | |
warp_funcs = [] | |
for query_image in query_images: | |
query_image = query_image.unsqueeze(0).repeat(len(ref_images), 1, 1, 1) | |
flow = self.flow_estimator(query_image, ref_images) # 估计光流 | |
warp_func = partial(self.obtain_delta_noise, flow=flow) | |
warp_funcs.append(warp_func) | |
return warp_funcs | |
def second_clip_forward( | |
self, | |
latent: torch.Tensor, | |
text_cond: torch.Tensor, | |
text_uncond: torch.Tensor, | |
img_cond: torch.Tensor, | |
latent_ref: torch.Tensor, | |
ref_images: torch.Tensor, | |
query_images: torch.Tensor, | |
noise_correct_step: float = 1., | |
text_cfg = 7.5, | |
img_cfg = 1.2, | |
start_time: int = 0, | |
guidance_rescale: float = 0.0, | |
): | |
''' | |
latent1 | latent2 | latent3 | |
text x x v | |
img x v v | |
''' | |
assert ref_images.shape[0] == 1, 'only support batch size 1' | |
warp_funcs = self.obtain_flow_batched(ref_images[0], query_images[0]) | |
num_ref_frames = latent_ref.shape[1] | |
all_latent = [] | |
all_pred = [] # x0_hat | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) | |
latent2 = torch.cat([latent, img_cond], dim=2) | |
latent3 = latent2.clone() | |
latent_input = torch.cat([latent1, latent2, latent3], dim=0) | |
context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) | |
latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
encoder_hidden_states=context_input, | |
).sample | |
noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') | |
noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) | |
noise_pred = ( | |
noise_pred1 + | |
img_cfg * (noise_pred2 - noise_pred1) + | |
text_cfg * (noise_pred3 - noise_pred2) | |
) | |
if guidance_rescale > 0: | |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) | |
if noise_correct_step * self.num_ddim_steps > i: | |
alpha_prod_t = self.scheduler.alphas_cumprod[t] | |
beta_prod_t = 1 - alpha_prod_t | |
noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w | |
delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames] | |
noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref | |
for refed_index, warp_func in zip(range(num_ref_frames, noise_pred.shape[1]), warp_funcs): | |
delta_noise_remaining, delta_noise_mask = warp_func(delta_noise_ref[0]) | |
mask_sum = delta_noise_mask[None].sum(dim=1, keepdim=True) | |
delta_noise_remaining = torch.where( | |
mask_sum > 0.5, | |
delta_noise_remaining[None].sum(dim=1, keepdim=True) / mask_sum, | |
0. | |
) | |
noise_pred[:, refed_index: refed_index+1] += torch.where( | |
mask_sum > 0.5, | |
delta_noise_remaining, | |
0 | |
) # 将这个扭曲的噪声差异应用到当前帧,确保帧之间的噪声变化符合视频中物体的移动 | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
all_latent.append(latent.detach()) | |
all_pred.append(pred.detach()) | |
return { | |
'latent': latent, | |
'all_latent': all_latent, | |
'all_pred': all_pred | |
} |