File size: 7,263 Bytes
3a1da90 7beabf1 3a1da90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
from typing import Literal, Optional
import open_clip
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from open_clip import create_model_from_pretrained
from torchvision.transforms import Normalize
from transformers import T5EncoderModel, AutoTokenizer
from meanaudio.ext.autoencoder import AutoEncoderModule
from meanaudio.ext.mel_converter import get_mel_converter
from meanaudio.model.utils.distributions import DiagonalGaussianDistribution
import laion_clap
import logging
def patch_clip(clip_model):
# a hack to make it output last hidden states
# https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
def new_encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
return F.normalize(x, dim=-1) if normalize else x
clip_model.encode_text = new_encode_text.__get__(clip_model)
return clip_model
class FeaturesUtils(nn.Module):
def __init__(
self,
*,
tod_vae_ckpt: Optional[str] = None,
bigvgan_vocoder_ckpt: Optional[str] = None,
enable_conditions: bool = True,
encoder_name=Literal['clip', 't5', 't5_clap', 't5_clap_cat'],
mode=Literal['16k', '44k'],
need_vae_encoder: bool = True,
):
super().__init__()
if enable_conditions:
self.encoder_name = encoder_name
if encoder_name == 'clip':
self.text_encoder = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',
return_transform=False)
self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
self.text_encoder = patch_clip(self.text_encoder)
self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
elif encoder_name == 't5':
logging.info('FeatureUtils: Loading google/flan-t5-large ... ') # root logger
self.tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-large')
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large').eval()
elif encoder_name == 't5_clap' or encoder_name == 't5_clap_cat':
self.tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-large',revision="main")
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large',revision="main").eval()
self.laion_clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base').eval()
self._clap_ckpt_path = "./weights/music_speech_audioset_epoch_15_esc_89.98.pt"
self.laion_clap_model.load_ckpt(self._clap_ckpt_path, verbose=False)
else:
raise ValueError(f"Encoder {encoder_name} is not allowed, select from ['clip', 't5']")
else:
self.text_encoder = None
self.tokenizer = None
if tod_vae_ckpt is not None:
self.mel_converter = get_mel_converter(mode)
self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
vocoder_ckpt_path=bigvgan_vocoder_ckpt,
mode=mode,
need_vae_encoder=need_vae_encoder)
else:
self.tod = None
def compile(self):
if self.text_encoder is not None:
self.text_encoder.encode_text = torch.compile(self.text_encoder.encode_text) # ONLY for CLIP text encoder
self.decode = torch.compile(self.decode)
self.vocode = torch.compile(self.vocode)
def train(self, mode: bool) -> None:
return super().train(False)
@torch.inference_mode()
def encode_text(self, text: list[str]) -> torch.Tensor:
assert self.text_encoder is not None, 'Text encoder is not loaded'
assert self.tokenizer is not None, 'Tokenizer is not loaded'
# x: (B, L)
if self.encoder_name == 'clip':
tokens = self.tokenizer(text).to(self.device)
text_features = self.text_encoder.encode_text(tokens, normalize=True)
elif self.encoder_name == 't5':
tokens = self.tokenizer(
text,
max_length=77,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids, attention_mask = tokens.input_ids.cuda(), tokens.attention_mask.cuda()
text_features = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
)[0]
text_features_c = text_features.mean(dim=1)
elif self.encoder_name == 't5_clap' or self.encoder_name == 't5_clap_cat':
tokens = self.tokenizer(
text,
max_length=77,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids, attention_mask = tokens.input_ids.cuda(), tokens.attention_mask.cuda()
text_features = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
)[0]
text_features_c = self.laion_clap_model.get_text_embedding(text, use_tensor=True)
if self.encoder_name == 't5_clap_cat':
text_features_c = torch.cat([text_features.mean(dim=-2), text_features_c], dim=-1)
return text_features, text_features_c
@torch.inference_mode()
def encode_audio(self, x) -> DiagonalGaussianDistribution:
assert self.tod is not None, 'VAE is not loaded'
# x: (B * L)
mel = self.mel_converter(x)
dist = self.tod.encode(mel)
return dist
@torch.inference_mode()
def vocode(self, mel: torch.Tensor) -> torch.Tensor:
assert self.tod is not None, 'VAE is not loaded'
return self.tod.vocode(mel)
@torch.inference_mode()
def decode(self, z: torch.Tensor) -> torch.Tensor:
assert self.tod is not None, 'VAE is not loaded'
return self.tod.decode(z.transpose(1, 2))
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
if __name__ == '__main__':
# features = FeaturesUtilsAT(
# tod_vae_ckpt='./ext_weights/v1-16.pth',
# bigvgan_vocoder_ckpt='./ext_weights/best_netG.pt',
# mode='16k',
# encoder_name='t5'
# )
# print(features)
clap_ckpt = "./weights/music_speech_audioset_epoch_15_esc_89.98.pt"
weights = torch.load(clap_ckpt, weights_only=False)
print(weights.keys())
|