File size: 6,469 Bytes
4f1a5d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from components.semantic_extractor.ssl_model import get_ssl_model
from components.simcodec.model import SimCodec
from transformers import GPT2Config, GPT2LMHeadModel
class N2S(nn.Module):
def __init__(self, hps):
super().__init__()
self.hps = hps
self.wavlm, self.km = get_ssl_model(**hps.ssl_model)
self.bos = 1
self.eos = 2
self.pad = 0
self.shift_num = 3
self.lm_conf = GPT2Config(
vocab_size=self.hps.model['n2s_vocab_size'],
n_embd=self.hps.model['hidden_size'],
n_layer=self.hps.model['num_hidden_layers'],
n_head=self.hps.model['num_attention_heads'],
activation_function='gelu_new',
n_positions=2048,
n_ctx=2048,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-05,
initializer_range=0.02,
summary_type='mean',
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
bos_token_id=self.bos,
eos_token_id=self.eos,
)
self.lm = GPT2LMHeadModel(self.lm_conf)
def extract_semantic(self, wavs, num_frames):
padding_size = (0, 100)
wavs = F.pad(wavs, padding_size, "constant", 0)
num_frames += 100
features = self.wavlm.extract_features(
wavs,
output_layer=6,
ret_layer_results=False,
input_length=num_frames
)[0]
b, t, d = features.shape
tokens = self.km(features.reshape(-1, d), b=b, t=t)
return tokens
def inference(self, token_gen, pos_gen):
predict_len = (token_gen.shape[1] - 1)
truck_length = token_gen.shape[1]
for j in tqdm(range(predict_len)):
lm_outputs = self.lm(
input_ids=token_gen,
attention_mask=None,
position_ids=pos_gen
)
logits = lm_outputs['logits']
logits[:, :, 0:self.shift_num] = -1e5
probs = logits[:, -1, :].softmax(dim=-1)
dist = torch.distributions.categorical.Categorical(probs=probs)
samples = dist.sample().unsqueeze(1).to(token_gen.device)
token_gen = torch.cat([token_gen, samples], dim=1)
pos_pad = torch.ones(pos_gen.shape[0]) * j
pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
return token_gen[:,truck_length:][0]
def generate(self, mix):
mix = mix.squeeze(1)
num_frame = torch.LongTensor([mix.shape[1]]).to(mix.device)
token_s = self.extract_semantic(mix, num_frames=num_frame)
token_s += 3
bos = torch.ones(token_s.shape[0],1).long().to(mix.device)
token_gen = torch.cat([token_s, bos], dim=1)
pos_gen_id = torch.from_numpy(np.asarray(list(range(token_s.shape[1] + 1)))).to(mix.device)
pos_gen = []
for i in range(token_s.shape[0]):
pos_gen.append(pos_gen_id.unsqueeze(0))
pos_gen = torch.cat(pos_gen, dim=0)
clean_s = self.inference(token_gen, pos_gen) - self.shift_num
token_s -= self.shift_num
return token_s, clean_s
class S2S(nn.Module):
def __init__(self, hps):
super().__init__()
self.hps = hps
self.codec_tokenizer = SimCodec(hps.path['codec_config_path'])
self.wavlm, self.km = get_ssl_model(**hps.ssl_model)
self.bos = 1
self.eos = 2
self.pad = 0
self.shift_num = 3 + self.hps.model['semantic_num']
self.lm_conf = GPT2Config(
vocab_size=self.hps.model['s2s_vocab_size'],
n_embd=self.hps.model['hidden_size'],
n_layer=self.hps.model['num_hidden_layers'],
n_head=self.hps.model['num_attention_heads'],
activation_function='gelu_new',
n_positions=4096,
n_ctx=4096,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-05,
initializer_range=0.02,
summary_type='mean',
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
bos_token_id=self.bos,
eos_token_id=self.eos,
)
self.lm = GPT2LMHeadModel(self.lm_conf)
def inference(self, token_gen, pos_gen):
predict_len = int((token_gen.shape[1] - 1) / 2)
truck_length = token_gen.shape[1]
for j in tqdm(range(predict_len)):
lm_outputs = self.lm(
input_ids=token_gen,
attention_mask=None,
position_ids=pos_gen
)
logits = lm_outputs['logits']
logits[:, :, 0:self.shift_num] = -1e5
probs = logits[:, -1, :].softmax(dim=-1)
dist = torch.distributions.categorical.Categorical(probs=probs)
samples = dist.sample().unsqueeze(1).to(token_gen.device)
token_gen = torch.cat([token_gen, samples], dim=1)
pos_pad = torch.ones(pos_gen.shape[0]) * (j + 1000)
pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
return token_gen[:,truck_length:][0]
def generate(self, mix, mix_s, clean_s):
mix_a = self.codec_tokenizer(mix).squeeze(-1)
if len(clean_s.shape) == 1:
clean_s = clean_s.unsqueeze(0)
mix_s += 3
clean_s += 3
mix_a += self.shift_num
bos = torch.ones(mix_s.shape[0],1).long().to(mix.device)
token_gen = torch.cat([mix_s, clean_s, bos, mix_a], dim=1)
pos_gen_id = torch.from_numpy(np.asarray(list(range(mix_s.shape[1] + clean_s.shape[1] + 1)) + list(range(mix_a.shape[1])))).to(mix.device)
pos_gen = []
for i in range(mix_s.shape[0]):
pos_gen.append(pos_gen_id.unsqueeze(0))
pos_gen = torch.cat(pos_gen, dim=0)
pre_a = self.inference(token_gen, pos_gen) - self.shift_num
gen_wav = self.codec_tokenizer.decode(pre_a.unsqueeze(0).unsqueeze(2)).squeeze(0).cpu()
return gen_wav |