RelightVid / pl_trainer /diffusion.py
aleafy's picture
Start fresh
0a63786
import torch
from torch import nn
import pytorch_lightning as pl
from misc_utils.model_utils import default, instantiate_from_config
from diffusers import DDPMScheduler
from safetensors.torch import load_file
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
class DDPM(pl.LightningModule):
def __init__(
self,
unet,
beta_schedule_args={
'beta_start': 0.00085,
'beta_end': 0.0012,
'num_train_timesteps': 1000,
'beta_schedule': 'scaled_linear',
'clip_sample': False,
'thresholding': False,
},
prediction_type='epsilon',
loss_fn='l2',
optim_args={},
base_path=None,
**kwargs
):
'''
denoising_fn: a denoising model such as UNet
beta_schedule_args: a dictionary which contains
the configurations of the beta schedule
'''
super().__init__(**kwargs)
self.unet = unet
self.prediction_type = prediction_type
beta_schedule_args.update({'prediction_type': prediction_type})
self.set_beta_schedule(beta_schedule_args)
self.num_timesteps = beta_schedule_args['num_train_timesteps']
self.optim_args = optim_args
self.loss = loss_fn
self.base_path = base_path
if loss_fn == 'l2' or loss_fn == 'mse':
self.loss_fn = nn.MSELoss(reduction='none')
elif loss_fn == 'l1' or loss_fn == 'mae':
self.loss_fn = nn.L1Loss(reduction='none')
elif isinstance(loss_fn, dict):
self.loss_fn = instantiate_from_config(loss_fn)
else:
raise NotImplementedError
def set_beta_schedule(self, beta_schedule_args):
self.beta_schedule_args = beta_schedule_args
self.scheduler = DDPMScheduler(**beta_schedule_args)
@torch.no_grad()
def add_noise(self, x, t, noise=None):
noise = default(noise, torch.randn_like(x))
return self.scheduler.add_noise(x, noise, t)
def predict_x_0_from_x_t(self, model_output: torch.Tensor, t: torch.LongTensor, x_t: torch.Tensor): # 这边是一个缓存值: predicted x0
''' recover x_0 from predicted noise. Reverse of Eq(4) in DDPM paper
\hat(x_0) = 1 / sqrt[\bar(a)]*x_t - sqrt[(1-\bar(a)) / \bar(a)]*noise'''
# return self.scheduler.step(model_output, int(t), x_t).pred_original_sample
if self.prediction_type == 'sample':
return model_output
# for training target == epsilon
alphas_cumprod = self.scheduler.alphas_cumprod.to(device=x_t.device, dtype=x_t.dtype)
sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod[t]).flatten()
sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod[t] - 1.).flatten()
while len(sqrt_recip_alphas_cumprod.shape) < len(x_t.shape):
sqrt_recip_alphas_cumprod = sqrt_recip_alphas_cumprod.unsqueeze(-1)
sqrt_recipm1_alphas_cumprod = sqrt_recipm1_alphas_cumprod.unsqueeze(-1)
return sqrt_recip_alphas_cumprod * x_t - sqrt_recipm1_alphas_cumprod * model_output
def predict_x_tm1_from_x_t(self, model_output, t, x_t):
'''predict x_{t-1} from x_t and model_output'''
return self.scheduler.step(model_output, t, x_t).prev_sample
class DDPMTraining(DDPM): # 加入training step保证训练等等
def __init__(
self,
unet,
beta_schedule_args,
prediction_type='epsilon',
loss_fn='l2',
optim_args={
'lr': 1e-3,
'weight_decay': 5e-4
},
log_args={}, # for record all arguments with self.save_hyperparameters
ddim_sampling_steps=20,
guidance_scale=5.,
**kwargs
):
super().__init__(
unet=unet,
beta_schedule_args=beta_schedule_args,
prediction_type=prediction_type,
loss_fn=loss_fn,
optim_args=optim_args,
**kwargs)
self.log_args = log_args
self.call_save_hyperparameters()
self.ddim_sampling_steps = ddim_sampling_steps
self.guidance_scale = guidance_scale
def call_save_hyperparameters(self):
'''write in a separate function so that the inherit class can overwrite it'''
self.save_hyperparameters(ignore=['unet'])
def process_batch(self, x_0, mode):
assert mode in ['train', 'val', 'test']
b, *_ = x_0.shape
noise = torch.randn_like(x_0)
if mode == 'train':
t = torch.randint(0, self.num_timesteps, (b,), device=x_0.device).long()
x_t = self.add_noise(x_0, t, noise=noise)
else:
t = torch.full((b,), self.num_timesteps-1, device=x_0.device, dtype=torch.long)
x_t = self.add_noise(x_0, t, noise=noise)
model_kwargs = {}
'''the order of return is
1) model input,
2) model pred target,
3) model time condition
4) raw image before adding noise
5) model_kwargs
'''
if self.prediction_type == 'epsilon':
return {
'model_input': x_t,
'model_target': noise,
't': t,
'model_kwargs': model_kwargs
}
else:
return {
'model_input': x_t,
'model_target': x_0,
't': t,
'model_kwargs': model_kwargs
}
def forward(self, x):
return self.validation_step(x, 0)
def get_loss(self, pred, target, t):
loss_raw = self.loss_fn(pred, target)
loss_flat = mean_flat(loss_raw)
loss = loss_flat
loss = loss.mean()
return loss
def get_hdr_loss(self, fg_mask, pred, pred_combine): # fg_mask: 1,16,4,64,64 都是这个维度
# import pdb; pdb.set_trace() #todo 打印维度, 查看是否有问题
loss_raw = self.loss_fn(pred, pred_combine) #(1,16,4,64,64)
masked_loss = fg_mask * loss_raw
loss_flat = mean_flat(masked_loss)
loss = loss_flat
loss = loss.mean()
return loss
def training_step(self, batch, batch_idx):
self.clip_denoised = False
processed_batch = self.process_batch(batch, mode='train')
x_t = processed_batch['model_input']
y = processed_batch['model_target']
t = processed_batch['t']
model_kwargs = processed_batch['model_kwargs']
pred = self.unet(x_t, t, **model_kwargs)
loss = self.get_loss(pred, y, t)
x_0_hat = self.predict_x_0_from_x_t(pred, t, x_t)
self.log(f'train_loss', loss)
return {
'loss': loss,
'model_input': x_t,
'model_output': pred,
'x_0_hat': x_0_hat
}
@torch.no_grad()
def validation_step(self, batch, batch_idx):
from diffusers import DDIMScheduler
scheduler = DDIMScheduler(**self.beta_schedule_args)
scheduler.set_timesteps(self.ddim_sampling_steps)
processed_batch = self.process_batch(batch, mode='val')
x_t = torch.randn_like(processed_batch['model_input'])
x_hist = []
timesteps = scheduler.timesteps
for i, t in enumerate(timesteps):
t_ = torch.full((x_t.shape[0],), t, device=x_t.device, dtype=torch.long)
model_output = self.unet(x_t, t_, **processed_batch['model_kwargs'])
x_hist.append(
self.predict_x_0_from_x_t(model_output, t_, x_t)
)
x_t = scheduler.step(model_output, t, x_t).prev_sample
return {
'x_pred': x_t,
'x_hist': torch.stack(x_hist, dim=1),
}
def test_step(self, batch, batch_idx):
'''Test is usually not used in a sampling problem'''
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), **self.optim_args)
return optimizer
class DDPMLDMTraining(DDPMTraining): # 加入潜变量, LDM 即在latent层面上来做
def __init__(
self, *args,
vae,
unet_init_weights=None,
vae_init_weights=None,
scale_factor=0.18215,
**kwargs
):
super().__init__(*args, **kwargs)
self.vae = vae
self.scale_factor = scale_factor
self.initialize_unet(unet_init_weights)
self.initialize_vqvae(vae_init_weights) # 这边可以把这个设为none(config文件里面)
def initialize_unet(self, unet_init_weights):
if unet_init_weights is not None:
print(f'INFO: initialize denoising UNet from {unet_init_weights}')
sd = torch.load(unet_init_weights, map_location='cpu')
self.unet.load_state_dict(sd)
def initialize_vqvae(self, vqvae_init_weights): # 这边vae load最后调用就是这个init函数
if vqvae_init_weights is not None:
print(f'INFO: initialize VQVAE from {vqvae_init_weights}')
if '.safetensors' in vqvae_init_weights:
sd = load_file(vqvae_init_weights)
else:
sd = torch.load(vqvae_init_weights, map_location='cpu')
self.vae.load_state_dict(sd)
for param in self.vae.parameters():
param.requires_grad = False # vae 也是冻住参数的
def call_save_hyperparameters(self):
'''write in a separate function so that the inherit class can overwrite it'''
self.save_hyperparameters(ignore=['unet', 'vae'])
@torch.no_grad()
def encode_image_to_latent(self, x):
#return self.vae.encode(x) * self.scale_factor #! change
return self.vae.encode(x).latent_dist.mean * self.scale_factor
@torch.no_grad()
def decode_latent_to_image(self, x):
x = x / self.scale_factor # 注意一下这个东西出现 必须要一致 sample乘以了, 这边就得除以
return self.vae.decode(x)
def process_batch(self, x_0, mode):
x_0 = self.encode_image_to_latent(x_0)
res = super().process_batch(x_0, mode)
return res
def training_step(self, batch, batch_idx):
res_dict = super().training_step(batch, batch_idx)
res_dict['x_0_hat'] = self.decode_latent_to_image(res_dict['x_0_hat'])
return res_dict
class DDIMLDMTextTraining(DDPMLDMTraining): # 加入text encoder以及文本编码进行条件生成;+改成DDIM 训练
def __init__(
self, *args,
text_model,
text_model_init_weights=None,
**kwargs
):
super().__init__(
*args, **kwargs
)
self.text_model = text_model
self.initialize_text_model(text_model_init_weights) #! 这个也可以不要, 直接设置weights=None
def initialize_text_model(self, text_model_init_weights): # 这边text model最后调用就是这个init函数
if text_model_init_weights is not None:
print(f'INFO: initialize text model from {text_model_init_weights}')
sd = torch.load(text_model_init_weights, map_location='cpu')
self.text_model.load_state_dict(sd)
for param in self.text_model.parameters():
param.requires_grad = False # 这边设置了text model不回传梯度
def call_save_hyperparameters(self):
'''write in a separate function so that the inherit class can overwrite it'''
self.save_hyperparameters(ignore=['unet', 'vae', 'text_model'])
@torch.no_grad()
def encode_text(self, x):
if isinstance(x, tuple):
x = list(x)
return self.text_model.encode(x)
def process_batch(self, batch, mode):
x_0 = batch['image']
text = batch['text']
processed_batch = super().process_batch(x_0, mode)
processed_batch['model_kwargs'].update({
'context': {'text': self.encode_text([text])}
})
return processed_batch
def sampling(self, image_shape=(1, 4, 64, 64), text='', negative_text=None):
'''
Usage:
sampled = self.sampling(text='a cat on the tree', negative_text='')
x = sampled['x_pred'][0].permute(1, 2, 0).detach().cpu().numpy()
x = x / 2 + 0.5
plt.imshow(x)
y = sampled['x_hist'][0, 10].permute(1, 2, 0).detach().cpu().numpy()
y = y / 2 + 0.5
plt.imshow(y)
'''
from diffusers import DDIMScheduler # ddim训练
scheduler = DDIMScheduler(**self.beta_schedule_args)
scheduler.set_timesteps(self.ddim_sampling_steps)
x_t = torch.randn(*image_shape, device=self.device)
do_cfg = self.guidance_scale > 1. and negative_text is not None
if do_cfg:
context = {'text': self.encode_text([text, negative_text])}
else:
context = {'text': self.encode_text([text])}
x_hist = []
timesteps = scheduler.timesteps
for i, t in enumerate(timesteps):
if do_cfg:
model_input = torch.cat([x_t]*2)
else:
model_input = x_t
t_ = torch.full((model_input.shape[0],), t, device=x_t.device, dtype=torch.long)
model_output = self.unet(model_input, t_, context)
if do_cfg:
model_output_positive, model_output_negative = model_output.chunk(2)
model_output = model_output_negative + self.guidance_scale * (model_output_positive - model_output_negative)
x_hist.append(
self.decode_latent_to_image(self.predict_x_0_from_x_t(model_output, t_[:x_t.shape[0]], x_t))
)
x_t = scheduler.step(model_output, t, x_t).prev_sample
return {
'x_pred': self.decode_latent_to_image(x_t),
'x_hist': torch.stack(x_hist, dim=1),
}