|
|
|
|
|
import copy |
|
import random |
|
from contextlib import nullcontext |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from scepter.modules.model.network.ldm import LatentDiffusion |
|
from scepter.modules.model.registry import MODELS |
|
from scepter.modules.utils.config import dict_to_yaml |
|
from scepter.modules.utils.distribute import we |
|
|
|
from ..utils.basic_utils import ( |
|
check_list_of_list, |
|
pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor, |
|
to_device, |
|
unpack_tensor_into_imagelist |
|
) |
|
|
|
|
|
class TextEmbedding(nn.Module): |
|
def __init__(self, embedding_shape): |
|
super().__init__() |
|
self.pos = nn.Parameter(data=torch.zeros(embedding_shape)) |
|
|
|
|
|
@MODELS.register_class() |
|
class LdmACE(LatentDiffusion): |
|
para_dict = LatentDiffusion.para_dict |
|
para_dict['DECODER_BIAS'] = {'value': 0, 'description': ''} |
|
|
|
def __init__(self, cfg, logger=None): |
|
super().__init__(cfg, logger=logger) |
|
self.interpolate_func = lambda x: (F.interpolate( |
|
x.unsqueeze(0), |
|
scale_factor=1 / self.size_factor, |
|
mode='nearest-exact') if x is not None else None) |
|
|
|
self.text_indentifers = cfg.get('TEXT_IDENTIFIER', []) |
|
self.use_text_pos_embeddings = cfg.get('USE_TEXT_POS_EMBEDDINGS', |
|
False) |
|
if self.use_text_pos_embeddings: |
|
self.text_position_embeddings = TextEmbedding( |
|
(10, 4096)).eval().requires_grad_(False) |
|
else: |
|
self.text_position_embeddings = None |
|
|
|
self.logger.info(self.model) |
|
|
|
@torch.no_grad() |
|
def encode_first_stage(self, x, **kwargs): |
|
return [ |
|
self.scale_factor * |
|
self.first_stage_model._encode(i.unsqueeze(0).to(torch.float16)) |
|
for i in x |
|
] |
|
|
|
@torch.no_grad() |
|
def decode_first_stage(self, z): |
|
return [ |
|
self.first_stage_model._decode(1. / self.scale_factor * |
|
i.to(torch.float16)) for i in z |
|
] |
|
|
|
def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask): |
|
if self.use_text_pos_embeddings and not torch.sum( |
|
self.text_position_embeddings.pos) > 0: |
|
identifier_cont, identifier_cont_mask = getattr( |
|
self.cond_stage_model, 'encode')(self.text_indentifers, |
|
return_mask=True) |
|
self.text_position_embeddings.load_state_dict( |
|
{'pos': identifier_cont[:, 0, :]}) |
|
cont_, cont_mask_ = [], [] |
|
for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask): |
|
if isinstance(pp, list): |
|
cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]]) |
|
cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]]) |
|
else: |
|
raise NotImplementedError |
|
|
|
return cont_, cont_mask_ |
|
|
|
def limit_batch_data(self, batch_data_list, log_num): |
|
if log_num and log_num > 0: |
|
batch_data_list_limited = [] |
|
for sub_data in batch_data_list: |
|
if sub_data is not None: |
|
sub_data = sub_data[:log_num] |
|
batch_data_list_limited.append(sub_data) |
|
return batch_data_list_limited |
|
else: |
|
return batch_data_list |
|
|
|
def forward_train(self, |
|
edit_image=[], |
|
edit_image_mask=[], |
|
image=None, |
|
image_mask=None, |
|
noise=None, |
|
prompt=[], |
|
**kwargs): |
|
''' |
|
Args: |
|
edit_image: list of list of edit_image |
|
edit_image_mask: list of list of edit_image_mask |
|
image: target image |
|
image_mask: target image mask |
|
noise: default is None, generate automaticly |
|
prompt: list of list of text |
|
**kwargs: |
|
Returns: |
|
''' |
|
assert check_list_of_list(prompt) and check_list_of_list( |
|
edit_image) and check_list_of_list(edit_image_mask) |
|
assert len(edit_image) == len(edit_image_mask) == len(prompt) |
|
assert self.cond_stage_model is not None |
|
gc_seg = kwargs.pop('gc_seg', []) |
|
gc_seg = int(gc_seg[0]) if len(gc_seg) > 0 else 0 |
|
context = {} |
|
|
|
|
|
image = to_device(image) |
|
x_start = self.encode_first_stage(image, **kwargs) |
|
x_start, x_shapes = pack_imagelist_into_tensor(x_start) |
|
n, _, _ = x_start.shape |
|
t = torch.randint(0, self.num_timesteps, (n, ), |
|
device=x_start.device).long() |
|
context['x_shapes'] = x_shapes |
|
|
|
|
|
image_mask = to_device(image_mask, strict=False) |
|
context['x_mask'] = [self.interpolate_func(i) for i in image_mask |
|
] if image_mask is not None else [None] * n |
|
|
|
|
|
|
|
prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt] |
|
try: |
|
cont, cont_mask = getattr(self.cond_stage_model, |
|
'encode_list')(prompt_, return_mask=True) |
|
except Exception as e: |
|
print(e, prompt_) |
|
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, |
|
cont_mask) |
|
context['crossattn'] = cont |
|
|
|
|
|
edit_image = [to_device(i, strict=False) for i in edit_image] |
|
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] |
|
e_img, e_mask = [], [] |
|
for u, m in zip(edit_image, edit_image_mask): |
|
if m is None: |
|
m = [None] * len(u) if u is not None else [None] |
|
e_img.append( |
|
self.encode_first_stage(u, **kwargs) if u is not None else u) |
|
e_mask.append([ |
|
self.interpolate_func(i) if i is not None else None for i in m |
|
]) |
|
context['edit'], context['edit_mask'] = e_img, e_mask |
|
|
|
|
|
loss = self.diffusion.loss( |
|
x_0=x_start, |
|
t=t, |
|
noise=noise, |
|
model=self.model, |
|
model_kwargs={ |
|
'cond': |
|
context, |
|
'mask': |
|
cont_mask, |
|
'gc_seg': |
|
gc_seg, |
|
'text_position_embeddings': |
|
self.text_position_embeddings.pos if hasattr( |
|
self.text_position_embeddings, 'pos') else None |
|
}, |
|
**kwargs) |
|
loss = loss.mean() |
|
ret = {'loss': loss, 'probe_data': {'prompt': prompt}} |
|
return ret |
|
|
|
@torch.no_grad() |
|
def forward_test(self, |
|
edit_image=[], |
|
edit_image_mask=[], |
|
image=None, |
|
image_mask=None, |
|
prompt=[], |
|
n_prompt=[], |
|
sampler='ddim', |
|
sample_steps=20, |
|
guide_scale=4.5, |
|
guide_rescale=0.5, |
|
log_num=-1, |
|
seed=2024, |
|
**kwargs): |
|
|
|
assert check_list_of_list(prompt) and check_list_of_list( |
|
edit_image) and check_list_of_list(edit_image_mask) |
|
assert len(edit_image) == len(edit_image_mask) == len(prompt) |
|
assert self.cond_stage_model is not None |
|
|
|
kwargs.pop('gc_seg', -1) |
|
|
|
context, null_context = {}, {} |
|
|
|
prompt, n_prompt, image, image_mask, edit_image, edit_image_mask = self.limit_batch_data( |
|
[prompt, n_prompt, image, image_mask, edit_image, edit_image_mask], |
|
log_num) |
|
g = torch.Generator(device=we.device_id) |
|
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) |
|
g.manual_seed(seed) |
|
n_prompt = copy.deepcopy(prompt) |
|
|
|
for nn_p_id, nn_p in enumerate(n_prompt): |
|
if isinstance(nn_p, str): |
|
n_prompt[nn_p_id] = [''] |
|
elif isinstance(nn_p, list): |
|
n_prompt[nn_p_id][-1] = '' |
|
else: |
|
raise NotImplementedError |
|
|
|
image = to_device(image) |
|
x = self.encode_first_stage(image, **kwargs) |
|
noise = [ |
|
torch.empty(*i.shape, device=we.device_id).normal_(generator=g) |
|
for i in x |
|
] |
|
noise, x_shapes = pack_imagelist_into_tensor(noise) |
|
context['x_shapes'] = null_context['x_shapes'] = x_shapes |
|
|
|
|
|
image_mask = to_device(image_mask, strict=False) |
|
cond_mask = [self.interpolate_func(i) for i in image_mask |
|
] if image_mask is not None else [None] * len(image) |
|
context['x_mask'] = null_context['x_mask'] = cond_mask |
|
|
|
|
|
prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt] |
|
cont, cont_mask = getattr(self.cond_stage_model, |
|
'encode_list')(prompt_, return_mask=True) |
|
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, |
|
cont_mask) |
|
null_cont, null_cont_mask = getattr(self.cond_stage_model, |
|
'encode_list')(n_prompt, |
|
return_mask=True) |
|
null_cont, null_cont_mask = self.cond_stage_embeddings( |
|
prompt, edit_image, null_cont, null_cont_mask) |
|
context['crossattn'] = cont |
|
null_context['crossattn'] = null_cont |
|
|
|
|
|
edit_image = [to_device(i, strict=False) for i in edit_image] |
|
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] |
|
e_img, e_mask = [], [] |
|
for u, m in zip(edit_image, edit_image_mask): |
|
if u is None: |
|
continue |
|
if m is None: |
|
m = [None] * len(u) |
|
e_img.append(self.encode_first_stage(u, **kwargs)) |
|
e_mask.append([self.interpolate_func(i) for i in m]) |
|
null_context['edit'] = context['edit'] = e_img |
|
null_context['edit_mask'] = context['edit_mask'] = e_mask |
|
|
|
|
|
model = self.model_ema if self.use_ema and self.eval_ema else self.model |
|
embedding_context = model.no_sync if isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel) \ |
|
else nullcontext |
|
with embedding_context(): |
|
samples = self.diffusion.sample( |
|
sampler=sampler, |
|
noise=noise, |
|
model=model, |
|
model_kwargs=[{ |
|
'cond': |
|
context, |
|
'mask': |
|
cont_mask, |
|
'text_position_embeddings': |
|
self.text_position_embeddings.pos if hasattr( |
|
self.text_position_embeddings, 'pos') else None |
|
}, { |
|
'cond': |
|
null_context, |
|
'mask': |
|
null_cont_mask, |
|
'text_position_embeddings': |
|
self.text_position_embeddings.pos if hasattr( |
|
self.text_position_embeddings, 'pos') else None |
|
}] if guide_scale is not None and guide_scale > 1 else { |
|
'cond': |
|
context, |
|
'mask': |
|
cont_mask, |
|
'text_position_embeddings': |
|
self.text_position_embeddings.pos if hasattr( |
|
self.text_position_embeddings, 'pos') else None |
|
}, |
|
steps=sample_steps, |
|
guide_scale=guide_scale, |
|
guide_rescale=guide_rescale, |
|
show_progress=True, |
|
**kwargs) |
|
|
|
samples = unpack_tensor_into_imagelist(samples, x_shapes) |
|
x_samples = self.decode_first_stage(samples) |
|
outputs = list() |
|
for i in range(len(prompt)): |
|
rec_img = torch.clamp( |
|
(x_samples[i] + 1.0) / 2.0 + self.decoder_bias / 255, |
|
min=0.0, |
|
max=1.0) |
|
rec_img = rec_img.squeeze(0) |
|
edit_imgs, edit_img_masks = [], [] |
|
if edit_image is not None and edit_image[i] is not None: |
|
if edit_image_mask[i] is None: |
|
edit_image_mask[i] = [None] * len(edit_image[i]) |
|
for edit_img, edit_mask in zip(edit_image[i], |
|
edit_image_mask[i]): |
|
edit_img = torch.clamp((edit_img + 1.0) / 2.0, |
|
min=0.0, |
|
max=1.0) |
|
edit_imgs.append(edit_img.squeeze(0)) |
|
if edit_mask is None: |
|
edit_mask = torch.ones_like(edit_img[[0], :, :]) |
|
edit_img_masks.append(edit_mask) |
|
one_tup = { |
|
'reconstruct_image': rec_img, |
|
'instruction': prompt[i], |
|
'edit_image': edit_imgs if len(edit_imgs) > 0 else None, |
|
'edit_mask': edit_img_masks if len(edit_imgs) > 0 else None |
|
} |
|
if image is not None: |
|
if image_mask is None: |
|
image_mask = [None] * len(image) |
|
ori_img = torch.clamp((image[i] + 1.0) / 2.0, min=0.0, max=1.0) |
|
one_tup['target_image'] = ori_img.squeeze(0) |
|
one_tup['target_mask'] = image_mask[i] if image_mask[ |
|
i] is not None else torch.ones_like(ori_img[[0], :, :]) |
|
outputs.append(one_tup) |
|
return outputs |
|
|
|
@staticmethod |
|
def get_config_template(): |
|
return dict_to_yaml('MODEL', |
|
__class__.__name__, |
|
LdmACE.para_dict, |
|
set_name=True) |
|
|