yaoxunji commited on
Commit
4f1a5d6
·
verified ·
1 Parent(s): c130bcf

Upload 2 files

Browse files
Files changed (2) hide show
  1. models/gense.py +174 -0
  2. models/gense_wavlm.py +176 -0
models/gense.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from components.semantic_extractor.ssl_model import get_ssl_model
8
+ from components.simcodec.model import SimCodec
9
+ from transformers import GPT2Config, GPT2LMHeadModel
10
+
11
+ class N2S(nn.Module):
12
+ def __init__(self, hps):
13
+ super().__init__()
14
+ self.hps = hps
15
+ self.xlsr, self.km = get_ssl_model(**hps.ssl_model)
16
+ self.bos = 1
17
+ self.eos = 2
18
+ self.pad = 0
19
+ self.shift_num = 3
20
+
21
+ self.lm_conf = GPT2Config(
22
+ vocab_size=self.hps.model['n2s_vocab_size'],
23
+ n_embd=self.hps.model['hidden_size'],
24
+ n_layer=self.hps.model['num_hidden_layers'],
25
+ n_head=self.hps.model['num_attention_heads'],
26
+ activation_function='gelu_new',
27
+ n_positions=2048,
28
+ n_ctx=2048,
29
+ resid_pdrop=0.1,
30
+ embd_pdrop=0.1,
31
+ attn_pdrop=0.1,
32
+ layer_norm_epsilon=1e-05,
33
+ initializer_range=0.02,
34
+ summary_type='mean',
35
+ summary_use_proj=True,
36
+ summary_activation=None,
37
+ summary_proj_to_labels=True,
38
+ summary_first_dropout=0.1,
39
+ bos_token_id=self.bos,
40
+ eos_token_id=self.eos,
41
+ )
42
+ self.lm = GPT2LMHeadModel(self.lm_conf)
43
+
44
+ def extract_semantic(self, wavs, num_frames):
45
+ padding_size = (0, 100)
46
+ wavs = F.pad(wavs, padding_size, "constant", 0)
47
+ num_frames += 100
48
+ features = self.xlsr.extract_features(wavs, padding_mask=None)
49
+ layer_results = features['layer_results'][5]
50
+ x, _, _ = layer_results
51
+ features = x.transpose(0,1)
52
+ b, t, d = features.shape
53
+ tokens = self.km(features.reshape(-1, d), b=b, t=t)
54
+ return tokens
55
+
56
+ def inference(self, token_gen, pos_gen):
57
+ predict_len = (token_gen.shape[1] - 1)
58
+ truck_length = token_gen.shape[1]
59
+
60
+ for j in tqdm(range(predict_len)):
61
+ lm_outputs = self.lm(
62
+ input_ids=token_gen,
63
+ attention_mask=None,
64
+ position_ids=pos_gen
65
+ )
66
+ logits = lm_outputs['logits']
67
+ logits[:, :, 0:self.shift_num] = -1e5
68
+ probs = logits[:, -1, :].softmax(dim=-1)
69
+
70
+ dist = torch.distributions.categorical.Categorical(probs=probs)
71
+
72
+ samples = dist.sample().unsqueeze(1).to(token_gen.device)
73
+ token_gen = torch.cat([token_gen, samples], dim=1)
74
+ pos_pad = torch.ones(pos_gen.shape[0]) * j
75
+ pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
76
+
77
+ return token_gen[:,truck_length:][0]
78
+
79
+
80
+ def generate(self, mix):
81
+ mix = mix.squeeze(1)
82
+ num_frame = torch.LongTensor([mix.shape[1]]).to(mix.device)
83
+ token_s = self.extract_semantic(mix, num_frames=num_frame)
84
+
85
+ token_s += 3
86
+ bos = torch.ones(token_s.shape[0],1).long().to(mix.device)
87
+ token_gen = torch.cat([token_s, bos], dim=1)
88
+
89
+ pos_gen_id = torch.from_numpy(np.asarray(list(range(token_s.shape[1] + 1)))).to(mix.device)
90
+ pos_gen = []
91
+ for i in range(token_s.shape[0]):
92
+ pos_gen.append(pos_gen_id.unsqueeze(0))
93
+ pos_gen = torch.cat(pos_gen, dim=0)
94
+
95
+ clean_s = self.inference(token_gen, pos_gen) - self.shift_num
96
+ token_s -= self.shift_num
97
+ return token_s, clean_s
98
+
99
+
100
+ class S2S(nn.Module):
101
+ def __init__(self, hps):
102
+ super().__init__()
103
+ self.hps = hps
104
+ self.codec_tokenizer = SimCodec(hps.path['codec_config_path'])
105
+ self.xlsr, self.km = get_ssl_model(**hps.ssl_model)
106
+ self.bos = 1
107
+ self.eos = 2
108
+ self.pad = 0
109
+ self.shift_num = 3 + self.hps.model['semantic_num']
110
+ self.lm_conf = GPT2Config(
111
+ vocab_size=self.hps.model['s2s_vocab_size'],
112
+ n_embd=self.hps.model['hidden_size'],
113
+ n_layer=self.hps.model['num_hidden_layers'],
114
+ n_head=self.hps.model['num_attention_heads'],
115
+ activation_function='gelu_new',
116
+ n_positions=4096,
117
+ n_ctx=4096,
118
+ resid_pdrop=0.1,
119
+ embd_pdrop=0.1,
120
+ attn_pdrop=0.1,
121
+ layer_norm_epsilon=1e-05,
122
+ initializer_range=0.02,
123
+ summary_type='mean',
124
+ summary_use_proj=True,
125
+ summary_activation=None,
126
+ summary_proj_to_labels=True,
127
+ summary_first_dropout=0.1,
128
+ bos_token_id=self.bos,
129
+ eos_token_id=self.eos,
130
+ )
131
+ self.lm = GPT2LMHeadModel(self.lm_conf)
132
+
133
+ def inference(self, token_gen, pos_gen):
134
+ predict_len = int((token_gen.shape[1] - 1) / 2)
135
+ truck_length = token_gen.shape[1]
136
+ for j in tqdm(range(predict_len)):
137
+ lm_outputs = self.lm(
138
+ input_ids=token_gen,
139
+ attention_mask=None,
140
+ position_ids=pos_gen
141
+ )
142
+ logits = lm_outputs['logits']
143
+ logits[:, :, 0:self.shift_num] = -1e5
144
+ probs = logits[:, -1, :].softmax(dim=-1)
145
+ dist = torch.distributions.categorical.Categorical(probs=probs)
146
+ samples = dist.sample().unsqueeze(1).to(token_gen.device)
147
+ token_gen = torch.cat([token_gen, samples], dim=1)
148
+ pos_pad = torch.ones(pos_gen.shape[0]) * (j + 1000)
149
+ pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
150
+
151
+ return token_gen[:,truck_length:][0]
152
+
153
+ def generate(self, mix, mix_s, clean_s):
154
+ mix_a = self.codec_tokenizer(mix).squeeze(-1)
155
+ if len(clean_s.shape) == 1:
156
+ clean_s = clean_s.unsqueeze(0)
157
+
158
+ mix_s += 3
159
+ clean_s += 3
160
+ mix_a += self.shift_num
161
+
162
+ bos = torch.ones(mix_s.shape[0],1).long().to(mix.device)
163
+ token_gen = torch.cat([mix_s, clean_s, bos, mix_a], dim=1)
164
+
165
+ pos_gen_id = torch.from_numpy(np.asarray(list(range(mix_s.shape[1] + clean_s.shape[1] + 1)) + list(range(mix_a.shape[1])))).to(mix.device)
166
+ pos_gen = []
167
+ for i in range(mix_s.shape[0]):
168
+ pos_gen.append(pos_gen_id.unsqueeze(0))
169
+ pos_gen = torch.cat(pos_gen, dim=0)
170
+
171
+ pre_a = self.inference(token_gen, pos_gen) - self.shift_num
172
+ gen_wav = self.codec_tokenizer.decode(pre_a.unsqueeze(0).unsqueeze(2)).squeeze(0).cpu()
173
+
174
+ return gen_wav
models/gense_wavlm.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from components.semantic_extractor.ssl_model import get_ssl_model
8
+ from components.simcodec.model import SimCodec
9
+ from transformers import GPT2Config, GPT2LMHeadModel
10
+
11
+ class N2S(nn.Module):
12
+ def __init__(self, hps):
13
+ super().__init__()
14
+ self.hps = hps
15
+ self.wavlm, self.km = get_ssl_model(**hps.ssl_model)
16
+ self.bos = 1
17
+ self.eos = 2
18
+ self.pad = 0
19
+ self.shift_num = 3
20
+
21
+ self.lm_conf = GPT2Config(
22
+ vocab_size=self.hps.model['n2s_vocab_size'],
23
+ n_embd=self.hps.model['hidden_size'],
24
+ n_layer=self.hps.model['num_hidden_layers'],
25
+ n_head=self.hps.model['num_attention_heads'],
26
+ activation_function='gelu_new',
27
+ n_positions=2048,
28
+ n_ctx=2048,
29
+ resid_pdrop=0.1,
30
+ embd_pdrop=0.1,
31
+ attn_pdrop=0.1,
32
+ layer_norm_epsilon=1e-05,
33
+ initializer_range=0.02,
34
+ summary_type='mean',
35
+ summary_use_proj=True,
36
+ summary_activation=None,
37
+ summary_proj_to_labels=True,
38
+ summary_first_dropout=0.1,
39
+ bos_token_id=self.bos,
40
+ eos_token_id=self.eos,
41
+ )
42
+ self.lm = GPT2LMHeadModel(self.lm_conf)
43
+
44
+ def extract_semantic(self, wavs, num_frames):
45
+ padding_size = (0, 100)
46
+ wavs = F.pad(wavs, padding_size, "constant", 0)
47
+ num_frames += 100
48
+ features = self.wavlm.extract_features(
49
+ wavs,
50
+ output_layer=6,
51
+ ret_layer_results=False,
52
+ input_length=num_frames
53
+ )[0]
54
+ b, t, d = features.shape
55
+ tokens = self.km(features.reshape(-1, d), b=b, t=t)
56
+ return tokens
57
+
58
+ def inference(self, token_gen, pos_gen):
59
+ predict_len = (token_gen.shape[1] - 1)
60
+ truck_length = token_gen.shape[1]
61
+
62
+ for j in tqdm(range(predict_len)):
63
+ lm_outputs = self.lm(
64
+ input_ids=token_gen,
65
+ attention_mask=None,
66
+ position_ids=pos_gen
67
+ )
68
+ logits = lm_outputs['logits']
69
+ logits[:, :, 0:self.shift_num] = -1e5
70
+ probs = logits[:, -1, :].softmax(dim=-1)
71
+
72
+ dist = torch.distributions.categorical.Categorical(probs=probs)
73
+
74
+ samples = dist.sample().unsqueeze(1).to(token_gen.device)
75
+ token_gen = torch.cat([token_gen, samples], dim=1)
76
+ pos_pad = torch.ones(pos_gen.shape[0]) * j
77
+ pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
78
+
79
+ return token_gen[:,truck_length:][0]
80
+
81
+
82
+ def generate(self, mix):
83
+ mix = mix.squeeze(1)
84
+ num_frame = torch.LongTensor([mix.shape[1]]).to(mix.device)
85
+ token_s = self.extract_semantic(mix, num_frames=num_frame)
86
+
87
+ token_s += 3
88
+ bos = torch.ones(token_s.shape[0],1).long().to(mix.device)
89
+ token_gen = torch.cat([token_s, bos], dim=1)
90
+
91
+ pos_gen_id = torch.from_numpy(np.asarray(list(range(token_s.shape[1] + 1)))).to(mix.device)
92
+ pos_gen = []
93
+ for i in range(token_s.shape[0]):
94
+ pos_gen.append(pos_gen_id.unsqueeze(0))
95
+ pos_gen = torch.cat(pos_gen, dim=0)
96
+
97
+ clean_s = self.inference(token_gen, pos_gen) - self.shift_num
98
+ token_s -= self.shift_num
99
+ return token_s, clean_s
100
+
101
+
102
+ class S2S(nn.Module):
103
+ def __init__(self, hps):
104
+ super().__init__()
105
+ self.hps = hps
106
+ self.codec_tokenizer = SimCodec(hps.path['codec_config_path'])
107
+ self.wavlm, self.km = get_ssl_model(**hps.ssl_model)
108
+ self.bos = 1
109
+ self.eos = 2
110
+ self.pad = 0
111
+ self.shift_num = 3 + self.hps.model['semantic_num']
112
+ self.lm_conf = GPT2Config(
113
+ vocab_size=self.hps.model['s2s_vocab_size'],
114
+ n_embd=self.hps.model['hidden_size'],
115
+ n_layer=self.hps.model['num_hidden_layers'],
116
+ n_head=self.hps.model['num_attention_heads'],
117
+ activation_function='gelu_new',
118
+ n_positions=4096,
119
+ n_ctx=4096,
120
+ resid_pdrop=0.1,
121
+ embd_pdrop=0.1,
122
+ attn_pdrop=0.1,
123
+ layer_norm_epsilon=1e-05,
124
+ initializer_range=0.02,
125
+ summary_type='mean',
126
+ summary_use_proj=True,
127
+ summary_activation=None,
128
+ summary_proj_to_labels=True,
129
+ summary_first_dropout=0.1,
130
+ bos_token_id=self.bos,
131
+ eos_token_id=self.eos,
132
+ )
133
+ self.lm = GPT2LMHeadModel(self.lm_conf)
134
+
135
+ def inference(self, token_gen, pos_gen):
136
+ predict_len = int((token_gen.shape[1] - 1) / 2)
137
+ truck_length = token_gen.shape[1]
138
+ for j in tqdm(range(predict_len)):
139
+ lm_outputs = self.lm(
140
+ input_ids=token_gen,
141
+ attention_mask=None,
142
+ position_ids=pos_gen
143
+ )
144
+ logits = lm_outputs['logits']
145
+ logits[:, :, 0:self.shift_num] = -1e5
146
+ probs = logits[:, -1, :].softmax(dim=-1)
147
+ dist = torch.distributions.categorical.Categorical(probs=probs)
148
+ samples = dist.sample().unsqueeze(1).to(token_gen.device)
149
+ token_gen = torch.cat([token_gen, samples], dim=1)
150
+ pos_pad = torch.ones(pos_gen.shape[0]) * (j + 1000)
151
+ pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
152
+
153
+ return token_gen[:,truck_length:][0]
154
+
155
+ def generate(self, mix, mix_s, clean_s):
156
+ mix_a = self.codec_tokenizer(mix).squeeze(-1)
157
+ if len(clean_s.shape) == 1:
158
+ clean_s = clean_s.unsqueeze(0)
159
+
160
+ mix_s += 3
161
+ clean_s += 3
162
+ mix_a += self.shift_num
163
+
164
+ bos = torch.ones(mix_s.shape[0],1).long().to(mix.device)
165
+ token_gen = torch.cat([mix_s, clean_s, bos, mix_a], dim=1)
166
+
167
+ pos_gen_id = torch.from_numpy(np.asarray(list(range(mix_s.shape[1] + clean_s.shape[1] + 1)) + list(range(mix_a.shape[1])))).to(mix.device)
168
+ pos_gen = []
169
+ for i in range(mix_s.shape[0]):
170
+ pos_gen.append(pos_gen_id.unsqueeze(0))
171
+ pos_gen = torch.cat(pos_gen, dim=0)
172
+
173
+ pre_a = self.inference(token_gen, pos_gen) - self.shift_num
174
+ gen_wav = self.codec_tokenizer.decode(pre_a.unsqueeze(0).unsqueeze(2)).squeeze(0).cpu()
175
+
176
+ return gen_wav