File size: 2,784 Bytes
8310a1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import torch
def normalise2(tensor):
'''[0,1] -> [-1,1]'''
return (tensor*2 - 1.).clamp(-1,1)
def tfg_data(dataloader, face_hide_percentage, use_ref, use_audio):#, sampling_use_gt_for_ref=False, noise = None):
def inf_gen(generator):
while True:
yield from generator
data = inf_gen(dataloader)
for batch in data:
img_batch, model_kwargs = tfg_process_batch(batch, face_hide_percentage, use_ref, use_audio)
yield img_batch, model_kwargs
def tfg_process_batch(batch, face_hide_percentage, use_ref=False, use_audio=False, sampling_use_gt_for_ref=False, noise = None):
model_kwargs = {}
B, F,C, H, W = batch["image"].shape
img_batch = normalise2(batch["image"].reshape(B*F, C, H, W).contiguous())
model_kwargs = tfg_add_cond_inputs(img_batch, model_kwargs, face_hide_percentage, noise)
if use_ref:
model_kwargs = tfg_add_reference(batch, model_kwargs, sampling_use_gt_for_ref)
if use_audio:
model_kwargs = tfg_add_audio(batch,model_kwargs)
return img_batch, model_kwargs
def tfg_add_reference(batch, model_kwargs, sampling_use_gt_for_ref=False):
# assuming nrefer = 1
#[B, nframes, C, H, W] -> #[B*nframes, C, H, W]
if sampling_use_gt_for_ref:
B, F,C, H, W = batch["image"].shape
img_batch = normalise2(batch["image"].reshape(B*F, C, H, W).contiguous())
model_kwargs["ref_img"] = img_batch
else:
_, _, C, H , W = batch["ref_img"].shape
ref_img = normalise2(batch["ref_img"].reshape(-1, C, H, W).contiguous())
model_kwargs["ref_img"] = ref_img
return model_kwargs
def tfg_add_audio(batch, model_kwargs):
# unet needs [BF, h, w] as input
B, F, _, h, w = batch["indiv_mels"].shape
indiv_mels = batch["indiv_mels"] # [B, F, 1, h, w]
indiv_mels = indiv_mels.squeeze(dim=2).reshape(B*F, h , w)
model_kwargs["indiv_mels"] = indiv_mels
# syncloss needs [B, 1, 80, 16] as input
if "mel" in batch:
mel = batch["mel"] #[B, 1, h, w]
model_kwargs["mel"]=mel
return model_kwargs
def tfg_add_cond_inputs(img_batch, model_kwargs, face_hide_percentage, noise=None):
B, C, H, W = img_batch.shape
mask = torch.zeros(B,1,H,W)
mask_start_idx = int (H*(1-face_hide_percentage))
mask[:,:,mask_start_idx:,:]=1.
if noise is None:
noise = torch.randn_like(img_batch)
assert noise.shape == img_batch.shape, "Noise shape != Image shape"
cond_img = img_batch *(1. - mask)+mask*noise
model_kwargs["cond_img"] = cond_img
model_kwargs["mask"] = mask
return model_kwargs
def get_n_params(model):
pp=0
for p in list(model.parameters()):
nn=1
for s in list(p.size()):
nn=nn*s
pp+=nn
return pp |