MeanAudio / meanaudio /model /mean_flow.py
junxiliu's picture
add needed model with proper LFS tracking
3a1da90
import logging
from typing import Callable, Optional
import torch
from torchdiffeq import odeint
import torch.nn as nn
log = logging.getLogger()
import torch
import torch.nn.functional as F
from einops import rearrange
from functools import partial
import numpy as np
import math
def normalize_to_neg1_1(x):
return x * 2 - 1
def unnormalize_to_0_1(x):
return (x + 1) * 0.5
def stopgrad(x):
return x.detach()
def adaptive_l2_loss(error, gamma=0, c=1e-3):
"""
Adaptive L2 loss: sg(w) * ||Δ||_2^2, where w = 1 / (||Δ||^2 + c)^p, p = 1 - γ
Args:
error: Tensor of shape (B, C, W, H)
gamma: Power used in original ||Δ||^{2γ} loss
c: Small constant for stability
Returns:
Scalar loss
"""
delta_sq = torch.mean(error ** 2, dim=(1, 2), keepdim=False)
p = 1.0 - gamma
w = 1.0 / (delta_sq + c).pow(p)
loss = delta_sq # ||Δ||^2
return stopgrad(w) * loss
def cosine_annealing(start, end, step, total_steps):
cos_inner = math.pi * step / total_steps
return end + 0.5 * (start - end) * (1 + math.cos(cos_inner))
## partially from https://github.com/haidog-yaqub/MeanFlow
class MeanFlow():
def __init__(
self,
steps=1,
flow_ratio=0.75,
time_dist=['lognorm', -0.4, 1.0],
w=0.3,
k=0.9,
cfg_uncond='u',
jvp_api='autograd',
):
super().__init__()
self.flow_ratio = flow_ratio
self.time_dist = time_dist
self.w = w
self.k = k
self.steps = steps
self.cfg_uncond = cfg_uncond
self.jvp_api = jvp_api
assert jvp_api in ['funtorch', 'autograd'], "jvp_api must be 'funtorch' or 'autograd'"
if jvp_api == 'funtorch':
self.jvp_fn = torch.func.jvp
self.create_graph = False
elif jvp_api == 'autograd':
self.jvp_fn = torch.autograd.functional.jvp
self.create_graph = True
log.info(f'MeanFlow initialized with {steps} steps')
def sample_t_r(self, batch_size, device):
if self.time_dist[0] == 'uniform':
samples = np.random.rand(batch_size, 2).astype(np.float32)
elif self.time_dist[0] == 'lognorm':
mu, sigma = self.time_dist[-2], self.time_dist[-1]
normal_samples = np.random.randn(batch_size, 2).astype(np.float32) * sigma + mu
samples = 1 / (1 + np.exp(-normal_samples))
t_np = np.maximum(samples[:, 0], samples[:, 1])
r_np = np.minimum(samples[:, 0], samples[:, 1])
# we don't use self.flow ratio if we use scheduler
# !TODO: implement flow ratio scheduler
num_selected = int(self.flow_ratio * batch_size)
indices = np.random.permutation(batch_size)[:num_selected]
r_np[indices] = t_np[indices]
t = torch.tensor(t_np, device=device)
r = torch.tensor(r_np, device=device)
return t, r
def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor:
return self.run_t0_to_t1(fn, x1)
@torch.no_grad()
def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor:
return self.run_t0_to_t1(fn, x0)
def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor:
t = torch.ones((x0.shape[0],), device=x0.device,dtype=x0.dtype)
r = torch.zeros((x0.shape[0],), device=x0.device,dtype=x0.dtype)
steps = torch.linspace(1, 0, self.steps + 1).to(device=x0.device,dtype=x0.dtype)
for ti, t in enumerate(steps[:-1]):
t = t.expand(x0.shape[0])
next_t = steps[ti + 1].expand(x0.shape[0])
u_flow = fn(t=t, r=next_t, x=x0)
dt = (t - next_t).mean()
x0 = x0 - dt * u_flow
return x0
def loss(self,
fn: Callable,
x0: torch.Tensor,
text_f: torch.Tensor,
text_f_c: torch.Tensor,
text_f_undrop: torch.Tensor,
text_f_c_undrop: torch.Tensor,
empty_string_feat: torch.Tensor,
empty_string_feat_c: torch.Tensor):
if isinstance(fn, torch.nn.parallel.DistributedDataParallel):
fn = fn.module
batch_size = x0.shape[0]
device = x0.device
e = torch.randn_like(x0)
t, r = self.sample_t_r(batch_size, device)
t_ = rearrange(t, "b -> b 1 1 ")
r_ = rearrange(r, "b -> b 1 1 ")
z = (1 - t_) * x0 + t_ * e # r < t
v = e - x0
if self.w is not None:
u_text_f = empty_string_feat.expand(batch_size, -1, -1)
u_text_f_c = empty_string_feat_c.expand(batch_size, -1)
u_t = fn(latent=z,
text_f=u_text_f,
text_f_c=u_text_f_c,
r=t,
t=t).detach().requires_grad_(False)
u_t_c = fn(latent=z,
text_f=text_f_undrop,
text_f_c=text_f_c_undrop,
r=t,
t=t).detach().requires_grad_(False)
v_hat = self.w * v + self.k * u_t_c + (1 - self.w - self.k) * u_t
else:
v_hat = v
device = z.device
model_partial = partial(fn, text_f=text_f,text_f_c=text_f_c)
jvp_args = (
lambda z_f, r_f, t_f: model_partial(latent=z_f, r=r_f, t=t_f),
(z, r, t),
(v_hat, torch.zeros_like(r), torch.ones_like(t)),
)
if self.create_graph:
u, dudt = self.jvp_fn(*jvp_args, create_graph=True)
else:
u, dudt = self.jvp_fn(*jvp_args)
u_tgt = v_hat - (t_ - r_) * dudt
error = u - stopgrad(u_tgt)
loss = adaptive_l2_loss(error)
return loss, r, t
if __name__ == '__main__':
pass