|
import logging |
|
from typing import Callable, Optional |
|
|
|
import torch |
|
from torchdiffeq import odeint |
|
import torch.nn as nn |
|
log = logging.getLogger() |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from functools import partial |
|
import numpy as np |
|
import math |
|
|
|
|
|
def normalize_to_neg1_1(x): |
|
return x * 2 - 1 |
|
|
|
|
|
def unnormalize_to_0_1(x): |
|
return (x + 1) * 0.5 |
|
|
|
|
|
def stopgrad(x): |
|
return x.detach() |
|
|
|
|
|
def adaptive_l2_loss(error, gamma=0, c=1e-3): |
|
""" |
|
Adaptive L2 loss: sg(w) * ||Δ||_2^2, where w = 1 / (||Δ||^2 + c)^p, p = 1 - γ |
|
Args: |
|
error: Tensor of shape (B, C, W, H) |
|
gamma: Power used in original ||Δ||^{2γ} loss |
|
c: Small constant for stability |
|
Returns: |
|
Scalar loss |
|
""" |
|
delta_sq = torch.mean(error ** 2, dim=(1, 2), keepdim=False) |
|
p = 1.0 - gamma |
|
w = 1.0 / (delta_sq + c).pow(p) |
|
loss = delta_sq |
|
return stopgrad(w) * loss |
|
|
|
|
|
def cosine_annealing(start, end, step, total_steps): |
|
cos_inner = math.pi * step / total_steps |
|
return end + 0.5 * (start - end) * (1 + math.cos(cos_inner)) |
|
|
|
|
|
|
|
class MeanFlow(): |
|
def __init__( |
|
self, |
|
steps=1, |
|
flow_ratio=0.75, |
|
time_dist=['lognorm', -0.4, 1.0], |
|
w=0.3, |
|
k=0.9, |
|
cfg_uncond='u', |
|
jvp_api='autograd', |
|
): |
|
super().__init__() |
|
self.flow_ratio = flow_ratio |
|
self.time_dist = time_dist |
|
self.w = w |
|
self.k = k |
|
self.steps = steps |
|
|
|
self.cfg_uncond = cfg_uncond |
|
self.jvp_api = jvp_api |
|
assert jvp_api in ['funtorch', 'autograd'], "jvp_api must be 'funtorch' or 'autograd'" |
|
if jvp_api == 'funtorch': |
|
self.jvp_fn = torch.func.jvp |
|
self.create_graph = False |
|
elif jvp_api == 'autograd': |
|
self.jvp_fn = torch.autograd.functional.jvp |
|
self.create_graph = True |
|
log.info(f'MeanFlow initialized with {steps} steps') |
|
|
|
def sample_t_r(self, batch_size, device): |
|
if self.time_dist[0] == 'uniform': |
|
samples = np.random.rand(batch_size, 2).astype(np.float32) |
|
|
|
elif self.time_dist[0] == 'lognorm': |
|
mu, sigma = self.time_dist[-2], self.time_dist[-1] |
|
normal_samples = np.random.randn(batch_size, 2).astype(np.float32) * sigma + mu |
|
samples = 1 / (1 + np.exp(-normal_samples)) |
|
|
|
t_np = np.maximum(samples[:, 0], samples[:, 1]) |
|
r_np = np.minimum(samples[:, 0], samples[:, 1]) |
|
|
|
|
|
|
|
num_selected = int(self.flow_ratio * batch_size) |
|
indices = np.random.permutation(batch_size)[:num_selected] |
|
r_np[indices] = t_np[indices] |
|
|
|
t = torch.tensor(t_np, device=device) |
|
r = torch.tensor(r_np, device=device) |
|
return t, r |
|
|
|
def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor: |
|
return self.run_t0_to_t1(fn, x1) |
|
|
|
@torch.no_grad() |
|
def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor: |
|
return self.run_t0_to_t1(fn, x0) |
|
|
|
def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor: |
|
t = torch.ones((x0.shape[0],), device=x0.device,dtype=x0.dtype) |
|
r = torch.zeros((x0.shape[0],), device=x0.device,dtype=x0.dtype) |
|
steps = torch.linspace(1, 0, self.steps + 1).to(device=x0.device,dtype=x0.dtype) |
|
for ti, t in enumerate(steps[:-1]): |
|
t = t.expand(x0.shape[0]) |
|
next_t = steps[ti + 1].expand(x0.shape[0]) |
|
u_flow = fn(t=t, r=next_t, x=x0) |
|
dt = (t - next_t).mean() |
|
x0 = x0 - dt * u_flow |
|
return x0 |
|
|
|
def loss(self, |
|
fn: Callable, |
|
x0: torch.Tensor, |
|
text_f: torch.Tensor, |
|
text_f_c: torch.Tensor, |
|
text_f_undrop: torch.Tensor, |
|
text_f_c_undrop: torch.Tensor, |
|
empty_string_feat: torch.Tensor, |
|
empty_string_feat_c: torch.Tensor): |
|
if isinstance(fn, torch.nn.parallel.DistributedDataParallel): |
|
fn = fn.module |
|
batch_size = x0.shape[0] |
|
device = x0.device |
|
e = torch.randn_like(x0) |
|
t, r = self.sample_t_r(batch_size, device) |
|
t_ = rearrange(t, "b -> b 1 1 ") |
|
r_ = rearrange(r, "b -> b 1 1 ") |
|
z = (1 - t_) * x0 + t_ * e |
|
v = e - x0 |
|
|
|
if self.w is not None: |
|
u_text_f = empty_string_feat.expand(batch_size, -1, -1) |
|
u_text_f_c = empty_string_feat_c.expand(batch_size, -1) |
|
u_t = fn(latent=z, |
|
text_f=u_text_f, |
|
text_f_c=u_text_f_c, |
|
r=t, |
|
t=t).detach().requires_grad_(False) |
|
u_t_c = fn(latent=z, |
|
text_f=text_f_undrop, |
|
text_f_c=text_f_c_undrop, |
|
r=t, |
|
t=t).detach().requires_grad_(False) |
|
|
|
v_hat = self.w * v + self.k * u_t_c + (1 - self.w - self.k) * u_t |
|
else: |
|
v_hat = v |
|
|
|
device = z.device |
|
model_partial = partial(fn, text_f=text_f,text_f_c=text_f_c) |
|
jvp_args = ( |
|
lambda z_f, r_f, t_f: model_partial(latent=z_f, r=r_f, t=t_f), |
|
(z, r, t), |
|
(v_hat, torch.zeros_like(r), torch.ones_like(t)), |
|
) |
|
if self.create_graph: |
|
u, dudt = self.jvp_fn(*jvp_args, create_graph=True) |
|
else: |
|
u, dudt = self.jvp_fn(*jvp_args) |
|
u_tgt = v_hat - (t_ - r_) * dudt |
|
error = u - stopgrad(u_tgt) |
|
loss = adaptive_l2_loss(error) |
|
return loss, r, t |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |