Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
import pytorch_lightning as pl | |
from misc_utils.model_utils import default, instantiate_from_config | |
from diffusers import DDPMScheduler | |
from safetensors.torch import load_file | |
def mean_flat(tensor): | |
""" | |
Take the mean over all non-batch dimensions. | |
""" | |
return tensor.mean(dim=list(range(1, len(tensor.shape)))) | |
class DDPM(pl.LightningModule): | |
def __init__( | |
self, | |
unet, | |
beta_schedule_args={ | |
'beta_start': 0.00085, | |
'beta_end': 0.0012, | |
'num_train_timesteps': 1000, | |
'beta_schedule': 'scaled_linear', | |
'clip_sample': False, | |
'thresholding': False, | |
}, | |
prediction_type='epsilon', | |
loss_fn='l2', | |
optim_args={}, | |
base_path=None, | |
**kwargs | |
): | |
''' | |
denoising_fn: a denoising model such as UNet | |
beta_schedule_args: a dictionary which contains | |
the configurations of the beta schedule | |
''' | |
super().__init__(**kwargs) | |
self.unet = unet | |
self.prediction_type = prediction_type | |
beta_schedule_args.update({'prediction_type': prediction_type}) | |
self.set_beta_schedule(beta_schedule_args) | |
self.num_timesteps = beta_schedule_args['num_train_timesteps'] | |
self.optim_args = optim_args | |
self.loss = loss_fn | |
self.base_path = base_path | |
if loss_fn == 'l2' or loss_fn == 'mse': | |
self.loss_fn = nn.MSELoss(reduction='none') | |
elif loss_fn == 'l1' or loss_fn == 'mae': | |
self.loss_fn = nn.L1Loss(reduction='none') | |
elif isinstance(loss_fn, dict): | |
self.loss_fn = instantiate_from_config(loss_fn) | |
else: | |
raise NotImplementedError | |
def set_beta_schedule(self, beta_schedule_args): | |
self.beta_schedule_args = beta_schedule_args | |
self.scheduler = DDPMScheduler(**beta_schedule_args) | |
def add_noise(self, x, t, noise=None): | |
noise = default(noise, torch.randn_like(x)) | |
return self.scheduler.add_noise(x, noise, t) | |
def predict_x_0_from_x_t(self, model_output: torch.Tensor, t: torch.LongTensor, x_t: torch.Tensor): # 这边是一个缓存值: predicted x0 | |
''' recover x_0 from predicted noise. Reverse of Eq(4) in DDPM paper | |
\hat(x_0) = 1 / sqrt[\bar(a)]*x_t - sqrt[(1-\bar(a)) / \bar(a)]*noise''' | |
# return self.scheduler.step(model_output, int(t), x_t).pred_original_sample | |
if self.prediction_type == 'sample': | |
return model_output | |
# for training target == epsilon | |
alphas_cumprod = self.scheduler.alphas_cumprod.to(device=x_t.device, dtype=x_t.dtype) | |
sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod[t]).flatten() | |
sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod[t] - 1.).flatten() | |
while len(sqrt_recip_alphas_cumprod.shape) < len(x_t.shape): | |
sqrt_recip_alphas_cumprod = sqrt_recip_alphas_cumprod.unsqueeze(-1) | |
sqrt_recipm1_alphas_cumprod = sqrt_recipm1_alphas_cumprod.unsqueeze(-1) | |
return sqrt_recip_alphas_cumprod * x_t - sqrt_recipm1_alphas_cumprod * model_output | |
def predict_x_tm1_from_x_t(self, model_output, t, x_t): | |
'''predict x_{t-1} from x_t and model_output''' | |
return self.scheduler.step(model_output, t, x_t).prev_sample | |
class DDPMTraining(DDPM): # 加入training step保证训练等等 | |
def __init__( | |
self, | |
unet, | |
beta_schedule_args, | |
prediction_type='epsilon', | |
loss_fn='l2', | |
optim_args={ | |
'lr': 1e-3, | |
'weight_decay': 5e-4 | |
}, | |
log_args={}, # for record all arguments with self.save_hyperparameters | |
ddim_sampling_steps=20, | |
guidance_scale=5., | |
**kwargs | |
): | |
super().__init__( | |
unet=unet, | |
beta_schedule_args=beta_schedule_args, | |
prediction_type=prediction_type, | |
loss_fn=loss_fn, | |
optim_args=optim_args, | |
**kwargs) | |
self.log_args = log_args | |
self.call_save_hyperparameters() | |
self.ddim_sampling_steps = ddim_sampling_steps | |
self.guidance_scale = guidance_scale | |
def call_save_hyperparameters(self): | |
'''write in a separate function so that the inherit class can overwrite it''' | |
self.save_hyperparameters(ignore=['unet']) | |
def process_batch(self, x_0, mode): | |
assert mode in ['train', 'val', 'test'] | |
b, *_ = x_0.shape | |
noise = torch.randn_like(x_0) | |
if mode == 'train': | |
t = torch.randint(0, self.num_timesteps, (b,), device=x_0.device).long() | |
x_t = self.add_noise(x_0, t, noise=noise) | |
else: | |
t = torch.full((b,), self.num_timesteps-1, device=x_0.device, dtype=torch.long) | |
x_t = self.add_noise(x_0, t, noise=noise) | |
model_kwargs = {} | |
'''the order of return is | |
1) model input, | |
2) model pred target, | |
3) model time condition | |
4) raw image before adding noise | |
5) model_kwargs | |
''' | |
if self.prediction_type == 'epsilon': | |
return { | |
'model_input': x_t, | |
'model_target': noise, | |
't': t, | |
'model_kwargs': model_kwargs | |
} | |
else: | |
return { | |
'model_input': x_t, | |
'model_target': x_0, | |
't': t, | |
'model_kwargs': model_kwargs | |
} | |
def forward(self, x): | |
return self.validation_step(x, 0) | |
def get_loss(self, pred, target, t): | |
loss_raw = self.loss_fn(pred, target) | |
loss_flat = mean_flat(loss_raw) | |
loss = loss_flat | |
loss = loss.mean() | |
return loss | |
def get_hdr_loss(self, fg_mask, pred, pred_combine): # fg_mask: 1,16,4,64,64 都是这个维度 | |
# import pdb; pdb.set_trace() #todo 打印维度, 查看是否有问题 | |
loss_raw = self.loss_fn(pred, pred_combine) #(1,16,4,64,64) | |
masked_loss = fg_mask * loss_raw | |
loss_flat = mean_flat(masked_loss) | |
loss = loss_flat | |
loss = loss.mean() | |
return loss | |
def training_step(self, batch, batch_idx): | |
self.clip_denoised = False | |
processed_batch = self.process_batch(batch, mode='train') | |
x_t = processed_batch['model_input'] | |
y = processed_batch['model_target'] | |
t = processed_batch['t'] | |
model_kwargs = processed_batch['model_kwargs'] | |
pred = self.unet(x_t, t, **model_kwargs) | |
loss = self.get_loss(pred, y, t) | |
x_0_hat = self.predict_x_0_from_x_t(pred, t, x_t) | |
self.log(f'train_loss', loss) | |
return { | |
'loss': loss, | |
'model_input': x_t, | |
'model_output': pred, | |
'x_0_hat': x_0_hat | |
} | |
def validation_step(self, batch, batch_idx): | |
from diffusers import DDIMScheduler | |
scheduler = DDIMScheduler(**self.beta_schedule_args) | |
scheduler.set_timesteps(self.ddim_sampling_steps) | |
processed_batch = self.process_batch(batch, mode='val') | |
x_t = torch.randn_like(processed_batch['model_input']) | |
x_hist = [] | |
timesteps = scheduler.timesteps | |
for i, t in enumerate(timesteps): | |
t_ = torch.full((x_t.shape[0],), t, device=x_t.device, dtype=torch.long) | |
model_output = self.unet(x_t, t_, **processed_batch['model_kwargs']) | |
x_hist.append( | |
self.predict_x_0_from_x_t(model_output, t_, x_t) | |
) | |
x_t = scheduler.step(model_output, t, x_t).prev_sample | |
return { | |
'x_pred': x_t, | |
'x_hist': torch.stack(x_hist, dim=1), | |
} | |
def test_step(self, batch, batch_idx): | |
'''Test is usually not used in a sampling problem''' | |
return self.validation_step(batch, batch_idx) | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), **self.optim_args) | |
return optimizer | |
class DDPMLDMTraining(DDPMTraining): # 加入潜变量, LDM 即在latent层面上来做 | |
def __init__( | |
self, *args, | |
vae, | |
unet_init_weights=None, | |
vae_init_weights=None, | |
scale_factor=0.18215, | |
**kwargs | |
): | |
super().__init__(*args, **kwargs) | |
self.vae = vae | |
self.scale_factor = scale_factor | |
self.initialize_unet(unet_init_weights) | |
self.initialize_vqvae(vae_init_weights) # 这边可以把这个设为none(config文件里面) | |
def initialize_unet(self, unet_init_weights): | |
if unet_init_weights is not None: | |
print(f'INFO: initialize denoising UNet from {unet_init_weights}') | |
sd = torch.load(unet_init_weights, map_location='cpu') | |
self.unet.load_state_dict(sd) | |
def initialize_vqvae(self, vqvae_init_weights): # 这边vae load最后调用就是这个init函数 | |
if vqvae_init_weights is not None: | |
print(f'INFO: initialize VQVAE from {vqvae_init_weights}') | |
if '.safetensors' in vqvae_init_weights: | |
sd = load_file(vqvae_init_weights) | |
else: | |
sd = torch.load(vqvae_init_weights, map_location='cpu') | |
self.vae.load_state_dict(sd) | |
for param in self.vae.parameters(): | |
param.requires_grad = False # vae 也是冻住参数的 | |
def call_save_hyperparameters(self): | |
'''write in a separate function so that the inherit class can overwrite it''' | |
self.save_hyperparameters(ignore=['unet', 'vae']) | |
def encode_image_to_latent(self, x): | |
#return self.vae.encode(x) * self.scale_factor #! change | |
return self.vae.encode(x).latent_dist.mean * self.scale_factor | |
def decode_latent_to_image(self, x): | |
x = x / self.scale_factor # 注意一下这个东西出现 必须要一致 sample乘以了, 这边就得除以 | |
return self.vae.decode(x) | |
def process_batch(self, x_0, mode): | |
x_0 = self.encode_image_to_latent(x_0) | |
res = super().process_batch(x_0, mode) | |
return res | |
def training_step(self, batch, batch_idx): | |
res_dict = super().training_step(batch, batch_idx) | |
res_dict['x_0_hat'] = self.decode_latent_to_image(res_dict['x_0_hat']) | |
return res_dict | |
class DDIMLDMTextTraining(DDPMLDMTraining): # 加入text encoder以及文本编码进行条件生成;+改成DDIM 训练 | |
def __init__( | |
self, *args, | |
text_model, | |
text_model_init_weights=None, | |
**kwargs | |
): | |
super().__init__( | |
*args, **kwargs | |
) | |
self.text_model = text_model | |
self.initialize_text_model(text_model_init_weights) #! 这个也可以不要, 直接设置weights=None | |
def initialize_text_model(self, text_model_init_weights): # 这边text model最后调用就是这个init函数 | |
if text_model_init_weights is not None: | |
print(f'INFO: initialize text model from {text_model_init_weights}') | |
sd = torch.load(text_model_init_weights, map_location='cpu') | |
self.text_model.load_state_dict(sd) | |
for param in self.text_model.parameters(): | |
param.requires_grad = False # 这边设置了text model不回传梯度 | |
def call_save_hyperparameters(self): | |
'''write in a separate function so that the inherit class can overwrite it''' | |
self.save_hyperparameters(ignore=['unet', 'vae', 'text_model']) | |
def encode_text(self, x): | |
if isinstance(x, tuple): | |
x = list(x) | |
return self.text_model.encode(x) | |
def process_batch(self, batch, mode): | |
x_0 = batch['image'] | |
text = batch['text'] | |
processed_batch = super().process_batch(x_0, mode) | |
processed_batch['model_kwargs'].update({ | |
'context': {'text': self.encode_text([text])} | |
}) | |
return processed_batch | |
def sampling(self, image_shape=(1, 4, 64, 64), text='', negative_text=None): | |
''' | |
Usage: | |
sampled = self.sampling(text='a cat on the tree', negative_text='') | |
x = sampled['x_pred'][0].permute(1, 2, 0).detach().cpu().numpy() | |
x = x / 2 + 0.5 | |
plt.imshow(x) | |
y = sampled['x_hist'][0, 10].permute(1, 2, 0).detach().cpu().numpy() | |
y = y / 2 + 0.5 | |
plt.imshow(y) | |
''' | |
from diffusers import DDIMScheduler # ddim训练 | |
scheduler = DDIMScheduler(**self.beta_schedule_args) | |
scheduler.set_timesteps(self.ddim_sampling_steps) | |
x_t = torch.randn(*image_shape, device=self.device) | |
do_cfg = self.guidance_scale > 1. and negative_text is not None | |
if do_cfg: | |
context = {'text': self.encode_text([text, negative_text])} | |
else: | |
context = {'text': self.encode_text([text])} | |
x_hist = [] | |
timesteps = scheduler.timesteps | |
for i, t in enumerate(timesteps): | |
if do_cfg: | |
model_input = torch.cat([x_t]*2) | |
else: | |
model_input = x_t | |
t_ = torch.full((model_input.shape[0],), t, device=x_t.device, dtype=torch.long) | |
model_output = self.unet(model_input, t_, context) | |
if do_cfg: | |
model_output_positive, model_output_negative = model_output.chunk(2) | |
model_output = model_output_negative + self.guidance_scale * (model_output_positive - model_output_negative) | |
x_hist.append( | |
self.decode_latent_to_image(self.predict_x_0_from_x_t(model_output, t_[:x_t.shape[0]], x_t)) | |
) | |
x_t = scheduler.step(model_output, t, x_t).prev_sample | |
return { | |
'x_pred': self.decode_latent_to_image(x_t), | |
'x_hist': torch.stack(x_hist, dim=1), | |
} | |