MeanAudio / meanaudio /model /networks.py
AndreasXi's picture
update meanaudio_l_full
ef27943
import logging
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from meanaudio.ext.rotary_embeddings import compute_rope_rotations
from meanaudio.model.embeddings import TimestepEmbedder
from meanaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP
from meanaudio.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock)
log = logging.getLogger()
@dataclass
class PreprocessedConditions:
text_f: torch.Tensor
text_f_c: torch.Tensor
class FluxAudio(nn.Module):
# Flux style latent transformer for TTA, single time step embedding
def __init__(self,
*,
latent_dim: int,
text_dim: int,
text_c_dim: int,
hidden_dim: int,
depth: int,
fused_depth: int,
num_heads: int,
mlp_ratio: float = 4.0,
latent_seq_len: int,
text_seq_len: int = 77,
latent_mean: Optional[torch.Tensor] = None,
latent_std: Optional[torch.Tensor] = None,
empty_string_feat: Optional[torch.Tensor] = None,
empty_string_feat_c: Optional[torch.Tensor] = None,
use_rope: bool = False) -> None:
super().__init__()
self.latent_dim = latent_dim
self._latent_seq_len = latent_seq_len
self._text_seq_len = text_seq_len
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.use_rope = use_rope
self.mm_depth = depth - fused_depth
self.audio_input_proj = nn.Sequential(
ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3),
nn.SELU(),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
)
self.text_input_proj = nn.Sequential(
nn.Linear(text_dim, hidden_dim),
MLP(hidden_dim, hidden_dim * 4),
)
self.text_cond_proj = nn.Sequential(
nn.Linear(text_c_dim, hidden_dim),
MLP(hidden_dim, hidden_dim*4)
)
self.final_layer = FinalBlock(hidden_dim, latent_dim)
self.t_embed = TimestepEmbedder(hidden_dim,
frequency_embedding_size=256,
max_period=10000)
self.joint_blocks = nn.ModuleList([
JointBlock(hidden_dim,
num_heads,
mlp_ratio=mlp_ratio,
pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth) # last layer is pre-only (only appllied to text and vision)
])
self.fused_blocks = nn.ModuleList([
MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1)
for i in range(fused_depth)
])
if latent_mean is None:
# these values are not meant to be used
# if you don't provide mean/std here, we should load them later from a checkpoint
assert latent_std is None
latent_mean = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
latent_std = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
else:
assert latent_std is not None
assert latent_mean.numel() == latent_dim, f'{latent_mean.numel()=} != {latent_dim=}'
if empty_string_feat is None:
empty_string_feat = torch.zeros((text_seq_len, text_dim))
if empty_string_feat_c is None:
empty_string_feat_c = torch.zeros((text_c_dim))
assert empty_string_feat.shape[-1] == text_dim, f'{empty_string_feat.shape[-1]} == {text_dim}'
assert empty_string_feat_c.shape[-1] == text_c_dim, f'{empty_string_feat_c.shape[-1]} == {text_c_dim}'
self.latent_mean = nn.Parameter(latent_mean.view(1, 1, -1), requires_grad=False) # (1, 1, d)
self.latent_std = nn.Parameter(latent_std.view(1, 1, -1), requires_grad=False) # (1, 1, d)
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
self.empty_string_feat_c = nn.Parameter(empty_string_feat_c, requires_grad=False)
self.initialize_weights()
if self.use_rope:
log.info("Network: Enabling RoPE embeddings")
self.initialize_rotations()
else:
log.info("Network: RoPE embedding disabled")
self.latent_rot = None
self.text_rot = None
def initialize_rotations(self):
base_freq = 1.0
latent_rot = compute_rope_rotations(self._latent_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq,
device=self.device)
text_rot = compute_rope_rotations(self._text_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq,
device=self.device)
self.latent_rot = nn.Buffer(latent_rot, persistent=False) # will not be saved into state dict
self.text_rot = nn.Buffer(text_rot, persistent=False)
def update_seq_lengths(self, latent_seq_len: int) -> None:
self._latent_seq_len = latent_seq_len
if self.use_rope:
self.initialize_rotations() # after changing seq_len we need to re-initialize RoPE to match new seq_len
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.joint_blocks:
nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) # the linear layer -> 6 coefficients
nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0)
for block in self.fused_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.conv.weight, 0)
nn.init.constant_(self.final_layer.conv.bias, 0)
def normalize(self, x: torch.Tensor) -> torch.Tensor:
# return (x - self.latent_mean) / self.latent_std
return x.sub_(self.latent_mean).div_(self.latent_std)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
# return x * self.latent_std + self.latent_mean
return x.mul_(self.latent_std).add_(self.latent_mean)
def preprocess_conditions(self, text_f: torch.Tensor, text_f_c: torch.Tensor) -> PreprocessedConditions:
"""
cache computations that do not depend on the latent/time step
i.e., the features are reused over steps during inference
"""
assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}'
bs = text_f.shape[0]
# get global and local text features
# NOTE here the order of projection has been changed so global and local features are projected seperately
text_f_c = self.text_cond_proj(text_f_c) # (B, D)
text_f = self.text_input_proj(text_f) # (B, VN, D)
return PreprocessedConditions(text_f=text_f,
text_f_c=text_f_c)
def predict_flow(self, latent: torch.Tensor, t: torch.Tensor,
conditions: PreprocessedConditions) -> torch.Tensor:
"""
for non-cacheable computations
"""
assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}'
text_f = conditions.text_f
text_f_c = conditions.text_f_c
latent = self.audio_input_proj(latent) # (B, N, D)
global_c = self.t_embed(t).unsqueeze(1) + text_f_c.unsqueeze(1) # (B, 1, D)
extended_c = global_c # extended_c: Latent_c, global_c: Text_c
for block in self.joint_blocks:
latent, text_f = block(latent, text_f, global_c, extended_c, self.latent_rot, self.text_rot) # (B, N, D)
for block in self.fused_blocks:
latent = block(latent, extended_c, self.latent_rot)
flow = self.final_layer(latent, extended_c) # (B, N, out_dim), remove t
return flow
def forward(self, latent: torch.Tensor, text_f: torch.Tensor, text_f_c: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
latent: (B, N, C)
text_f: (B, T, D)
t: (B,)
"""
conditions = self.preprocess_conditions(text_f, text_f_c) # cachable operations
flow = self.predict_flow(latent, t, conditions) # non-cachable operations
return flow
def get_empty_string_sequence(self, bs: int) -> tuple[torch.Tensor, torch.Tensor]:
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1), \
self.empty_string_feat_c.unsqueeze(0).expand(bs, -1) # (b, d)
def get_empty_conditions(
self,
bs: int,
*,
negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions:
if negative_text_features is not None:
empty_string_feat, empty_string_feat_c = negative_text_features
else:
empty_string_feat, empty_string_feat_c = self.get_empty_string_sequence(1)
conditions = self.preprocess_conditions(empty_string_feat,
empty_string_feat_c) # use encoder's empty features
if negative_text_features is None:
conditions.text_f = conditions.text_f.expand(bs, -1, -1)
conditions.text_f_c = conditions.text_f_c.expand(bs, -1)
return conditions
def ode_wrapper(self, t: torch.Tensor, latent: torch.Tensor, conditions: PreprocessedConditions,
empty_conditions: PreprocessedConditions, cfg_strength: float) -> torch.Tensor:
t = t * torch.ones(len(latent), device=latent.device, dtype=latent.dtype)
if cfg_strength < 1.0:
return self.predict_flow(latent, t, conditions)
else:
return (cfg_strength * self.predict_flow(latent, t, conditions) +
(1 - cfg_strength) * self.predict_flow(latent, t, empty_conditions))
def load_weights(self, src_dict) -> None:
if 't_embed.freqs' in src_dict:
del src_dict['t_embed.freqs']
if 'latent_rot' in src_dict:
del src_dict['latent_rot']
if 'text_rot' in src_dict:
del src_dict['text_rot']
if 'empty_string_feat_c' not in src_dict.keys(): # FIXME: issue of version mismatch here
src_dict['empty_string_feat_c'] = src_dict['empty_string_feat'].mean(dim=0)
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return self.latent_mean.device
@property
def latent_seq_len(self) -> int:
return self._latent_seq_len
class MeanAudio(nn.Module):
# Flux style latent transformer for TTA, dual time step embedding
def __init__(self,
*,
latent_dim: int,
text_dim: int,
text_c_dim: int,
hidden_dim: int,
depth: int,
fused_depth: int,
num_heads: int,
mlp_ratio: float = 4.0,
latent_seq_len: int,
text_seq_len: int = 77,
latent_mean: Optional[torch.Tensor] = None,
latent_std: Optional[torch.Tensor] = None,
empty_string_feat: Optional[torch.Tensor] = None,
empty_string_feat_c: Optional[torch.Tensor] = None,
use_rope: bool = False) -> None:
super().__init__()
self.latent_dim = latent_dim
self._latent_seq_len = latent_seq_len
self._text_seq_len = text_seq_len
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.use_rope = use_rope
self.mm_depth = depth - fused_depth
self.audio_input_proj = nn.Sequential(
ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3),
nn.SELU(),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
)
self.text_input_proj = nn.Sequential(
nn.Linear(text_dim, hidden_dim),
MLP(hidden_dim, hidden_dim * 4),
)
self.text_cond_proj = nn.Sequential(
nn.Linear(text_c_dim, hidden_dim),
MLP(hidden_dim, hidden_dim*4)
)
self.final_layer = FinalBlock(hidden_dim, latent_dim)
self.t_embed = TimestepEmbedder(hidden_dim,
frequency_embedding_size=256,
max_period=10000)
#add
self.r_embed = TimestepEmbedder(hidden_dim,
frequency_embedding_size=256,
max_period=10000)
self.joint_blocks = nn.ModuleList([
JointBlock(hidden_dim,
num_heads,
mlp_ratio=mlp_ratio,
pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth) # last layer is pre-only (only appllied to text and vision)
])
self.fused_blocks = nn.ModuleList([
MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1)
for i in range(fused_depth)
])
if latent_mean is None:
# these values are not meant to be used
# if you don't provide mean/std here, we should load them later from a checkpoint
assert latent_std is None
latent_mean = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
latent_std = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
else:
assert latent_std is not None
assert latent_mean.numel() == latent_dim, f'{latent_mean.numel()=} != {latent_dim=}'
if empty_string_feat is None:
empty_string_feat = torch.zeros((text_seq_len, text_dim))
if empty_string_feat_c is None:
empty_string_feat_c = torch.zeros((text_c_dim))
assert empty_string_feat.shape[-1] == text_dim, f'{empty_string_feat.shape[-1]} == {text_dim}'
assert empty_string_feat_c.shape[-1] == text_c_dim, f'{empty_string_feat_c.shape[-1]} == {text_c_dim}'
self.latent_mean = nn.Parameter(latent_mean.view(1, 1, -1), requires_grad=False) # (1, 1, d)
self.latent_std = nn.Parameter(latent_std.view(1, 1, -1), requires_grad=False) # (1, 1, d)
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
self.empty_string_feat_c = nn.Parameter(empty_string_feat_c, requires_grad=False)
self.initialize_weights()
if self.use_rope:
log.info("Network: Enabling RoPE embeddings")
self.initialize_rotations()
else:
log.info("Network: RoPE embedding disabled")
self.latent_rot = None
self.text_rot = None
def initialize_rotations(self):
base_freq = 1.0
latent_rot = compute_rope_rotations(self._latent_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq,
device=self.device)
text_rot = compute_rope_rotations(self._text_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq,
device=self.device)
self.latent_rot = nn.Buffer(latent_rot, persistent=False) # will not be saved into state dict
self.text_rot = nn.Buffer(text_rot, persistent=False)
def update_seq_lengths(self, latent_seq_len: int) -> None:
self._latent_seq_len = latent_seq_len
if self.use_rope:
self.initialize_rotations() # after changing seq_len we need to re-initialize RoPE to match new seq_len
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.joint_blocks:
nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) # the linear layer -> 6 coefficients
nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0)
for block in self.fused_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.conv.weight, 0)
nn.init.constant_(self.final_layer.conv.bias, 0)
def normalize(self, x: torch.Tensor) -> torch.Tensor:
# return (x - self.latent_mean) / self.latent_std
return x.sub_(self.latent_mean).div_(self.latent_std)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
# return x * self.latent_std + self.latent_mean
return x.mul_(self.latent_std).add_(self.latent_mean)
def preprocess_conditions(self, text_f: torch.Tensor, text_f_c: torch.Tensor) -> PreprocessedConditions:
"""
cache computations that do not depend on the latent/time step
i.e., the features are reused over steps during inference
"""
assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}'
bs = text_f.shape[0]
# get global and local text features
# NOTE here the order of projection has been changed so global and local features are projected seperately
text_f_c = self.text_cond_proj(text_f_c) # (B, D)
text_f = self.text_input_proj(text_f) # (B, VN, D)
return PreprocessedConditions(text_f=text_f,
text_f_c=text_f_c)
def predict_flow(self, latent: torch.Tensor, t: torch.Tensor,r: torch.Tensor,#need r<t
conditions: PreprocessedConditions) -> torch.Tensor:
"""
for non-cacheable computations
"""
#assert r<=t,"r should smaller than t"
assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}'
text_f = conditions.text_f
text_f_c = conditions.text_f_c
latent = self.audio_input_proj(latent) # (B, N, D)
#easy try:same embed
global_c = self.t_embed(t).unsqueeze(1) + self.r_embed(r).unsqueeze(1) + text_f_c.unsqueeze(1) # (B, 1, D)
extended_c = global_c # !TODO add fine-grained control
for block in self.joint_blocks:
latent, text_f = block(latent, text_f, global_c, extended_c, self.latent_rot, self.text_rot) # (B, N, D)
for block in self.fused_blocks:
latent = block(latent, extended_c, self.latent_rot)
flow = self.final_layer(latent, extended_c) # (B, N, out_dim), remove t
return flow
def forward(self, latent: torch.Tensor, text_f: torch.Tensor, text_f_c: torch.Tensor, r: torch.Tensor,t: torch.Tensor) -> torch.Tensor:
"""
latent: (B, N, C)
text_f: (B, T, D)
text_f_c
r: (B,)
t:(B,)
"""
#print("2")
conditions = self.preprocess_conditions(text_f, text_f_c) # cachable operations
#print(conditions)
flow = self.predict_flow(latent, t,r, conditions) # non-cachable operations
return flow
def get_empty_string_sequence(self, bs: int) -> tuple[torch.Tensor, torch.Tensor]:
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1), \
self.empty_string_feat_c.unsqueeze(0).expand(bs, -1) # (b, d)
def get_empty_conditions(
self,
bs: int,
*,
negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions:
if negative_text_features is not None:
empty_string_feat, empty_string_feat_c = negative_text_features
else:
empty_string_feat, empty_string_feat_c = self.get_empty_string_sequence(1)
conditions = self.preprocess_conditions(empty_string_feat,
empty_string_feat_c) # use encoder's empty features
if negative_text_features is None:
conditions.text_f = conditions.text_f.expand(bs, -1, -1)
conditions.text_f_c = conditions.text_f_c.expand(bs, -1)
return conditions
def ode_wrapper(self, t: torch.Tensor, r: torch.Tensor, latent: torch.Tensor, conditions: PreprocessedConditions,
empty_conditions: PreprocessedConditions, cfg_strength: float) -> torch.Tensor:
t = t * torch.ones(len(latent), device=latent.device, dtype=latent.dtype)
r = r * torch.ones(len(latent), device=latent.device, dtype=latent.dtype)
#(r)
if cfg_strength < 1.0:
return self.predict_flow(latent, t,r, conditions)
else:
return (cfg_strength * self.predict_flow(latent, t,r, conditions) +
(1 - cfg_strength) * self.predict_flow(latent, t,r, empty_conditions))
def load_weights(self, src_dict) -> None:
def remove_prefix(storage):
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in storage.items():
name = k.replace("ema_model.", "")
new_state_dict[name] = v
return new_state_dict
src_dict=remove_prefix(src_dict)
if 't_embed.freqs' in src_dict:
del src_dict['t_embed.freqs']
if 'r_embed.freqs' in src_dict:
del src_dict['r_embed.freqs']
if 'latent_rot' in src_dict:
del src_dict['latent_rot']
if 'text_rot' in src_dict:
del src_dict['text_rot']
if 'empty_string_feat_c' not in src_dict.keys(): # FIXME: issue of version mismatch here
src_dict['empty_string_feat_c'] = src_dict['empty_string_feat'].mean(dim=0)
if '_extra_state' in src_dict:
del src_dict['_extra_state']
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return self.latent_mean.device
@property
def latent_seq_len(self) -> int:
return self._latent_seq_len
def fluxaudio_s(**kwargs) -> FluxAudio:
num_heads = 7
return FluxAudio(latent_dim=20,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=12,
fused_depth=8,
num_heads=num_heads,
latent_seq_len=312, # for 10s audio
**kwargs)
def meanaudio_s(**kwargs) -> MeanAudio:
num_heads = 7
return MeanAudio(latent_dim=20,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=12,
fused_depth=8,
num_heads=num_heads,
latent_seq_len=312, # for 10s audio
**kwargs)
def meanaudio_l(**kwargs) -> MeanAudio:
num_heads = 14
return MeanAudio(latent_dim=20,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=12,
fused_depth=8,
num_heads=num_heads,
latent_seq_len=312, # for 10s audio
**kwargs)
def get_mean_audio(name: str, **kwargs) -> MeanAudio:
if name == 'meanaudio_s_ac' or name == 'meanaudio_s_full':
return meanaudio_s(**kwargs)
elif name == 'meanaudio_l_full':
return meanaudio_l(**kwargs)
elif name == 'fluxaudio_s_full':
return fluxaudio_s(**kwargs)
else:
raise ValueError(f'Unknown model name: {name}')
if __name__ == '__main__':
from meanaudio.model.utils.sample_utils import log_normal_sample
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
# logging.FileHandler("main.log"),
logging.StreamHandler()
]
)
network: MeanAudio = get_mean_audio('meanaudio_s',
use_rope=False,
text_c_dim=512)
x = torch.randn(256, 312, 20)
print(x.shape)
print('Finish')