gense / models /gense_wavlm.py
yaoxunji's picture
Upload 2 files
4f1a5d6 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from components.semantic_extractor.ssl_model import get_ssl_model
from components.simcodec.model import SimCodec
from transformers import GPT2Config, GPT2LMHeadModel
class N2S(nn.Module):
def __init__(self, hps):
super().__init__()
self.hps = hps
self.wavlm, self.km = get_ssl_model(**hps.ssl_model)
self.bos = 1
self.eos = 2
self.pad = 0
self.shift_num = 3
self.lm_conf = GPT2Config(
vocab_size=self.hps.model['n2s_vocab_size'],
n_embd=self.hps.model['hidden_size'],
n_layer=self.hps.model['num_hidden_layers'],
n_head=self.hps.model['num_attention_heads'],
activation_function='gelu_new',
n_positions=2048,
n_ctx=2048,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-05,
initializer_range=0.02,
summary_type='mean',
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
bos_token_id=self.bos,
eos_token_id=self.eos,
)
self.lm = GPT2LMHeadModel(self.lm_conf)
def extract_semantic(self, wavs, num_frames):
padding_size = (0, 100)
wavs = F.pad(wavs, padding_size, "constant", 0)
num_frames += 100
features = self.wavlm.extract_features(
wavs,
output_layer=6,
ret_layer_results=False,
input_length=num_frames
)[0]
b, t, d = features.shape
tokens = self.km(features.reshape(-1, d), b=b, t=t)
return tokens
def inference(self, token_gen, pos_gen):
predict_len = (token_gen.shape[1] - 1)
truck_length = token_gen.shape[1]
for j in tqdm(range(predict_len)):
lm_outputs = self.lm(
input_ids=token_gen,
attention_mask=None,
position_ids=pos_gen
)
logits = lm_outputs['logits']
logits[:, :, 0:self.shift_num] = -1e5
probs = logits[:, -1, :].softmax(dim=-1)
dist = torch.distributions.categorical.Categorical(probs=probs)
samples = dist.sample().unsqueeze(1).to(token_gen.device)
token_gen = torch.cat([token_gen, samples], dim=1)
pos_pad = torch.ones(pos_gen.shape[0]) * j
pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
return token_gen[:,truck_length:][0]
def generate(self, mix):
mix = mix.squeeze(1)
num_frame = torch.LongTensor([mix.shape[1]]).to(mix.device)
token_s = self.extract_semantic(mix, num_frames=num_frame)
token_s += 3
bos = torch.ones(token_s.shape[0],1).long().to(mix.device)
token_gen = torch.cat([token_s, bos], dim=1)
pos_gen_id = torch.from_numpy(np.asarray(list(range(token_s.shape[1] + 1)))).to(mix.device)
pos_gen = []
for i in range(token_s.shape[0]):
pos_gen.append(pos_gen_id.unsqueeze(0))
pos_gen = torch.cat(pos_gen, dim=0)
clean_s = self.inference(token_gen, pos_gen) - self.shift_num
token_s -= self.shift_num
return token_s, clean_s
class S2S(nn.Module):
def __init__(self, hps):
super().__init__()
self.hps = hps
self.codec_tokenizer = SimCodec(hps.path['codec_config_path'])
self.wavlm, self.km = get_ssl_model(**hps.ssl_model)
self.bos = 1
self.eos = 2
self.pad = 0
self.shift_num = 3 + self.hps.model['semantic_num']
self.lm_conf = GPT2Config(
vocab_size=self.hps.model['s2s_vocab_size'],
n_embd=self.hps.model['hidden_size'],
n_layer=self.hps.model['num_hidden_layers'],
n_head=self.hps.model['num_attention_heads'],
activation_function='gelu_new',
n_positions=4096,
n_ctx=4096,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-05,
initializer_range=0.02,
summary_type='mean',
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
bos_token_id=self.bos,
eos_token_id=self.eos,
)
self.lm = GPT2LMHeadModel(self.lm_conf)
def inference(self, token_gen, pos_gen):
predict_len = int((token_gen.shape[1] - 1) / 2)
truck_length = token_gen.shape[1]
for j in tqdm(range(predict_len)):
lm_outputs = self.lm(
input_ids=token_gen,
attention_mask=None,
position_ids=pos_gen
)
logits = lm_outputs['logits']
logits[:, :, 0:self.shift_num] = -1e5
probs = logits[:, -1, :].softmax(dim=-1)
dist = torch.distributions.categorical.Categorical(probs=probs)
samples = dist.sample().unsqueeze(1).to(token_gen.device)
token_gen = torch.cat([token_gen, samples], dim=1)
pos_pad = torch.ones(pos_gen.shape[0]) * (j + 1000)
pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
return token_gen[:,truck_length:][0]
def generate(self, mix, mix_s, clean_s):
mix_a = self.codec_tokenizer(mix).squeeze(-1)
if len(clean_s.shape) == 1:
clean_s = clean_s.unsqueeze(0)
mix_s += 3
clean_s += 3
mix_a += self.shift_num
bos = torch.ones(mix_s.shape[0],1).long().to(mix.device)
token_gen = torch.cat([mix_s, clean_s, bos, mix_a], dim=1)
pos_gen_id = torch.from_numpy(np.asarray(list(range(mix_s.shape[1] + clean_s.shape[1] + 1)) + list(range(mix_a.shape[1])))).to(mix.device)
pos_gen = []
for i in range(mix_s.shape[0]):
pos_gen.append(pos_gen_id.unsqueeze(0))
pos_gen = torch.cat(pos_gen, dim=0)
pre_a = self.inference(token_gen, pos_gen) - self.shift_num
gen_wav = self.codec_tokenizer.decode(pre_a.unsqueeze(0).unsqueeze(2)).squeeze(0).cpu()
return gen_wav