|
import logging |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import sys |
|
from pathlib import Path |
|
sys.path.append(str(Path(__file__).parent.parent.parent)) |
|
|
|
from meanaudio.ext.rotary_embeddings import compute_rope_rotations |
|
from meanaudio.model.embeddings import TimestepEmbedder |
|
from meanaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP |
|
from meanaudio.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock) |
|
|
|
log = logging.getLogger() |
|
|
|
|
|
@dataclass |
|
class PreprocessedConditions: |
|
text_f: torch.Tensor |
|
text_f_c: torch.Tensor |
|
|
|
|
|
class FluxAudio(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
*, |
|
latent_dim: int, |
|
text_dim: int, |
|
text_c_dim: int, |
|
hidden_dim: int, |
|
depth: int, |
|
fused_depth: int, |
|
num_heads: int, |
|
mlp_ratio: float = 4.0, |
|
latent_seq_len: int, |
|
text_seq_len: int = 77, |
|
latent_mean: Optional[torch.Tensor] = None, |
|
latent_std: Optional[torch.Tensor] = None, |
|
empty_string_feat: Optional[torch.Tensor] = None, |
|
empty_string_feat_c: Optional[torch.Tensor] = None, |
|
use_rope: bool = False) -> None: |
|
super().__init__() |
|
|
|
self.latent_dim = latent_dim |
|
self._latent_seq_len = latent_seq_len |
|
self._text_seq_len = text_seq_len |
|
self.hidden_dim = hidden_dim |
|
self.num_heads = num_heads |
|
self.use_rope = use_rope |
|
self.mm_depth = depth - fused_depth |
|
|
|
self.audio_input_proj = nn.Sequential( |
|
ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3), |
|
nn.SELU(), |
|
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3), |
|
) |
|
|
|
self.text_input_proj = nn.Sequential( |
|
nn.Linear(text_dim, hidden_dim), |
|
MLP(hidden_dim, hidden_dim * 4), |
|
) |
|
|
|
self.text_cond_proj = nn.Sequential( |
|
nn.Linear(text_c_dim, hidden_dim), |
|
MLP(hidden_dim, hidden_dim*4) |
|
) |
|
|
|
self.final_layer = FinalBlock(hidden_dim, latent_dim) |
|
|
|
self.t_embed = TimestepEmbedder(hidden_dim, |
|
frequency_embedding_size=256, |
|
max_period=10000) |
|
|
|
self.joint_blocks = nn.ModuleList([ |
|
JointBlock(hidden_dim, |
|
num_heads, |
|
mlp_ratio=mlp_ratio, |
|
pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth) |
|
]) |
|
|
|
self.fused_blocks = nn.ModuleList([ |
|
MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1) |
|
for i in range(fused_depth) |
|
]) |
|
|
|
if latent_mean is None: |
|
|
|
|
|
assert latent_std is None |
|
latent_mean = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan')) |
|
latent_std = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan')) |
|
else: |
|
assert latent_std is not None |
|
assert latent_mean.numel() == latent_dim, f'{latent_mean.numel()=} != {latent_dim=}' |
|
|
|
if empty_string_feat is None: |
|
empty_string_feat = torch.zeros((text_seq_len, text_dim)) |
|
if empty_string_feat_c is None: |
|
empty_string_feat_c = torch.zeros((text_c_dim)) |
|
|
|
assert empty_string_feat.shape[-1] == text_dim, f'{empty_string_feat.shape[-1]} == {text_dim}' |
|
assert empty_string_feat_c.shape[-1] == text_c_dim, f'{empty_string_feat_c.shape[-1]} == {text_c_dim}' |
|
|
|
self.latent_mean = nn.Parameter(latent_mean.view(1, 1, -1), requires_grad=False) |
|
self.latent_std = nn.Parameter(latent_std.view(1, 1, -1), requires_grad=False) |
|
|
|
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False) |
|
self.empty_string_feat_c = nn.Parameter(empty_string_feat_c, requires_grad=False) |
|
|
|
|
|
self.initialize_weights() |
|
if self.use_rope: |
|
log.info("Network: Enabling RoPE embeddings") |
|
self.initialize_rotations() |
|
else: |
|
log.info("Network: RoPE embedding disabled") |
|
self.latent_rot = None |
|
self.text_rot = None |
|
|
|
def initialize_rotations(self): |
|
base_freq = 1.0 |
|
latent_rot = compute_rope_rotations(self._latent_seq_len, |
|
self.hidden_dim // self.num_heads, |
|
10000, |
|
freq_scaling=base_freq, |
|
device=self.device) |
|
text_rot = compute_rope_rotations(self._text_seq_len, |
|
self.hidden_dim // self.num_heads, |
|
10000, |
|
freq_scaling=base_freq, |
|
device=self.device) |
|
|
|
self.latent_rot = nn.Buffer(latent_rot, persistent=False) |
|
self.text_rot = nn.Buffer(text_rot, persistent=False) |
|
|
|
def update_seq_lengths(self, latent_seq_len: int) -> None: |
|
self._latent_seq_len = latent_seq_len |
|
if self.use_rope: |
|
self.initialize_rotations() |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02) |
|
|
|
|
|
for block in self.joint_blocks: |
|
nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0) |
|
nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0) |
|
for block in self.fused_blocks: |
|
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) |
|
nn.init.constant_(self.final_layer.conv.weight, 0) |
|
nn.init.constant_(self.final_layer.conv.bias, 0) |
|
|
|
def normalize(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
return x.sub_(self.latent_mean).div_(self.latent_std) |
|
|
|
def unnormalize(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
return x.mul_(self.latent_std).add_(self.latent_mean) |
|
|
|
def preprocess_conditions(self, text_f: torch.Tensor, text_f_c: torch.Tensor) -> PreprocessedConditions: |
|
""" |
|
cache computations that do not depend on the latent/time step |
|
i.e., the features are reused over steps during inference |
|
""" |
|
assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}' |
|
|
|
bs = text_f.shape[0] |
|
|
|
|
|
|
|
text_f_c = self.text_cond_proj(text_f_c) |
|
text_f = self.text_input_proj(text_f) |
|
|
|
return PreprocessedConditions(text_f=text_f, |
|
text_f_c=text_f_c) |
|
|
|
def predict_flow(self, latent: torch.Tensor, t: torch.Tensor, |
|
conditions: PreprocessedConditions) -> torch.Tensor: |
|
""" |
|
for non-cacheable computations |
|
""" |
|
assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}' |
|
|
|
text_f = conditions.text_f |
|
text_f_c = conditions.text_f_c |
|
|
|
latent = self.audio_input_proj(latent) |
|
|
|
global_c = self.t_embed(t).unsqueeze(1) + text_f_c.unsqueeze(1) |
|
|
|
extended_c = global_c |
|
|
|
for block in self.joint_blocks: |
|
latent, text_f = block(latent, text_f, global_c, extended_c, self.latent_rot, self.text_rot) |
|
|
|
for block in self.fused_blocks: |
|
latent = block(latent, extended_c, self.latent_rot) |
|
|
|
flow = self.final_layer(latent, extended_c) |
|
return flow |
|
|
|
def forward(self, latent: torch.Tensor, text_f: torch.Tensor, text_f_c: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|
""" |
|
latent: (B, N, C) |
|
text_f: (B, T, D) |
|
t: (B,) |
|
""" |
|
conditions = self.preprocess_conditions(text_f, text_f_c) |
|
flow = self.predict_flow(latent, t, conditions) |
|
return flow |
|
|
|
def get_empty_string_sequence(self, bs: int) -> tuple[torch.Tensor, torch.Tensor]: |
|
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1), \ |
|
self.empty_string_feat_c.unsqueeze(0).expand(bs, -1) |
|
|
|
def get_empty_conditions( |
|
self, |
|
bs: int, |
|
*, |
|
negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions: |
|
if negative_text_features is not None: |
|
empty_string_feat, empty_string_feat_c = negative_text_features |
|
else: |
|
empty_string_feat, empty_string_feat_c = self.get_empty_string_sequence(1) |
|
|
|
conditions = self.preprocess_conditions(empty_string_feat, |
|
empty_string_feat_c) |
|
|
|
if negative_text_features is None: |
|
conditions.text_f = conditions.text_f.expand(bs, -1, -1) |
|
|
|
conditions.text_f_c = conditions.text_f_c.expand(bs, -1) |
|
|
|
return conditions |
|
|
|
def ode_wrapper(self, t: torch.Tensor, latent: torch.Tensor, conditions: PreprocessedConditions, |
|
empty_conditions: PreprocessedConditions, cfg_strength: float) -> torch.Tensor: |
|
t = t * torch.ones(len(latent), device=latent.device, dtype=latent.dtype) |
|
|
|
if cfg_strength < 1.0: |
|
return self.predict_flow(latent, t, conditions) |
|
else: |
|
return (cfg_strength * self.predict_flow(latent, t, conditions) + |
|
(1 - cfg_strength) * self.predict_flow(latent, t, empty_conditions)) |
|
|
|
def load_weights(self, src_dict) -> None: |
|
if 't_embed.freqs' in src_dict: |
|
del src_dict['t_embed.freqs'] |
|
if 'latent_rot' in src_dict: |
|
del src_dict['latent_rot'] |
|
if 'text_rot' in src_dict: |
|
del src_dict['text_rot'] |
|
|
|
if 'empty_string_feat_c' not in src_dict.keys(): |
|
src_dict['empty_string_feat_c'] = src_dict['empty_string_feat'].mean(dim=0) |
|
self.load_state_dict(src_dict, strict=True) |
|
|
|
@property |
|
def device(self) -> torch.device: |
|
return self.latent_mean.device |
|
|
|
@property |
|
def latent_seq_len(self) -> int: |
|
return self._latent_seq_len |
|
|
|
|
|
class MeanAudio(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
*, |
|
latent_dim: int, |
|
text_dim: int, |
|
text_c_dim: int, |
|
hidden_dim: int, |
|
depth: int, |
|
fused_depth: int, |
|
num_heads: int, |
|
mlp_ratio: float = 4.0, |
|
latent_seq_len: int, |
|
text_seq_len: int = 77, |
|
latent_mean: Optional[torch.Tensor] = None, |
|
latent_std: Optional[torch.Tensor] = None, |
|
empty_string_feat: Optional[torch.Tensor] = None, |
|
empty_string_feat_c: Optional[torch.Tensor] = None, |
|
use_rope: bool = False) -> None: |
|
super().__init__() |
|
|
|
self.latent_dim = latent_dim |
|
self._latent_seq_len = latent_seq_len |
|
self._text_seq_len = text_seq_len |
|
self.hidden_dim = hidden_dim |
|
self.num_heads = num_heads |
|
self.use_rope = use_rope |
|
self.mm_depth = depth - fused_depth |
|
|
|
self.audio_input_proj = nn.Sequential( |
|
ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3), |
|
nn.SELU(), |
|
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3), |
|
) |
|
|
|
self.text_input_proj = nn.Sequential( |
|
nn.Linear(text_dim, hidden_dim), |
|
MLP(hidden_dim, hidden_dim * 4), |
|
) |
|
|
|
self.text_cond_proj = nn.Sequential( |
|
nn.Linear(text_c_dim, hidden_dim), |
|
MLP(hidden_dim, hidden_dim*4) |
|
) |
|
|
|
self.final_layer = FinalBlock(hidden_dim, latent_dim) |
|
|
|
self.t_embed = TimestepEmbedder(hidden_dim, |
|
frequency_embedding_size=256, |
|
max_period=10000) |
|
|
|
self.r_embed = TimestepEmbedder(hidden_dim, |
|
frequency_embedding_size=256, |
|
max_period=10000) |
|
self.joint_blocks = nn.ModuleList([ |
|
JointBlock(hidden_dim, |
|
num_heads, |
|
mlp_ratio=mlp_ratio, |
|
pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth) |
|
]) |
|
|
|
self.fused_blocks = nn.ModuleList([ |
|
MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1) |
|
for i in range(fused_depth) |
|
]) |
|
|
|
if latent_mean is None: |
|
|
|
|
|
assert latent_std is None |
|
latent_mean = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan')) |
|
latent_std = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan')) |
|
else: |
|
assert latent_std is not None |
|
assert latent_mean.numel() == latent_dim, f'{latent_mean.numel()=} != {latent_dim=}' |
|
|
|
if empty_string_feat is None: |
|
empty_string_feat = torch.zeros((text_seq_len, text_dim)) |
|
if empty_string_feat_c is None: |
|
empty_string_feat_c = torch.zeros((text_c_dim)) |
|
|
|
assert empty_string_feat.shape[-1] == text_dim, f'{empty_string_feat.shape[-1]} == {text_dim}' |
|
assert empty_string_feat_c.shape[-1] == text_c_dim, f'{empty_string_feat_c.shape[-1]} == {text_c_dim}' |
|
|
|
self.latent_mean = nn.Parameter(latent_mean.view(1, 1, -1), requires_grad=False) |
|
self.latent_std = nn.Parameter(latent_std.view(1, 1, -1), requires_grad=False) |
|
|
|
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False) |
|
self.empty_string_feat_c = nn.Parameter(empty_string_feat_c, requires_grad=False) |
|
|
|
|
|
self.initialize_weights() |
|
if self.use_rope: |
|
log.info("Network: Enabling RoPE embeddings") |
|
self.initialize_rotations() |
|
else: |
|
log.info("Network: RoPE embedding disabled") |
|
self.latent_rot = None |
|
self.text_rot = None |
|
|
|
def initialize_rotations(self): |
|
base_freq = 1.0 |
|
latent_rot = compute_rope_rotations(self._latent_seq_len, |
|
self.hidden_dim // self.num_heads, |
|
10000, |
|
freq_scaling=base_freq, |
|
device=self.device) |
|
text_rot = compute_rope_rotations(self._text_seq_len, |
|
self.hidden_dim // self.num_heads, |
|
10000, |
|
freq_scaling=base_freq, |
|
device=self.device) |
|
|
|
self.latent_rot = nn.Buffer(latent_rot, persistent=False) |
|
self.text_rot = nn.Buffer(text_rot, persistent=False) |
|
|
|
def update_seq_lengths(self, latent_seq_len: int) -> None: |
|
self._latent_seq_len = latent_seq_len |
|
if self.use_rope: |
|
self.initialize_rotations() |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02) |
|
|
|
|
|
for block in self.joint_blocks: |
|
nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0) |
|
nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0) |
|
for block in self.fused_blocks: |
|
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) |
|
nn.init.constant_(self.final_layer.conv.weight, 0) |
|
nn.init.constant_(self.final_layer.conv.bias, 0) |
|
|
|
def normalize(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
return x.sub_(self.latent_mean).div_(self.latent_std) |
|
|
|
def unnormalize(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
return x.mul_(self.latent_std).add_(self.latent_mean) |
|
|
|
def preprocess_conditions(self, text_f: torch.Tensor, text_f_c: torch.Tensor) -> PreprocessedConditions: |
|
""" |
|
cache computations that do not depend on the latent/time step |
|
i.e., the features are reused over steps during inference |
|
""" |
|
assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}' |
|
|
|
bs = text_f.shape[0] |
|
|
|
|
|
|
|
text_f_c = self.text_cond_proj(text_f_c) |
|
text_f = self.text_input_proj(text_f) |
|
|
|
return PreprocessedConditions(text_f=text_f, |
|
text_f_c=text_f_c) |
|
|
|
def predict_flow(self, latent: torch.Tensor, t: torch.Tensor,r: torch.Tensor, |
|
conditions: PreprocessedConditions) -> torch.Tensor: |
|
""" |
|
for non-cacheable computations |
|
""" |
|
|
|
|
|
assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}' |
|
|
|
text_f = conditions.text_f |
|
text_f_c = conditions.text_f_c |
|
|
|
latent = self.audio_input_proj(latent) |
|
|
|
global_c = self.t_embed(t).unsqueeze(1) + self.r_embed(r).unsqueeze(1) + text_f_c.unsqueeze(1) |
|
|
|
extended_c = global_c |
|
|
|
for block in self.joint_blocks: |
|
latent, text_f = block(latent, text_f, global_c, extended_c, self.latent_rot, self.text_rot) |
|
|
|
for block in self.fused_blocks: |
|
latent = block(latent, extended_c, self.latent_rot) |
|
|
|
flow = self.final_layer(latent, extended_c) |
|
return flow |
|
|
|
def forward(self, latent: torch.Tensor, text_f: torch.Tensor, text_f_c: torch.Tensor, r: torch.Tensor,t: torch.Tensor) -> torch.Tensor: |
|
""" |
|
latent: (B, N, C) |
|
text_f: (B, T, D) |
|
text_f_c |
|
r: (B,) |
|
t:(B,) |
|
""" |
|
|
|
|
|
conditions = self.preprocess_conditions(text_f, text_f_c) |
|
|
|
flow = self.predict_flow(latent, t,r, conditions) |
|
return flow |
|
|
|
def get_empty_string_sequence(self, bs: int) -> tuple[torch.Tensor, torch.Tensor]: |
|
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1), \ |
|
self.empty_string_feat_c.unsqueeze(0).expand(bs, -1) |
|
|
|
def get_empty_conditions( |
|
self, |
|
bs: int, |
|
*, |
|
negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions: |
|
if negative_text_features is not None: |
|
empty_string_feat, empty_string_feat_c = negative_text_features |
|
else: |
|
empty_string_feat, empty_string_feat_c = self.get_empty_string_sequence(1) |
|
|
|
conditions = self.preprocess_conditions(empty_string_feat, |
|
empty_string_feat_c) |
|
if negative_text_features is None: |
|
conditions.text_f = conditions.text_f.expand(bs, -1, -1) |
|
|
|
conditions.text_f_c = conditions.text_f_c.expand(bs, -1) |
|
|
|
return conditions |
|
|
|
def ode_wrapper(self, t: torch.Tensor, r: torch.Tensor, latent: torch.Tensor, conditions: PreprocessedConditions, |
|
empty_conditions: PreprocessedConditions, cfg_strength: float) -> torch.Tensor: |
|
t = t * torch.ones(len(latent), device=latent.device, dtype=latent.dtype) |
|
r = r * torch.ones(len(latent), device=latent.device, dtype=latent.dtype) |
|
|
|
if cfg_strength < 1.0: |
|
return self.predict_flow(latent, t,r, conditions) |
|
else: |
|
return (cfg_strength * self.predict_flow(latent, t,r, conditions) + |
|
(1 - cfg_strength) * self.predict_flow(latent, t,r, empty_conditions)) |
|
|
|
|
|
def load_weights(self, src_dict) -> None: |
|
def remove_prefix(storage): |
|
from collections import OrderedDict |
|
new_state_dict = OrderedDict() |
|
for k, v in storage.items(): |
|
name = k.replace("ema_model.", "") |
|
new_state_dict[name] = v |
|
|
|
return new_state_dict |
|
|
|
src_dict=remove_prefix(src_dict) |
|
if 't_embed.freqs' in src_dict: |
|
del src_dict['t_embed.freqs'] |
|
if 'r_embed.freqs' in src_dict: |
|
del src_dict['r_embed.freqs'] |
|
if 'latent_rot' in src_dict: |
|
del src_dict['latent_rot'] |
|
if 'text_rot' in src_dict: |
|
del src_dict['text_rot'] |
|
|
|
if 'empty_string_feat_c' not in src_dict.keys(): |
|
src_dict['empty_string_feat_c'] = src_dict['empty_string_feat'].mean(dim=0) |
|
if '_extra_state' in src_dict: |
|
del src_dict['_extra_state'] |
|
self.load_state_dict(src_dict, strict=True) |
|
|
|
@property |
|
def device(self) -> torch.device: |
|
return self.latent_mean.device |
|
|
|
@property |
|
def latent_seq_len(self) -> int: |
|
return self._latent_seq_len |
|
|
|
|
|
def fluxaudio_s(**kwargs) -> FluxAudio: |
|
num_heads = 7 |
|
return FluxAudio(latent_dim=20, |
|
text_dim=1024, |
|
hidden_dim=64 * num_heads, |
|
depth=12, |
|
fused_depth=8, |
|
num_heads=num_heads, |
|
latent_seq_len=312, |
|
**kwargs) |
|
def meanaudio_s(**kwargs) -> MeanAudio: |
|
num_heads = 7 |
|
return MeanAudio(latent_dim=20, |
|
text_dim=1024, |
|
hidden_dim=64 * num_heads, |
|
depth=12, |
|
fused_depth=8, |
|
num_heads=num_heads, |
|
latent_seq_len=312, |
|
**kwargs) |
|
def meanaudio_l(**kwargs) -> MeanAudio: |
|
num_heads = 14 |
|
return MeanAudio(latent_dim=20, |
|
text_dim=1024, |
|
hidden_dim=64 * num_heads, |
|
depth=12, |
|
fused_depth=8, |
|
num_heads=num_heads, |
|
latent_seq_len=312, |
|
**kwargs) |
|
|
|
def get_mean_audio(name: str, **kwargs) -> MeanAudio: |
|
if name == 'meanaudio_s_ac' or name == 'meanaudio_s_full': |
|
return meanaudio_s(**kwargs) |
|
elif name == 'meanaudio_l_full': |
|
return meanaudio_l(**kwargs) |
|
elif name == 'fluxaudio_s_full': |
|
return fluxaudio_s(**kwargs) |
|
else: |
|
raise ValueError(f'Unknown model name: {name}') |
|
|
|
|
|
if __name__ == '__main__': |
|
from meanaudio.model.utils.sample_utils import log_normal_sample |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
handlers=[ |
|
|
|
logging.StreamHandler() |
|
] |
|
) |
|
|
|
network: MeanAudio = get_mean_audio('meanaudio_s', |
|
use_rope=False, |
|
text_c_dim=512) |
|
|
|
x = torch.randn(256, 312, 20) |
|
print(x.shape) |
|
print('Finish') |
|
|
|
|