|
import torch |
|
from einops import rearrange |
|
from typing import Any, Dict, Optional |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from canonicalize.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor |
|
|
|
|
|
class ReferenceOnlyAttnProc(torch.nn.Module): |
|
def __init__( |
|
self, |
|
chained_proc, |
|
enabled=False, |
|
name=None |
|
) -> None: |
|
super().__init__() |
|
self.enabled = enabled |
|
self.chained_proc = chained_proc |
|
self.name = name |
|
|
|
def __call__( |
|
self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, |
|
mode="w", ref_dict: dict = None, is_cfg_guidance = False,num_views=4, |
|
multiview_attention=True, |
|
cross_domain_attention=False, |
|
) -> Any: |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
|
|
if self.enabled: |
|
if mode == 'w': |
|
ref_dict[self.name] = encoder_hidden_states |
|
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=1, |
|
multiview_attention=False, |
|
cross_domain_attention=False,) |
|
elif mode == 'r': |
|
encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views) |
|
if self.name in ref_dict: |
|
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) |
|
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views, |
|
multiview_attention=False, |
|
cross_domain_attention=False,) |
|
elif mode == 'm': |
|
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) |
|
elif mode == 'n': |
|
encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views) |
|
encoder_hidden_states = torch.cat([encoder_hidden_states], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) |
|
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views, |
|
multiview_attention=False, |
|
cross_domain_attention=False,) |
|
else: |
|
assert False, mode |
|
else: |
|
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) |
|
return res |
|
|
|
class RefOnlyNoisedUNet(torch.nn.Module): |
|
def __init__(self, unet, train_sched, val_sched) -> None: |
|
super().__init__() |
|
self.unet = unet |
|
self.train_sched = train_sched |
|
self.val_sched = val_sched |
|
|
|
unet_lora_attn_procs = dict() |
|
for name, _ in unet.attn_processors.items(): |
|
if is_xformers_available(): |
|
default_attn_proc = XFormersMVAttnProcessor() |
|
else: |
|
default_attn_proc = MVAttnProcessor() |
|
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( |
|
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name) |
|
|
|
self.unet.set_attn_processor(unet_lora_attn_procs) |
|
|
|
def __getattr__(self, name: str): |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.unet, name) |
|
|
|
def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): |
|
if is_cfg_guidance: |
|
encoder_hidden_states = encoder_hidden_states[1:] |
|
class_labels = class_labels[1:] |
|
self.unet( |
|
noisy_cond_lat, timestep, |
|
encoder_hidden_states=encoder_hidden_states, |
|
class_labels=class_labels, |
|
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), |
|
**kwargs |
|
) |
|
|
|
def forward( |
|
self, sample, timestep, encoder_hidden_states, class_labels=None, |
|
*args, cross_attention_kwargs, |
|
down_block_res_samples=None, mid_block_res_sample=None, |
|
**kwargs |
|
): |
|
cond_lat = cross_attention_kwargs['cond_lat'] |
|
is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) |
|
noise = torch.randn_like(cond_lat) |
|
if self.training: |
|
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) |
|
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) |
|
else: |
|
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) |
|
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) |
|
ref_dict = {} |
|
self.forward_cond( |
|
noisy_cond_lat, timestep, |
|
encoder_hidden_states, class_labels, |
|
ref_dict, is_cfg_guidance, **kwargs |
|
) |
|
weight_dtype = self.unet.dtype |
|
return self.unet( |
|
sample, timestep, |
|
encoder_hidden_states, *args, |
|
class_labels=class_labels, |
|
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), |
|
down_block_additional_residuals=[ |
|
sample.to(dtype=weight_dtype) for sample in down_block_res_samples |
|
] if down_block_res_samples is not None else None, |
|
mid_block_additional_residual=( |
|
mid_block_res_sample.to(dtype=weight_dtype) |
|
if mid_block_res_sample is not None else None |
|
), |
|
**kwargs |
|
) |