Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from typing import List, Union, Tuple | |
from tqdm import tqdm | |
from .inference import Inference | |
class InferenceDAMO(Inference): | |
def __call__( | |
self, | |
latent: torch.Tensor, | |
context: torch.Tensor, | |
uncond_context: torch.Tensor=None, | |
start_time: int = 0, | |
null_embedding: List[torch.Tensor]=None, | |
): | |
all_latent = [] | |
all_pred = [] # x0_hat | |
do_classifier_free_guidance = self.guidance_scale > 1 and ((uncond_context is not None) or (null_embedding is not None)) | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
if do_classifier_free_guidance: | |
latent_input = torch.cat([latent, latent], dim=0) | |
if null_embedding is not None: | |
context_input = torch.cat([null_embedding[i], context], dim=0) | |
else: | |
context_input = torch.cat([uncond_context, context], dim=0) | |
else: | |
latent_input = latent | |
context_input = context | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
context_input, | |
) | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
all_latent.append(latent.detach()) | |
all_pred.append(pred.detach()) | |
return { | |
'latent': latent, | |
'all_latent': all_latent, | |
'all_pred': all_pred | |
} | |
class InferenceDAMO_PTP(Inference): | |
def infer_old_context(self, latent, context, t, uncond_context=None): | |
do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) | |
if do_classifier_free_guidance: | |
latent_input = torch.cat([latent, latent], dim=0) | |
context_input = torch.cat([uncond_context, context], dim=0) | |
else: | |
latent_input = latent | |
context_input = context | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
context_input, | |
) | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
return latent, pred | |
def infer_new_context(self, latent, context, t, uncond_context=None): | |
do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) | |
if do_classifier_free_guidance: | |
latent_input = torch.cat([latent, latent], dim=0) | |
if isinstance(context, (list, tuple)): | |
context_input = ( | |
torch.cat([uncond_context, context[0]], dim=0), | |
torch.cat([uncond_context, context[1]], dim=0), | |
) | |
else: | |
context_input = torch.cat([uncond_context, context], dim=0) | |
else: | |
latent_input = latent | |
context_input = context | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
context_input, | |
) | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
return latent, pred | |
def __call__( | |
self, | |
latent: torch.Tensor, | |
context: torch.Tensor, # used when > ca_end_time | |
old_context: torch.Tensor=None, # used when < sa_end_time | |
old_to_new_context: Union[Tuple, List]=None, # used when sa_end_time < t < ca_end_time | |
uncond_context: torch.Tensor=None, | |
sa_end_time: float=0.3, | |
ca_end_time: float=0.8, | |
start_time: int = 0, | |
): | |
assert sa_end_time < ca_end_time, f"sa_end_time must be less than ca_end_time, got {sa_end_time} and {ca_end_time} respectively" | |
all_latent = [] | |
all_pred = [] | |
all_latent_old = [] | |
all_pred_old = [] | |
old_latent = latent.clone() | |
new_latent = latent.clone() | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context) | |
if i < sa_end_time * self.num_ddim_steps: | |
new_latent_next_t, pred_new = old_latent_next_t, pred_old | |
elif sa_end_time * self.num_ddim_steps <= i < ca_end_time * self.num_ddim_steps: | |
new_latent_next_t, pred_new = self.infer_new_context( | |
new_latent, old_to_new_context, t, uncond_context | |
) | |
else: | |
new_latent_next_t, pred_new = self.infer_new_context( | |
new_latent, context, t, uncond_context | |
) | |
old_latent = old_latent_next_t | |
new_latent = new_latent_next_t | |
all_latent.append(new_latent_next_t.detach()) | |
all_pred.append(pred_new.detach()) | |
all_latent_old.append(old_latent_next_t.detach()) | |
all_pred_old.append(pred_old.detach()) | |
return { | |
'latent': new_latent, | |
'latent_old': old_latent, | |
'all_latent': all_latent, | |
'all_pred': all_pred, | |
'all_latent_old': all_latent_old, | |
'all_pred_old': all_pred_old, | |
} | |
class InferenceDAMO_PTP_v2(Inference): | |
def set_ptp_in_xattn_layers(self, prompt_to_prompt: bool, num_frames=1): | |
for m in self.unet.modules(): | |
if m.__class__.__name__ == 'CrossAttention': | |
m.ptp_sa_replace = prompt_to_prompt | |
m.num_frames = num_frames | |
def infer_both_with_sa_replace(self, old_latent, new_latent, old_context, new_context, t, uncond_context=None): | |
do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) | |
if do_classifier_free_guidance: | |
latent_input = torch.cat([old_latent, new_latent, old_latent, new_latent], dim=0) | |
context_input = torch.cat([uncond_context, uncond_context, old_context, new_context], dim=0) | |
else: | |
latent_input = torch.cat([old_latent, new_latent], dim=0) | |
context_input = torch.cat([old_context, new_context], dim=0) | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
context_input, | |
) | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
noise_pred_old, noise_pred_new = noise_pred.chunk(2, dim=0) | |
pred_samples_old = self.scheduler.step(noise_pred_old, t, old_latent) | |
pred_samples_new = self.scheduler.step(noise_pred_new, t, new_latent) | |
old_latent = pred_samples_old.prev_sample | |
new_latent = pred_samples_new.prev_sample | |
old_pred = pred_samples_old.pred_original_sample | |
new_pred = pred_samples_new.pred_original_sample | |
return old_latent, new_latent, old_pred, new_pred | |
def infer_old_context(self, latent, context, t, uncond_context=None): | |
do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) | |
if do_classifier_free_guidance: | |
latent_input = torch.cat([latent, latent], dim=0) | |
context_input = torch.cat([uncond_context, context], dim=0) | |
else: | |
latent_input = latent | |
context_input = context | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
context_input, | |
) | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
return latent, pred | |
def infer_new_context(self, latent, context, t, uncond_context=None): | |
do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) | |
if do_classifier_free_guidance: | |
latent_input = torch.cat([latent, latent], dim=0) | |
if isinstance(context, (list, tuple)): | |
context_input = ( | |
torch.cat([uncond_context, context[0]], dim=0), | |
torch.cat([uncond_context, context[1]], dim=0), | |
) | |
else: | |
context_input = torch.cat([uncond_context, context], dim=0) | |
else: | |
latent_input = latent | |
context_input = context | |
noise_pred = self.unet( | |
latent_input, | |
torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
context_input, | |
) | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
pred_samples = self.scheduler.step(noise_pred, t, latent) | |
latent = pred_samples.prev_sample | |
pred = pred_samples.pred_original_sample | |
return latent, pred | |
def __call__( | |
self, | |
latent: torch.Tensor, | |
context: torch.Tensor, # used when > ca_end_time | |
old_context: torch.Tensor=None, # used when < sa_end_time | |
old_to_new_context: Union[Tuple, List]=None, # used when sa_end_time < t < ca_end_time | |
uncond_context: torch.Tensor=None, | |
sa_end_time: float=0.3, | |
ca_end_time: float=0.8, | |
start_time: int = 0, | |
): | |
assert sa_end_time < ca_end_time, f"sa_end_time must be less than ca_end_time, got {sa_end_time} and {ca_end_time} respectively" | |
all_latent = [] | |
all_pred = [] | |
all_latent_old = [] | |
all_pred_old = [] | |
old_latent = latent.clone() | |
new_latent = latent.clone() | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
t = int(t) | |
if i < sa_end_time * self.num_ddim_steps: | |
self.set_ptp_in_xattn_layers(True, num_frames=latent.shape[2]) | |
old_latent_next_t, new_latent_next_t, pred_old, pred_new = self.infer_both_with_sa_replace( | |
old_latent, new_latent, old_context, context, t, uncond_context | |
) | |
elif sa_end_time * self.num_ddim_steps <= i < ca_end_time * self.num_ddim_steps: | |
self.set_ptp_in_xattn_layers(False) | |
old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context) | |
new_latent_next_t, pred_new = self.infer_new_context( | |
new_latent, old_to_new_context, t, uncond_context | |
) | |
else: | |
self.set_ptp_in_xattn_layers(False) | |
old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context) | |
new_latent_next_t, pred_new = self.infer_new_context( | |
new_latent, context, t, uncond_context | |
) | |
old_latent = old_latent_next_t | |
new_latent = new_latent_next_t | |
all_latent.append(new_latent_next_t.detach()) | |
all_pred.append(pred_new.detach()) | |
all_latent_old.append(old_latent_next_t.detach()) | |
all_pred_old.append(pred_old.detach()) | |
return { | |
'latent': new_latent, | |
'latent_old': old_latent, | |
'all_latent': all_latent, | |
'all_pred': all_pred, | |
'all_latent_old': all_latent_old, | |
'all_pred_old': all_pred_old, | |
} |