RelightVid / pl_trainer /inference /inference_damo.py
aleafy's picture
Start fresh
0a63786
import torch
from typing import List, Union, Tuple
from tqdm import tqdm
from .inference import Inference
class InferenceDAMO(Inference):
@torch.no_grad()
def __call__(
self,
latent: torch.Tensor,
context: torch.Tensor,
uncond_context: torch.Tensor=None,
start_time: int = 0,
null_embedding: List[torch.Tensor]=None,
):
all_latent = []
all_pred = [] # x0_hat
do_classifier_free_guidance = self.guidance_scale > 1 and ((uncond_context is not None) or (null_embedding is not None))
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
if do_classifier_free_guidance:
latent_input = torch.cat([latent, latent], dim=0)
if null_embedding is not None:
context_input = torch.cat([null_embedding[i], context], dim=0)
else:
context_input = torch.cat([uncond_context, context], dim=0)
else:
latent_input = latent
context_input = context
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
context_input,
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
all_latent.append(latent.detach())
all_pred.append(pred.detach())
return {
'latent': latent,
'all_latent': all_latent,
'all_pred': all_pred
}
class InferenceDAMO_PTP(Inference):
def infer_old_context(self, latent, context, t, uncond_context=None):
do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None)
if do_classifier_free_guidance:
latent_input = torch.cat([latent, latent], dim=0)
context_input = torch.cat([uncond_context, context], dim=0)
else:
latent_input = latent
context_input = context
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
context_input,
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
return latent, pred
def infer_new_context(self, latent, context, t, uncond_context=None):
do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None)
if do_classifier_free_guidance:
latent_input = torch.cat([latent, latent], dim=0)
if isinstance(context, (list, tuple)):
context_input = (
torch.cat([uncond_context, context[0]], dim=0),
torch.cat([uncond_context, context[1]], dim=0),
)
else:
context_input = torch.cat([uncond_context, context], dim=0)
else:
latent_input = latent
context_input = context
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
context_input,
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
return latent, pred
@torch.no_grad()
def __call__(
self,
latent: torch.Tensor,
context: torch.Tensor, # used when > ca_end_time
old_context: torch.Tensor=None, # used when < sa_end_time
old_to_new_context: Union[Tuple, List]=None, # used when sa_end_time < t < ca_end_time
uncond_context: torch.Tensor=None,
sa_end_time: float=0.3,
ca_end_time: float=0.8,
start_time: int = 0,
):
assert sa_end_time < ca_end_time, f"sa_end_time must be less than ca_end_time, got {sa_end_time} and {ca_end_time} respectively"
all_latent = []
all_pred = []
all_latent_old = []
all_pred_old = []
old_latent = latent.clone()
new_latent = latent.clone()
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context)
if i < sa_end_time * self.num_ddim_steps:
new_latent_next_t, pred_new = old_latent_next_t, pred_old
elif sa_end_time * self.num_ddim_steps <= i < ca_end_time * self.num_ddim_steps:
new_latent_next_t, pred_new = self.infer_new_context(
new_latent, old_to_new_context, t, uncond_context
)
else:
new_latent_next_t, pred_new = self.infer_new_context(
new_latent, context, t, uncond_context
)
old_latent = old_latent_next_t
new_latent = new_latent_next_t
all_latent.append(new_latent_next_t.detach())
all_pred.append(pred_new.detach())
all_latent_old.append(old_latent_next_t.detach())
all_pred_old.append(pred_old.detach())
return {
'latent': new_latent,
'latent_old': old_latent,
'all_latent': all_latent,
'all_pred': all_pred,
'all_latent_old': all_latent_old,
'all_pred_old': all_pred_old,
}
class InferenceDAMO_PTP_v2(Inference):
def set_ptp_in_xattn_layers(self, prompt_to_prompt: bool, num_frames=1):
for m in self.unet.modules():
if m.__class__.__name__ == 'CrossAttention':
m.ptp_sa_replace = prompt_to_prompt
m.num_frames = num_frames
def infer_both_with_sa_replace(self, old_latent, new_latent, old_context, new_context, t, uncond_context=None):
do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None)
if do_classifier_free_guidance:
latent_input = torch.cat([old_latent, new_latent, old_latent, new_latent], dim=0)
context_input = torch.cat([uncond_context, uncond_context, old_context, new_context], dim=0)
else:
latent_input = torch.cat([old_latent, new_latent], dim=0)
context_input = torch.cat([old_context, new_context], dim=0)
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
context_input,
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
noise_pred_old, noise_pred_new = noise_pred.chunk(2, dim=0)
pred_samples_old = self.scheduler.step(noise_pred_old, t, old_latent)
pred_samples_new = self.scheduler.step(noise_pred_new, t, new_latent)
old_latent = pred_samples_old.prev_sample
new_latent = pred_samples_new.prev_sample
old_pred = pred_samples_old.pred_original_sample
new_pred = pred_samples_new.pred_original_sample
return old_latent, new_latent, old_pred, new_pred
def infer_old_context(self, latent, context, t, uncond_context=None):
do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None)
if do_classifier_free_guidance:
latent_input = torch.cat([latent, latent], dim=0)
context_input = torch.cat([uncond_context, context], dim=0)
else:
latent_input = latent
context_input = context
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
context_input,
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
return latent, pred
def infer_new_context(self, latent, context, t, uncond_context=None):
do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None)
if do_classifier_free_guidance:
latent_input = torch.cat([latent, latent], dim=0)
if isinstance(context, (list, tuple)):
context_input = (
torch.cat([uncond_context, context[0]], dim=0),
torch.cat([uncond_context, context[1]], dim=0),
)
else:
context_input = torch.cat([uncond_context, context], dim=0)
else:
latent_input = latent
context_input = context
noise_pred = self.unet(
latent_input,
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long),
context_input,
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
pred_samples = self.scheduler.step(noise_pred, t, latent)
latent = pred_samples.prev_sample
pred = pred_samples.pred_original_sample
return latent, pred
@torch.no_grad()
def __call__(
self,
latent: torch.Tensor,
context: torch.Tensor, # used when > ca_end_time
old_context: torch.Tensor=None, # used when < sa_end_time
old_to_new_context: Union[Tuple, List]=None, # used when sa_end_time < t < ca_end_time
uncond_context: torch.Tensor=None,
sa_end_time: float=0.3,
ca_end_time: float=0.8,
start_time: int = 0,
):
assert sa_end_time < ca_end_time, f"sa_end_time must be less than ca_end_time, got {sa_end_time} and {ca_end_time} respectively"
all_latent = []
all_pred = []
all_latent_old = []
all_pred_old = []
old_latent = latent.clone()
new_latent = latent.clone()
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
t = int(t)
if i < sa_end_time * self.num_ddim_steps:
self.set_ptp_in_xattn_layers(True, num_frames=latent.shape[2])
old_latent_next_t, new_latent_next_t, pred_old, pred_new = self.infer_both_with_sa_replace(
old_latent, new_latent, old_context, context, t, uncond_context
)
elif sa_end_time * self.num_ddim_steps <= i < ca_end_time * self.num_ddim_steps:
self.set_ptp_in_xattn_layers(False)
old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context)
new_latent_next_t, pred_new = self.infer_new_context(
new_latent, old_to_new_context, t, uncond_context
)
else:
self.set_ptp_in_xattn_layers(False)
old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context)
new_latent_next_t, pred_new = self.infer_new_context(
new_latent, context, t, uncond_context
)
old_latent = old_latent_next_t
new_latent = new_latent_next_t
all_latent.append(new_latent_next_t.detach())
all_pred.append(pred_new.detach())
all_latent_old.append(old_latent_next_t.detach())
all_pred_old.append(pred_old.detach())
return {
'latent': new_latent,
'latent_old': old_latent,
'all_latent': all_latent,
'all_pred': all_pred,
'all_latent_old': all_latent_old,
'all_pred_old': all_pred_old,
}