yaoxunji commited on
Commit
a726cc5
·
verified ·
1 Parent(s): 4f1a5d6

Upload 6 files

Browse files
components/semantic_extractor/WavLM.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import logging
12
+ from typing import List, Optional, Tuple
13
+
14
+ import sys,os
15
+ sys.path.append(os.path.dirname(sys.path[0]))
16
+ import numpy as np
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn import LayerNorm
22
+ from .modules import (
23
+ Fp32GroupNorm,
24
+ Fp32LayerNorm,
25
+ GradMultiply,
26
+ MultiheadAttention,
27
+ SamePad,
28
+ init_bert_params,
29
+ get_activation_fn,
30
+ TransposeLast,
31
+ GLU_Linear,
32
+ )
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ def compute_mask_indices(
38
+ shape: Tuple[int, int],
39
+ padding_mask: Optional[torch.Tensor],
40
+ mask_prob: float,
41
+ mask_length: int,
42
+ mask_type: str = "static",
43
+ mask_other: float = 0.0,
44
+ min_masks: int = 0,
45
+ no_overlap: bool = False,
46
+ min_space: int = 0,
47
+ ) -> np.ndarray:
48
+ """
49
+ Computes random mask spans for a given shape
50
+
51
+ Args:
52
+ shape: the the shape for which to compute masks.
53
+ should be of size 2 where first element is batch size and 2nd is timesteps
54
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
55
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
56
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
57
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
58
+ mask_type: how to compute mask lengths
59
+ static = fixed size
60
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
61
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
62
+ poisson = sample from possion distribution with lambda = mask length
63
+ min_masks: minimum number of masked spans
64
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
65
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
66
+ """
67
+
68
+ bsz, all_sz = shape
69
+ mask = np.full((bsz, all_sz), False)
70
+
71
+ all_num_mask = int(
72
+ # add a random number for probabilistic rounding
73
+ mask_prob * all_sz / float(mask_length)
74
+ + np.random.rand()
75
+ )
76
+
77
+ all_num_mask = max(min_masks, all_num_mask)
78
+
79
+ mask_idcs = []
80
+ for i in range(bsz):
81
+ if padding_mask is not None:
82
+ sz = all_sz - padding_mask[i].long().sum().item()
83
+ num_mask = int(
84
+ # add a random number for probabilistic rounding
85
+ mask_prob * sz / float(mask_length)
86
+ + np.random.rand()
87
+ )
88
+ num_mask = max(min_masks, num_mask)
89
+ else:
90
+ sz = all_sz
91
+ num_mask = all_num_mask
92
+
93
+ if mask_type == "static":
94
+ lengths = np.full(num_mask, mask_length)
95
+ elif mask_type == "uniform":
96
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
97
+ elif mask_type == "normal":
98
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
99
+ lengths = [max(1, int(round(x))) for x in lengths]
100
+ elif mask_type == "poisson":
101
+ lengths = np.random.poisson(mask_length, size=num_mask)
102
+ lengths = [int(round(x)) for x in lengths]
103
+ else:
104
+ raise Exception("unknown mask selection " + mask_type)
105
+
106
+ if sum(lengths) == 0:
107
+ lengths[0] = min(mask_length, sz - 1)
108
+
109
+ if no_overlap:
110
+ mask_idc = []
111
+
112
+ def arrange(s, e, length, keep_length):
113
+ span_start = np.random.randint(s, e - length)
114
+ mask_idc.extend(span_start + i for i in range(length))
115
+
116
+ new_parts = []
117
+ if span_start - s - min_space >= keep_length:
118
+ new_parts.append((s, span_start - min_space + 1))
119
+ if e - span_start - keep_length - min_space > keep_length:
120
+ new_parts.append((span_start + length + min_space, e))
121
+ return new_parts
122
+
123
+ parts = [(0, sz)]
124
+ min_length = min(lengths)
125
+ for length in sorted(lengths, reverse=True):
126
+ lens = np.fromiter(
127
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
128
+ np.int,
129
+ )
130
+ l_sum = np.sum(lens)
131
+ if l_sum == 0:
132
+ break
133
+ probs = lens / np.sum(lens)
134
+ c = np.random.choice(len(parts), p=probs)
135
+ s, e = parts.pop(c)
136
+ parts.extend(arrange(s, e, length, min_length))
137
+ mask_idc = np.asarray(mask_idc)
138
+ else:
139
+ min_len = min(lengths)
140
+ if sz - min_len <= num_mask:
141
+ min_len = sz - num_mask - 1
142
+
143
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
144
+
145
+ mask_idc = np.asarray(
146
+ [
147
+ mask_idc[j] + offset
148
+ for j in range(len(mask_idc))
149
+ for offset in range(lengths[j])
150
+ ]
151
+ )
152
+
153
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
154
+
155
+ min_len = min([len(m) for m in mask_idcs])
156
+ for i, mask_idc in enumerate(mask_idcs):
157
+ if len(mask_idc) > min_len:
158
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
159
+ mask[i, mask_idc] = True
160
+
161
+ return mask
162
+
163
+
164
+ class WavLMConfig:
165
+ def __init__(self, cfg=None):
166
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
167
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
168
+
169
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
170
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
171
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
172
+ self.activation_fn: str = "gelu" # activation function to use
173
+
174
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
175
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
176
+ self.conv_bias: bool = False # include bias in conv encoder
177
+ self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
178
+
179
+ self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
180
+
181
+ # dropouts
182
+ self.dropout: float = 0.1 # dropout probability for the transformer
183
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
184
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
185
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
186
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
187
+ self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
188
+
189
+ # masking
190
+ self.mask_length: int = 10 # mask length
191
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
192
+ self.mask_selection: str = "static" # how to choose mask length
193
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
194
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
195
+ self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
196
+
197
+ # channel masking
198
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
199
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
200
+ self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
201
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
202
+ self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
203
+ self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
204
+
205
+ # positional embeddings
206
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
207
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
208
+
209
+ # relative position embedding
210
+ self.relative_position_embedding: bool = False # apply relative position embedding
211
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
212
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
213
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
214
+
215
+ if cfg is not None:
216
+ self.update(cfg)
217
+
218
+ def update(self, cfg: dict):
219
+ self.__dict__.update(cfg)
220
+
221
+
222
+ class WavLM(nn.Module):
223
+ def __init__(
224
+ self,
225
+ cfg: WavLMConfig,
226
+ ) -> None:
227
+ super().__init__()
228
+ logger.info(f"WavLM Config: {cfg.__dict__}")
229
+
230
+ self.cfg = cfg
231
+ feature_enc_layers = eval(cfg.conv_feature_layers)
232
+ self.embed = feature_enc_layers[-1][0]
233
+
234
+ self.feature_extractor = ConvFeatureExtractionModel(
235
+ conv_layers=feature_enc_layers,
236
+ dropout=0.0,
237
+ mode=cfg.extractor_mode,
238
+ conv_bias=cfg.conv_bias,
239
+ )
240
+
241
+ self.post_extract_proj = (
242
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
243
+ if self.embed != cfg.encoder_embed_dim
244
+ else None
245
+ )
246
+
247
+ self.mask_prob = cfg.mask_prob
248
+ self.mask_selection = cfg.mask_selection
249
+ self.mask_other = cfg.mask_other
250
+ self.mask_length = cfg.mask_length
251
+ self.no_mask_overlap = cfg.no_mask_overlap
252
+ self.mask_min_space = cfg.mask_min_space
253
+
254
+ self.mask_channel_prob = cfg.mask_channel_prob
255
+ self.mask_channel_selection = cfg.mask_channel_selection
256
+ self.mask_channel_other = cfg.mask_channel_other
257
+ self.mask_channel_length = cfg.mask_channel_length
258
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
259
+ self.mask_channel_min_space = cfg.mask_channel_min_space
260
+
261
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
262
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
263
+
264
+ self.feature_grad_mult = cfg.feature_grad_mult
265
+
266
+ self.mask_emb = nn.Parameter(
267
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
268
+ )
269
+
270
+ self.encoder = TransformerEncoder(cfg)
271
+ self.layer_norm = LayerNorm(self.embed)
272
+
273
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
274
+ """
275
+ Computes the output length of the convolutional layers
276
+ """
277
+
278
+ def _conv_out_length(input_length, kernel_size, stride):
279
+ return torch.floor((input_length - kernel_size) / stride + 1)
280
+
281
+ conv_cfg_list = eval(self.cfg.conv_feature_layers)
282
+
283
+ out_lengths_list = []
284
+ for i in range(len(conv_cfg_list)):
285
+ input_lengths = _conv_out_length(
286
+ input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
287
+ )
288
+ out_lengths_list.append(input_lengths)
289
+
290
+ return input_lengths.to(torch.long), out_lengths_list
291
+
292
+ def apply_mask(self, x, padding_mask):
293
+ B, T, C = x.shape
294
+ if self.mask_prob > 0:
295
+ mask_indices = compute_mask_indices(
296
+ (B, T),
297
+ padding_mask,
298
+ self.mask_prob,
299
+ self.mask_length,
300
+ self.mask_selection,
301
+ self.mask_other,
302
+ min_masks=2,
303
+ no_overlap=self.no_mask_overlap,
304
+ min_space=self.mask_min_space,
305
+ )
306
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
307
+ x[mask_indices] = self.mask_emb
308
+ else:
309
+ mask_indices = None
310
+
311
+ if self.mask_channel_prob > 0:
312
+ mask_channel_indices = compute_mask_indices(
313
+ (B, C),
314
+ None,
315
+ self.mask_channel_prob,
316
+ self.mask_channel_length,
317
+ self.mask_channel_selection,
318
+ self.mask_channel_other,
319
+ no_overlap=self.no_mask_channel_overlap,
320
+ min_space=self.mask_channel_min_space,
321
+ )
322
+ mask_channel_indices = (
323
+ torch.from_numpy(mask_channel_indices)
324
+ .to(x.device)
325
+ .unsqueeze(1)
326
+ .expand(-1, T, -1)
327
+ )
328
+ x[mask_channel_indices] = 0
329
+
330
+ return x, mask_indices
331
+
332
+ def forward_padding_mask(
333
+ self, features: torch.Tensor, padding_mask: torch.Tensor,
334
+ ) -> torch.Tensor:
335
+ extra = padding_mask.size(1) % features.size(1)
336
+ if extra > 0:
337
+ padding_mask = padding_mask[:, :-extra]
338
+ padding_mask = padding_mask.view(
339
+ padding_mask.size(0), features.size(1), -1
340
+ )
341
+ padding_mask = padding_mask.all(-1)
342
+ return padding_mask
343
+
344
+ def sequence_mask(self, sequence_length, max_len=None):
345
+ """Create a sequence mask for filtering padding in a sequence tensor.
346
+ Args:
347
+ sequence_length (torch.tensor): Sequence lengths.
348
+ max_len (int, Optional): Maximum sequence length. Defaults to None.
349
+ Shapes:
350
+ - mask: :math:`[B, T_max]`
351
+ """
352
+ if max_len is None:
353
+ max_len = sequence_length.data.max()
354
+ seq_range = torch.arange(max_len,
355
+ dtype=sequence_length.dtype,
356
+ device=sequence_length.device)
357
+ # B x T_max
358
+ mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
359
+ return mask
360
+
361
+ def extract_features(
362
+ self,
363
+ source: torch.Tensor,
364
+ padding_mask: Optional[torch.Tensor] = None,
365
+ mask: bool = False,
366
+ ret_conv: bool = False,
367
+ output_layer: Optional[int] = None,
368
+ ret_layer_results: bool = False,
369
+ input_length: Optional[torch.Tensor] = None
370
+ ):
371
+ out_lengths_list = None
372
+ if input_length is not None:
373
+ out_conv_lengths, out_lengths_list = self._get_feat_extract_output_lengths(input_length)
374
+ else:
375
+ out_conv_lengths, out_lengths_list = self._get_feat_extract_output_lengths(torch.tensor([source.shape[-1] for _ in range(source.shape[0])]).to(source.device))
376
+
377
+ if self.feature_grad_mult > 0:
378
+ features = self.feature_extractor(source, input_lengths=input_length, out_lengths_list=out_lengths_list)
379
+ if self.feature_grad_mult != 1.0:
380
+ features = GradMultiply.apply(features, self.feature_grad_mult)
381
+ else:
382
+ with torch.no_grad():
383
+ features = self.feature_extractor(source)
384
+
385
+ features = features.transpose(1, 2)
386
+ features = self.layer_norm(features)
387
+
388
+ # if padding_mask is not None:
389
+ # padding_mask = self.forward_padding_mask(features, padding_mask)
390
+
391
+ if self.post_extract_proj is not None:
392
+ features *= self.sequence_mask(out_conv_lengths).unsqueeze(-1)
393
+ features = self.post_extract_proj(features)
394
+ features *= self.sequence_mask(out_conv_lengths).unsqueeze(-1)
395
+
396
+
397
+ features = self.dropout_input(features)
398
+ # return features
399
+
400
+ if mask:
401
+ x, mask_indices = self.apply_mask(
402
+ features, padding_mask
403
+ )
404
+ else:
405
+ x = features
406
+
407
+ # feature: (B, T, D), float
408
+ # target: (B, T), long
409
+ # x: (B, T, D), float
410
+ # padding_mask: (B, T), bool
411
+ # mask_indices: (B, T), bool
412
+ if source.shape[0] == 1:
413
+ padding_mask = None
414
+ else:
415
+ padding_mask = ~self.sequence_mask(out_conv_lengths)
416
+
417
+ x, layer_results = self.encoder(
418
+ x,
419
+ padding_mask=padding_mask,
420
+ layer=None if output_layer is None else output_layer - 1
421
+ )
422
+
423
+ res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
424
+
425
+ feature = res["features"] if ret_conv else res["x"]
426
+ if ret_layer_results:
427
+ feature = (feature, res["layer_results"])
428
+ return feature, res["padding_mask"]
429
+
430
+
431
+ def long_term_modeling(
432
+ self,
433
+ source: torch.Tensor,
434
+ padding_mask: Optional[torch.Tensor] = None,
435
+ mask: bool = False,
436
+ ret_conv: bool = False,
437
+ output_layer: Optional[int] = None,
438
+ ret_layer_results: bool = False,
439
+ ):
440
+
441
+ features = source.transpose(1, 2)
442
+ features = self.layer_norm(features)
443
+
444
+ if padding_mask is not None:
445
+ padding_mask = self.forward_padding_mask(features, padding_mask)
446
+
447
+ if self.post_extract_proj is not None:
448
+ features = self.post_extract_proj(features)
449
+
450
+ features = self.dropout_input(features)
451
+
452
+ if mask:
453
+ x, mask_indices = self.apply_mask(
454
+ features, padding_mask
455
+ )
456
+ else:
457
+ x = features
458
+
459
+ # feature: (B, T, D), float
460
+ # target: (B, T), long
461
+ # x: (B, T, D), float
462
+ # padding_mask: (B, T), bool
463
+ # mask_indices: (B, T), bool
464
+ x, layer_results = self.encoder(
465
+ x,
466
+ padding_mask=padding_mask,
467
+ layer=None if output_layer is None else output_layer - 1
468
+ )
469
+
470
+ res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
471
+
472
+ feature = res["features"] if ret_conv else res["x"]
473
+ if ret_layer_results:
474
+ feature = (feature, res["layer_results"])
475
+ return feature, res["padding_mask"]
476
+
477
+
478
+
479
+ class ConvFeatureExtractionModel(nn.Module):
480
+ def __init__(
481
+ self,
482
+ conv_layers: List[Tuple[int, int, int]],
483
+ dropout: float = 0.0,
484
+ mode: str = "default",
485
+ conv_bias: bool = False,
486
+ conv_type: str = "default"
487
+ ):
488
+ super().__init__()
489
+
490
+ assert mode in {"default", "layer_norm"}
491
+
492
+ def block(
493
+ n_in,
494
+ n_out,
495
+ k,
496
+ stride,
497
+ is_layer_norm=False,
498
+ is_group_norm=False,
499
+ conv_bias=False,
500
+ ):
501
+ def make_conv():
502
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
503
+ nn.init.kaiming_normal_(conv.weight)
504
+ return conv
505
+
506
+ assert (
507
+ is_layer_norm and is_group_norm
508
+ ) == False, "layer norm and group norm are exclusive"
509
+
510
+ if is_layer_norm:
511
+ return nn.Sequential(
512
+ make_conv(),
513
+ nn.Dropout(p=dropout),
514
+ nn.Sequential(
515
+ TransposeLast(),
516
+ Fp32LayerNorm(dim, elementwise_affine=True),
517
+ TransposeLast(),
518
+ ),
519
+ nn.GELU(),
520
+ )
521
+ # elif is_group_norm:
522
+ # return nn.Sequential(
523
+ # make_conv(),
524
+ # nn.Dropout(p=dropout),
525
+ # Fp32GroupNorm(dim, dim, affine=True),
526
+ # nn.GELU(),
527
+ # )
528
+ # else:
529
+ # return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
530
+
531
+ self.conv_type = conv_type
532
+ if self.conv_type == "default":
533
+ in_d = 1
534
+ self.conv_layers = nn.ModuleList()
535
+ for i, cl in enumerate(conv_layers):
536
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
537
+ (dim, k, stride) = cl
538
+
539
+ self.conv_layers.append(
540
+ block(
541
+ in_d,
542
+ dim,
543
+ k,
544
+ stride,
545
+ is_layer_norm=mode == "layer_norm",
546
+ is_group_norm=mode == "default" and i == 0,
547
+ conv_bias=conv_bias,
548
+ )
549
+ )
550
+ in_d = dim
551
+ elif self.conv_type == "conv2d":
552
+ in_d = 1
553
+ self.conv_layers = nn.ModuleList()
554
+ for i, cl in enumerate(conv_layers):
555
+ assert len(cl) == 3
556
+ (dim, k, stride) = cl
557
+
558
+ self.conv_layers.append(
559
+ torch.nn.Conv2d(in_d, dim, k, stride)
560
+ )
561
+ self.conv_layers.append(torch.nn.ReLU())
562
+ in_d = dim
563
+ elif self.conv_type == "custom":
564
+ in_d = 1
565
+ idim = 80
566
+ self.conv_layers = nn.ModuleList()
567
+ for i, cl in enumerate(conv_layers):
568
+ assert len(cl) == 3
569
+ (dim, k, stride) = cl
570
+ self.conv_layers.append(
571
+ torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
572
+ )
573
+ self.conv_layers.append(
574
+ torch.nn.LayerNorm([dim, idim])
575
+ )
576
+ self.conv_layers.append(torch.nn.ReLU())
577
+ in_d = dim
578
+ if (i + 1) % 2 == 0:
579
+ self.conv_layers.append(
580
+ torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
581
+ )
582
+ idim = int(math.ceil(idim / 2))
583
+ else:
584
+ pass
585
+
586
+ def sequence_mask(self, sequence_length, max_len=None):
587
+ """Create a sequence mask for filtering padding in a sequence tensor.
588
+ Args:
589
+ sequence_length (torch.tensor): Sequence lengths.
590
+ max_len (int, Optional): Maximum sequence length. Defaults to None.
591
+ Shapes:
592
+ - mask: :math:`[B, T_max]`
593
+ """
594
+ if max_len is None:
595
+ max_len = sequence_length.data.max()
596
+ seq_range = torch.arange(max_len,
597
+ dtype=sequence_length.dtype,
598
+ device=sequence_length.device)
599
+ # B x T_max
600
+ mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
601
+ return mask
602
+
603
+ def forward(self, x, mask=None, input_lengths=None, out_lengths_list=None):
604
+
605
+ # BxT -> BxCxT
606
+ x = x.unsqueeze(1)
607
+ # if self.conv_type == "custom":
608
+ # for conv in self.conv_layers:
609
+ # if isinstance(conv, nn.LayerNorm):
610
+ # x = x.transpose(1, 2)
611
+ # x = conv(x).transpose(1, 2)
612
+ # else:
613
+ # x = conv(x)
614
+ # x = x.transpose(2, 3).contiguous()
615
+ # x = x.view(x.size(0), -1, x.size(-1))
616
+ # else:
617
+
618
+ for idx, conv in enumerate(self.conv_layers):
619
+ x = conv(x)
620
+ # if idx == 0:
621
+ # x = conv(x * self.sequence_mask(input_lengths).unsqueeze(1))
622
+ # else:
623
+ # if len(out_lengths_list[idx-1]) == 1:
624
+ # x = conv(x * self.sequence_mask(out_lengths_list[idx-1]))
625
+ # else:
626
+ # x = conv(x * self.sequence_mask(out_lengths_list[idx-1]).unsqueeze(1))
627
+ # if len(out_lengths_list[idx-1]) == 1:
628
+ # x *= self.sequence_mask(out_lengths_list[idx].unsqueeze(0))
629
+ # else:
630
+ # x *= self.sequence_mask(out_lengths_list[idx].unsqueeze(1))
631
+ # if self.conv_type == "conv2d":
632
+ # b, c, t, f = x.size()
633
+ # x = x.transpose(2, 3).contiguous().view(b, c * f, t)
634
+ return x
635
+
636
+
637
+ class TransformerEncoder(nn.Module):
638
+ def __init__(self, args):
639
+ super().__init__()
640
+
641
+ self.dropout = args.dropout
642
+ self.embedding_dim = args.encoder_embed_dim
643
+
644
+ self.pos_conv = nn.Conv1d(
645
+ self.embedding_dim,
646
+ self.embedding_dim,
647
+ kernel_size=args.conv_pos,
648
+ padding=args.conv_pos // 2,
649
+ groups=args.conv_pos_groups,
650
+ )
651
+ dropout = 0
652
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
653
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
654
+ nn.init.constant_(self.pos_conv.bias, 0)
655
+
656
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
657
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
658
+
659
+ if hasattr(args, "relative_position_embedding"):
660
+ self.relative_position_embedding = args.relative_position_embedding
661
+ self.num_buckets = args.num_buckets
662
+ self.max_distance = args.max_distance
663
+ else:
664
+ self.relative_position_embedding = False
665
+ self.num_buckets = 0
666
+ self.max_distance = 0
667
+
668
+ self.layers = nn.ModuleList(
669
+ [
670
+ TransformerSentenceEncoderLayer(
671
+ embedding_dim=self.embedding_dim,
672
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
673
+ num_attention_heads=args.encoder_attention_heads,
674
+ dropout=self.dropout,
675
+ attention_dropout=args.attention_dropout,
676
+ activation_dropout=args.activation_dropout,
677
+ activation_fn=args.activation_fn,
678
+ layer_norm_first=args.layer_norm_first,
679
+ has_relative_attention_bias=(self.relative_position_embedding and i == 0),
680
+ num_buckets=self.num_buckets,
681
+ max_distance=self.max_distance,
682
+ gru_rel_pos=args.gru_rel_pos,
683
+ )
684
+ for i in range(args.encoder_layers)
685
+ ]
686
+ )
687
+
688
+ self.layer_norm_first = args.layer_norm_first
689
+ self.layer_norm = LayerNorm(self.embedding_dim)
690
+ self.layerdrop = args.encoder_layerdrop
691
+
692
+ self.apply(init_bert_params)
693
+
694
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
695
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
696
+
697
+ if self.layer_norm_first and layer is None:
698
+ x = self.layer_norm(x)
699
+
700
+ return x, layer_results
701
+
702
+ def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
703
+
704
+ if padding_mask is not None:
705
+ x[padding_mask] = 0
706
+
707
+ y = x.transpose(1, 2).clone()
708
+ x_conv = self.pos_conv(y)
709
+ x_conv = x_conv.transpose(1, 2)
710
+ x += x_conv
711
+
712
+ if not self.layer_norm_first:
713
+ x = self.layer_norm(x)
714
+
715
+ x = F.dropout(x, p=self.dropout, training=self.training)
716
+
717
+ # B x T x C -> T x B x C
718
+ x = x.transpose(0, 1)
719
+
720
+ layer_results = []
721
+ z = None
722
+ if tgt_layer is not None:
723
+ layer_results.append((x, z))
724
+ r = None
725
+ pos_bias = None
726
+ for i, layer in enumerate(self.layers):
727
+ dropout_probability = np.random.random()
728
+ if not self.training or (dropout_probability > self.layerdrop):
729
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
730
+ self_attn_mask=streaming_mask, pos_bias=pos_bias)
731
+ if tgt_layer is not None:
732
+ layer_results.append((x, z))
733
+ if i == tgt_layer:
734
+ r = x
735
+ break
736
+
737
+ if r is not None:
738
+ x = r
739
+
740
+ # T x B x C -> B x T x C
741
+ x = x.transpose(0, 1)
742
+
743
+ return x, layer_results
744
+
745
+
746
+ class TransformerSentenceEncoderLayer(nn.Module):
747
+ """
748
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
749
+ models.
750
+ """
751
+
752
+ def __init__(
753
+ self,
754
+ embedding_dim: float = 768,
755
+ ffn_embedding_dim: float = 3072,
756
+ num_attention_heads: float = 8,
757
+ dropout: float = 0.1,
758
+ attention_dropout: float = 0.1,
759
+ activation_dropout: float = 0.1,
760
+ activation_fn: str = "relu",
761
+ layer_norm_first: bool = False,
762
+ has_relative_attention_bias: bool = False,
763
+ num_buckets: int = 0,
764
+ max_distance: int = 0,
765
+ rescale_init: bool = False,
766
+ gru_rel_pos: bool = False,
767
+ ) -> None:
768
+
769
+ super().__init__()
770
+ # Initialize parameters
771
+ self.embedding_dim = embedding_dim
772
+ self.dropout = dropout
773
+ self.activation_dropout = activation_dropout
774
+
775
+ # Initialize blocks
776
+ self.activation_name = activation_fn
777
+ self.activation_fn = get_activation_fn(activation_fn)
778
+ self.self_attn = MultiheadAttention(
779
+ self.embedding_dim,
780
+ num_attention_heads,
781
+ dropout=attention_dropout,
782
+ self_attention=True,
783
+ has_relative_attention_bias=has_relative_attention_bias,
784
+ num_buckets=num_buckets,
785
+ max_distance=max_distance,
786
+ rescale_init=rescale_init,
787
+ gru_rel_pos=gru_rel_pos,
788
+ )
789
+
790
+ self.dropout1 = nn.Dropout(dropout)
791
+ self.dropout2 = nn.Dropout(self.activation_dropout)
792
+ self.dropout3 = nn.Dropout(dropout)
793
+
794
+ self.layer_norm_first = layer_norm_first
795
+
796
+ # layer norm associated with the self attention layer
797
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
798
+
799
+ if self.activation_name == "glu":
800
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
801
+ else:
802
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
803
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
804
+
805
+ # layer norm associated with the position wise feed-forward NN
806
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
807
+
808
+ def forward(
809
+ self,
810
+ x: torch.Tensor,
811
+ self_attn_mask: torch.Tensor = None,
812
+ self_attn_padding_mask: torch.Tensor = None,
813
+ need_weights: bool = False,
814
+ pos_bias=None
815
+ ):
816
+ """
817
+ LayerNorm is applied either before or after the self-attention/ffn
818
+ modules similar to the original Transformer imlementation.
819
+ """
820
+ residual = x
821
+
822
+ if self.layer_norm_first:
823
+ x = self.self_attn_layer_norm(x)
824
+ x, attn, pos_bias = self.self_attn(
825
+ query=x,
826
+ key=x,
827
+ value=x,
828
+ key_padding_mask=self_attn_padding_mask,
829
+ need_weights=False,
830
+ attn_mask=self_attn_mask,
831
+ position_bias=pos_bias
832
+ )
833
+ x = self.dropout1(x)
834
+ x = residual + x
835
+
836
+ residual = x
837
+ x = self.final_layer_norm(x)
838
+ if self.activation_name == "glu":
839
+ x = self.fc1(x)
840
+ else:
841
+ x = self.activation_fn(self.fc1(x))
842
+ x = self.dropout2(x)
843
+ x = self.fc2(x)
844
+ x = self.dropout3(x)
845
+ x = residual + x
846
+ else:
847
+ x, attn, pos_bias = self.self_attn(
848
+ query=x,
849
+ key=x,
850
+ value=x,
851
+ key_padding_mask=self_attn_padding_mask,
852
+ need_weights=need_weights,
853
+ attn_mask=self_attn_mask,
854
+ position_bias=pos_bias
855
+ )
856
+
857
+ x = self.dropout1(x)
858
+ x = residual + x
859
+
860
+ x = self.self_attn_layer_norm(x)
861
+
862
+ residual = x
863
+ if self.activation_name == "glu":
864
+ x = self.fc1(x)
865
+ else:
866
+ x = self.activation_fn(self.fc1(x))
867
+ x = self.dropout2(x)
868
+ x = self.fc2(x)
869
+ x = self.dropout3(x)
870
+ x = residual + x
871
+ x = self.final_layer_norm(x)
872
+
873
+ return x, attn, pos_bias
components/semantic_extractor/modules.py ADDED
@@ -0,0 +1,825 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ from torch.nn import Parameter
16
+ import torch.nn.functional as F
17
+
18
+ class TransposeLast(nn.Module):
19
+ def __init__(self, deconstruct_idx=None):
20
+ super().__init__()
21
+ self.deconstruct_idx = deconstruct_idx
22
+
23
+ def forward(self, x):
24
+ if self.deconstruct_idx is not None:
25
+ x = x[self.deconstruct_idx]
26
+ return x.transpose(-2, -1)
27
+
28
+
29
+ class Fp32LayerNorm(nn.LayerNorm):
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+
33
+ def forward(self, input):
34
+ output = F.layer_norm(
35
+ input.float(),
36
+ self.normalized_shape,
37
+ self.weight.float() if self.weight is not None else None,
38
+ self.bias.float() if self.bias is not None else None,
39
+ self.eps,
40
+ )
41
+ return output.type_as(input)
42
+
43
+
44
+ class Fp32GroupNorm(nn.GroupNorm):
45
+ def __init__(self, *args, **kwargs):
46
+ super().__init__(*args, **kwargs)
47
+
48
+ def forward(self, input):
49
+ output = F.group_norm(
50
+ input.float(),
51
+ self.num_groups,
52
+ self.weight.float() if self.weight is not None else None,
53
+ self.bias.float() if self.bias is not None else None,
54
+ self.eps,
55
+ )
56
+ return output.type_as(input)
57
+
58
+
59
+ class GradMultiply(torch.autograd.Function):
60
+ @staticmethod
61
+ def forward(ctx, x, scale):
62
+ ctx.scale = scale
63
+ res = x.new(x)
64
+ return res
65
+
66
+ @staticmethod
67
+ def backward(ctx, grad):
68
+ return grad * ctx.scale, None
69
+
70
+
71
+ class SamePad(nn.Module):
72
+ def __init__(self, kernel_size, causal=False):
73
+ super().__init__()
74
+ if causal:
75
+ self.remove = kernel_size - 1
76
+ else:
77
+ self.remove = 1 if kernel_size % 2 == 0 else 0
78
+
79
+ def forward(self, x):
80
+ if self.remove > 0:
81
+ x = x[:, :, : -self.remove]
82
+ return x
83
+
84
+
85
+ class Swish(nn.Module):
86
+ """Swish function
87
+ """
88
+
89
+ def __init__(self):
90
+ """Construct an MultiHeadedAttention object."""
91
+ super(Swish, self).__init__()
92
+ self.act = torch.nn.Sigmoid()
93
+
94
+ def forward(self, x):
95
+ return x * self.act(x)
96
+
97
+
98
+ class GLU_Linear(nn.Module):
99
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
100
+ super(GLU_Linear, self).__init__()
101
+
102
+ self.glu_type = glu_type
103
+ self.output_dim = output_dim
104
+
105
+ if glu_type == "sigmoid":
106
+ self.glu_act = torch.nn.Sigmoid()
107
+ elif glu_type == "swish":
108
+ self.glu_act = Swish()
109
+ elif glu_type == "relu":
110
+ self.glu_act = torch.nn.ReLU()
111
+ elif glu_type == "gelu":
112
+ self.glu_act = torch.nn.GELU()
113
+
114
+ if bias_in_glu:
115
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
116
+ else:
117
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
118
+
119
+ def forward(self, x):
120
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
121
+ x = self.linear(x)
122
+
123
+ if self.glu_type == "bilinear":
124
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
125
+ else:
126
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
127
+
128
+ return x
129
+
130
+ def gelu_accurate(x):
131
+ if not hasattr(gelu_accurate, "_a"):
132
+ gelu_accurate._a = math.sqrt(2 / math.pi)
133
+ return (
134
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
135
+ )
136
+
137
+
138
+ def gelu(x: torch.Tensor) -> torch.Tensor:
139
+ return torch.nn.functional.gelu(x.float()).type_as(x)
140
+
141
+
142
+ def get_activation_fn(activation: str):
143
+ """Returns the activation function corresponding to `activation`"""
144
+
145
+ if activation == "relu":
146
+ return F.relu
147
+ elif activation == "gelu":
148
+ return gelu
149
+ elif activation == "gelu_fast":
150
+ warnings.warn(
151
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
152
+ )
153
+ return gelu_accurate
154
+ elif activation == "gelu_accurate":
155
+ return gelu_accurate
156
+ elif activation == "tanh":
157
+ return torch.tanh
158
+ elif activation == "linear":
159
+ return lambda x: x
160
+ elif activation == "glu":
161
+ return lambda x: x
162
+ else:
163
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
164
+
165
+
166
+ def init_bert_params(module):
167
+ """
168
+ Initialize the weights specific to the BERT Model.
169
+ This overrides the default initializations depending on the specified arguments.
170
+ 1. If normal_init_linear_weights is set then weights of linear
171
+ layer will be initialized using the normal distribution and
172
+ bais will be set to the specified value.
173
+ 2. If normal_init_embed_weights is set then weights of embedding
174
+ layer will be initialized using the normal distribution.
175
+ 3. If normal_init_proj_weights is set then weights of
176
+ in_project_weight for MultiHeadAttention initialized using
177
+ the normal distribution (to be validated).
178
+ """
179
+
180
+ def normal_(data):
181
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
182
+ # so that the RNG is consistent with and without FSDP
183
+ data.copy_(
184
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
185
+ )
186
+
187
+ if isinstance(module, nn.Linear):
188
+ normal_(module.weight.data)
189
+ if module.bias is not None:
190
+ module.bias.data.zero_()
191
+ if isinstance(module, nn.Embedding):
192
+ normal_(module.weight.data)
193
+ if module.padding_idx is not None:
194
+ module.weight.data[module.padding_idx].zero_()
195
+ if isinstance(module, MultiheadAttention):
196
+ normal_(module.q_proj.weight.data)
197
+ normal_(module.k_proj.weight.data)
198
+ normal_(module.v_proj.weight.data)
199
+
200
+
201
+ def quant_noise(module, p, block_size):
202
+ """
203
+ Wraps modules and applies quantization noise to the weights for
204
+ subsequent quantization with Iterative Product Quantization as
205
+ described in "Training with Quantization Noise for Extreme Model Compression"
206
+
207
+ Args:
208
+ - module: nn.Module
209
+ - p: amount of Quantization Noise
210
+ - block_size: size of the blocks for subsequent quantization with iPQ
211
+
212
+ Remarks:
213
+ - Module weights must have the right sizes wrt the block size
214
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
215
+ - For more detail on how to quantize by blocks with convolutional weights,
216
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
217
+ - We implement the simplest form of noise here as stated in the paper
218
+ which consists in randomly dropping blocks
219
+ """
220
+
221
+ # if no quantization noise, don't register hook
222
+ if p <= 0:
223
+ return module
224
+
225
+ # supported modules
226
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
227
+
228
+ # test whether module.weight has the right sizes wrt block_size
229
+ is_conv = module.weight.ndim == 4
230
+
231
+ # 2D matrix
232
+ if not is_conv:
233
+ assert (
234
+ module.weight.size(1) % block_size == 0
235
+ ), "Input features must be a multiple of block sizes"
236
+
237
+ # 4D matrix
238
+ else:
239
+ # 1x1 convolutions
240
+ if module.kernel_size == (1, 1):
241
+ assert (
242
+ module.in_channels % block_size == 0
243
+ ), "Input channels must be a multiple of block sizes"
244
+ # regular convolutions
245
+ else:
246
+ k = module.kernel_size[0] * module.kernel_size[1]
247
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
248
+
249
+ def _forward_pre_hook(mod, input):
250
+ # no noise for evaluation
251
+ if mod.training:
252
+ if not is_conv:
253
+ # gather weight and sizes
254
+ weight = mod.weight
255
+ in_features = weight.size(1)
256
+ out_features = weight.size(0)
257
+
258
+ # split weight matrix into blocks and randomly drop selected blocks
259
+ mask = torch.zeros(
260
+ in_features // block_size * out_features, device=weight.device
261
+ )
262
+ mask.bernoulli_(p)
263
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
264
+
265
+ else:
266
+ # gather weight and sizes
267
+ weight = mod.weight
268
+ in_channels = mod.in_channels
269
+ out_channels = mod.out_channels
270
+
271
+ # split weight matrix into blocks and randomly drop selected blocks
272
+ if mod.kernel_size == (1, 1):
273
+ mask = torch.zeros(
274
+ int(in_channels // block_size * out_channels),
275
+ device=weight.device,
276
+ )
277
+ mask.bernoulli_(p)
278
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
279
+ else:
280
+ mask = torch.zeros(
281
+ weight.size(0), weight.size(1), device=weight.device
282
+ )
283
+ mask.bernoulli_(p)
284
+ mask = (
285
+ mask.unsqueeze(2)
286
+ .unsqueeze(3)
287
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
288
+ )
289
+
290
+ # scale weights and apply mask
291
+ mask = mask.to(
292
+ torch.bool
293
+ ) # x.bool() is not currently supported in TorchScript
294
+ s = 1 / (1 - p)
295
+ mod.weight.data = s * weight.masked_fill(mask, 0)
296
+
297
+ module.register_forward_pre_hook(_forward_pre_hook)
298
+ return module
299
+
300
+
301
+ class MultiheadAttention(nn.Module):
302
+ """Multi-headed attention.
303
+
304
+ See "Attention Is All You Need" for more details.
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ embed_dim,
310
+ num_heads,
311
+ kdim=None,
312
+ vdim=None,
313
+ dropout=0.0,
314
+ bias=True,
315
+ add_bias_kv=False,
316
+ add_zero_attn=False,
317
+ self_attention=False,
318
+ encoder_decoder_attention=False,
319
+ q_noise=0.0,
320
+ qn_block_size=8,
321
+ has_relative_attention_bias=False,
322
+ num_buckets=32,
323
+ max_distance=128,
324
+ gru_rel_pos=False,
325
+ rescale_init=False,
326
+ ):
327
+ super().__init__()
328
+ self.embed_dim = embed_dim
329
+ self.kdim = kdim if kdim is not None else embed_dim
330
+ self.vdim = vdim if vdim is not None else embed_dim
331
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
332
+
333
+ self.num_heads = num_heads
334
+ self.dropout_module = nn.Dropout(dropout)
335
+
336
+ self.has_relative_attention_bias = has_relative_attention_bias
337
+ self.num_buckets = num_buckets
338
+ self.max_distance = max_distance
339
+ if self.has_relative_attention_bias:
340
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
341
+
342
+ self.head_dim = embed_dim // num_heads
343
+ self.q_head_dim = self.head_dim
344
+ self.k_head_dim = self.head_dim
345
+ assert (
346
+ self.head_dim * num_heads == self.embed_dim
347
+ ), "embed_dim must be divisible by num_heads"
348
+ self.scaling = self.head_dim ** -0.5
349
+
350
+ self.self_attention = self_attention
351
+ self.encoder_decoder_attention = encoder_decoder_attention
352
+
353
+ assert not self.self_attention or self.qkv_same_dim, (
354
+ "Self-attention requires query, key and " "value to be of the same size"
355
+ )
356
+
357
+ k_bias = True
358
+ if rescale_init:
359
+ k_bias = False
360
+
361
+ k_embed_dim = embed_dim
362
+ q_embed_dim = embed_dim
363
+
364
+ self.k_proj = quant_noise(
365
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
366
+ )
367
+ self.v_proj = quant_noise(
368
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
369
+ )
370
+ self.q_proj = quant_noise(
371
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
372
+ )
373
+
374
+ self.out_proj = quant_noise(
375
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
376
+ )
377
+
378
+ if add_bias_kv:
379
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
380
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
381
+ else:
382
+ self.bias_k = self.bias_v = None
383
+
384
+ self.add_zero_attn = add_zero_attn
385
+
386
+ self.gru_rel_pos = gru_rel_pos
387
+ if self.gru_rel_pos:
388
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
389
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
390
+
391
+ self.reset_parameters()
392
+
393
+ def reset_parameters(self):
394
+ if self.qkv_same_dim:
395
+ # Empirically observed the convergence to be much better with
396
+ # the scaled initialization
397
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
398
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
399
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
400
+ else:
401
+ nn.init.xavier_uniform_(self.k_proj.weight)
402
+ nn.init.xavier_uniform_(self.v_proj.weight)
403
+ nn.init.xavier_uniform_(self.q_proj.weight)
404
+
405
+ nn.init.xavier_uniform_(self.out_proj.weight)
406
+ if self.out_proj.bias is not None:
407
+ nn.init.constant_(self.out_proj.bias, 0.0)
408
+ if self.bias_k is not None:
409
+ nn.init.xavier_normal_(self.bias_k)
410
+ if self.bias_v is not None:
411
+ nn.init.xavier_normal_(self.bias_v)
412
+ if self.has_relative_attention_bias:
413
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
414
+
415
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
416
+ num_buckets = self.num_buckets
417
+ max_distance = self.max_distance
418
+ relative_buckets = 0
419
+
420
+ if bidirectional:
421
+ num_buckets = num_buckets // 2
422
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
423
+ relative_positions = torch.abs(relative_positions)
424
+ else:
425
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
426
+
427
+ max_exact = num_buckets // 2
428
+ is_small = relative_positions < max_exact
429
+
430
+ relative_postion_if_large = max_exact + (
431
+ torch.log(relative_positions.float() / max_exact)
432
+ / math.log(max_distance / max_exact)
433
+ * (num_buckets - max_exact)
434
+ ).to(torch.long)
435
+ relative_postion_if_large = torch.min(
436
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
437
+ )
438
+
439
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
440
+ return relative_buckets
441
+
442
+ def compute_bias(self, query_length, key_length):
443
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
444
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
445
+ relative_position = memory_position - context_position
446
+ relative_position_bucket = self._relative_positions_bucket(
447
+ relative_position,
448
+ bidirectional=True
449
+ )
450
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
451
+ values = self.relative_attention_bias(relative_position_bucket)
452
+ values = values.permute([2, 0, 1])
453
+ return values
454
+
455
+ def forward(
456
+ self,
457
+ query,
458
+ key: Optional[Tensor],
459
+ value: Optional[Tensor],
460
+ key_padding_mask: Optional[Tensor] = None,
461
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
462
+ need_weights: bool = True,
463
+ static_kv: bool = False,
464
+ attn_mask: Optional[Tensor] = None,
465
+ before_softmax: bool = False,
466
+ need_head_weights: bool = False,
467
+ position_bias: Optional[Tensor] = None
468
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
469
+ """Input shape: Time x Batch x Channel
470
+
471
+ Args:
472
+ key_padding_mask (ByteTensor, optional): mask to exclude
473
+ keys that are pads, of shape `(batch, src_len)`, where
474
+ padding elements are indicated by 1s.
475
+ need_weights (bool, optional): return the attention weights,
476
+ averaged over heads (default: False).
477
+ attn_mask (ByteTensor, optional): typically used to
478
+ implement causal attention, where the mask prevents the
479
+ attention from looking forward in time (default: None).
480
+ before_softmax (bool, optional): return the raw attention
481
+ weights and values before the attention softmax.
482
+ need_head_weights (bool, optional): return the attention
483
+ weights for each head. Implies *need_weights*. Default:
484
+ return the average attention weights over all heads.
485
+ """
486
+ if need_head_weights:
487
+ need_weights = True
488
+
489
+ is_tpu = query.device.type == "xla"
490
+
491
+ tgt_len, bsz, embed_dim = query.size()
492
+ src_len = tgt_len
493
+ assert embed_dim == self.embed_dim
494
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
495
+ if key is not None:
496
+ src_len, key_bsz, _ = key.size()
497
+ if not torch.jit.is_scripting():
498
+ assert key_bsz == bsz
499
+ assert value is not None
500
+ assert src_len, bsz == value.shape[:2]
501
+
502
+ if self.has_relative_attention_bias and position_bias is None:
503
+ position_bias = self.compute_bias(tgt_len, src_len)
504
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
505
+
506
+ if (
507
+ not is_tpu # don't use PyTorch version on TPUs
508
+ and incremental_state is None
509
+ and not static_kv
510
+ # A workaround for quantization to work. Otherwise JIT compilation
511
+ # treats bias in linear module as method.
512
+ and not torch.jit.is_scripting()
513
+ and self.q_head_dim == self.head_dim
514
+ ):
515
+ assert key is not None and value is not None
516
+ assert attn_mask is None
517
+
518
+ attn_mask_rel_pos = None
519
+ if position_bias is not None:
520
+ attn_mask_rel_pos = position_bias
521
+ if self.gru_rel_pos:
522
+ query_layer = query.transpose(0, 1)
523
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
524
+ query_layer = query_layer.view(*new_x_shape)
525
+ query_layer = query_layer.permute(0, 2, 1, 3)
526
+ _B, _H, _L, __ = query_layer.size()
527
+
528
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
529
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
530
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
531
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
532
+
533
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
534
+ k_proj_bias = self.k_proj.bias
535
+ if k_proj_bias is None:
536
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
537
+
538
+ x, attn = F.multi_head_attention_forward(
539
+ query,
540
+ key,
541
+ value,
542
+ self.embed_dim,
543
+ self.num_heads,
544
+ torch.empty([0]),
545
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
546
+ self.bias_k,
547
+ self.bias_v,
548
+ self.add_zero_attn,
549
+ self.dropout_module.p,
550
+ self.out_proj.weight,
551
+ self.out_proj.bias,
552
+ self.training,
553
+ # self.training or self.dropout_module.apply_during_inference,
554
+ key_padding_mask,
555
+ need_weights,
556
+ attn_mask_rel_pos,
557
+ use_separate_proj_weight=True,
558
+ q_proj_weight=self.q_proj.weight,
559
+ k_proj_weight=self.k_proj.weight,
560
+ v_proj_weight=self.v_proj.weight,
561
+ )
562
+ return x, attn, position_bias
563
+
564
+ if incremental_state is not None:
565
+ saved_state = self._get_input_buffer(incremental_state)
566
+ if saved_state is not None and "prev_key" in saved_state:
567
+ # previous time steps are cached - no need to recompute
568
+ # key and value if they are static
569
+ if static_kv:
570
+ assert self.encoder_decoder_attention and not self.self_attention
571
+ key = value = None
572
+ else:
573
+ saved_state = None
574
+
575
+ if self.self_attention:
576
+ q = self.q_proj(query)
577
+ k = self.k_proj(query)
578
+ v = self.v_proj(query)
579
+ elif self.encoder_decoder_attention:
580
+ # encoder-decoder attention
581
+ q = self.q_proj(query)
582
+ if key is None:
583
+ assert value is None
584
+ k = v = None
585
+ else:
586
+ k = self.k_proj(key)
587
+ v = self.v_proj(key)
588
+
589
+ else:
590
+ assert key is not None and value is not None
591
+ q = self.q_proj(query)
592
+ k = self.k_proj(key)
593
+ v = self.v_proj(value)
594
+ q *= self.scaling
595
+
596
+ if self.bias_k is not None:
597
+ assert self.bias_v is not None
598
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
599
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
600
+ if attn_mask is not None:
601
+ attn_mask = torch.cat(
602
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
603
+ )
604
+ if key_padding_mask is not None:
605
+ key_padding_mask = torch.cat(
606
+ [
607
+ key_padding_mask,
608
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
609
+ ],
610
+ dim=1,
611
+ )
612
+
613
+ q = (
614
+ q.contiguous()
615
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
616
+ .transpose(0, 1)
617
+ )
618
+ if k is not None:
619
+ k = (
620
+ k.contiguous()
621
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
622
+ .transpose(0, 1)
623
+ )
624
+ if v is not None:
625
+ v = (
626
+ v.contiguous()
627
+ .view(-1, bsz * self.num_heads, self.head_dim)
628
+ .transpose(0, 1)
629
+ )
630
+
631
+ if saved_state is not None:
632
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
633
+ if "prev_key" in saved_state:
634
+ _prev_key = saved_state["prev_key"]
635
+ assert _prev_key is not None
636
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
637
+ if static_kv:
638
+ k = prev_key
639
+ else:
640
+ assert k is not None
641
+ k = torch.cat([prev_key, k], dim=1)
642
+ src_len = k.size(1)
643
+ if "prev_value" in saved_state:
644
+ _prev_value = saved_state["prev_value"]
645
+ assert _prev_value is not None
646
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
647
+ if static_kv:
648
+ v = prev_value
649
+ else:
650
+ assert v is not None
651
+ v = torch.cat([prev_value, v], dim=1)
652
+ prev_key_padding_mask: Optional[Tensor] = None
653
+ if "prev_key_padding_mask" in saved_state:
654
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
655
+ assert k is not None and v is not None
656
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
657
+ key_padding_mask=key_padding_mask,
658
+ prev_key_padding_mask=prev_key_padding_mask,
659
+ batch_size=bsz,
660
+ src_len=k.size(1),
661
+ static_kv=static_kv,
662
+ )
663
+
664
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
665
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
666
+ saved_state["prev_key_padding_mask"] = key_padding_mask
667
+ # In this branch incremental_state is never None
668
+ assert incremental_state is not None
669
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
670
+ assert k is not None
671
+ assert k.size(1) == src_len
672
+
673
+ # This is part of a workaround to get around fork/join parallelism
674
+ # not supporting Optional types.
675
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
676
+ key_padding_mask = None
677
+
678
+ if key_padding_mask is not None:
679
+ assert key_padding_mask.size(0) == bsz
680
+ assert key_padding_mask.size(1) == src_len
681
+
682
+ if self.add_zero_attn:
683
+ assert v is not None
684
+ src_len += 1
685
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
686
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
687
+ if attn_mask is not None:
688
+ attn_mask = torch.cat(
689
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
690
+ )
691
+ if key_padding_mask is not None:
692
+ key_padding_mask = torch.cat(
693
+ [
694
+ key_padding_mask,
695
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
696
+ key_padding_mask
697
+ ),
698
+ ],
699
+ dim=1,
700
+ )
701
+
702
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
703
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
704
+
705
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
706
+
707
+ if attn_mask is not None:
708
+ attn_mask = attn_mask.unsqueeze(0)
709
+ attn_weights += attn_mask
710
+
711
+ if key_padding_mask is not None:
712
+ # don't attend to padding symbols
713
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
714
+ if not is_tpu:
715
+ attn_weights = attn_weights.masked_fill(
716
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
717
+ float("-inf"),
718
+ )
719
+ else:
720
+ attn_weights = attn_weights.transpose(0, 2)
721
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
722
+ attn_weights = attn_weights.transpose(0, 2)
723
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
724
+
725
+ if before_softmax:
726
+ return attn_weights, v, position_bias
727
+
728
+ if position_bias is not None:
729
+ if self.gru_rel_pos == 1:
730
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
731
+ _B, _H, _L, __ = query_layer.size()
732
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
733
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
734
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
735
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
736
+
737
+ position_bias = position_bias.view(attn_weights.size())
738
+
739
+ attn_weights = attn_weights + position_bias
740
+
741
+ attn_weights_float = F.softmax(
742
+ attn_weights, dim=-1
743
+ )
744
+ attn_weights = attn_weights_float.type_as(attn_weights)
745
+ attn_probs = self.dropout_module(attn_weights)
746
+
747
+ assert v is not None
748
+ attn = torch.bmm(attn_probs, v)
749
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
750
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
751
+ attn = self.out_proj(attn)
752
+ attn_weights: Optional[Tensor] = None
753
+ if need_weights:
754
+ attn_weights = attn_weights_float.view(
755
+ bsz, self.num_heads, tgt_len, src_len
756
+ ).transpose(1, 0)
757
+ if not need_head_weights:
758
+ # average attention weights over heads
759
+ attn_weights = attn_weights.mean(dim=0)
760
+
761
+ return attn, attn_weights, position_bias
762
+
763
+ @staticmethod
764
+ def _append_prev_key_padding_mask(
765
+ key_padding_mask: Optional[Tensor],
766
+ prev_key_padding_mask: Optional[Tensor],
767
+ batch_size: int,
768
+ src_len: int,
769
+ static_kv: bool,
770
+ ) -> Optional[Tensor]:
771
+ # saved key padding masks have shape (bsz, seq_len)
772
+ if prev_key_padding_mask is not None and static_kv:
773
+ new_key_padding_mask = prev_key_padding_mask
774
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
775
+ new_key_padding_mask = torch.cat(
776
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
777
+ )
778
+ # During incremental decoding, as the padding token enters and
779
+ # leaves the frame, there will be a time when prev or current
780
+ # is None
781
+ elif prev_key_padding_mask is not None:
782
+ if src_len > prev_key_padding_mask.size(1):
783
+ filler = torch.zeros(
784
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
785
+ device=prev_key_padding_mask.device,
786
+ )
787
+ new_key_padding_mask = torch.cat(
788
+ [prev_key_padding_mask.float(), filler.float()], dim=1
789
+ )
790
+ else:
791
+ new_key_padding_mask = prev_key_padding_mask.float()
792
+ elif key_padding_mask is not None:
793
+ if src_len > key_padding_mask.size(1):
794
+ filler = torch.zeros(
795
+ (batch_size, src_len - key_padding_mask.size(1)),
796
+ device=key_padding_mask.device,
797
+ )
798
+ new_key_padding_mask = torch.cat(
799
+ [filler.float(), key_padding_mask.float()], dim=1
800
+ )
801
+ else:
802
+ new_key_padding_mask = key_padding_mask.float()
803
+ else:
804
+ new_key_padding_mask = prev_key_padding_mask
805
+ return new_key_padding_mask
806
+
807
+ def _get_input_buffer(
808
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
809
+ ) -> Dict[str, Optional[Tensor]]:
810
+ result = self.get_incremental_state(incremental_state, "attn_state")
811
+ if result is not None:
812
+ return result
813
+ else:
814
+ empty_result: Dict[str, Optional[Tensor]] = {}
815
+ return empty_result
816
+
817
+ def _set_input_buffer(
818
+ self,
819
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
820
+ buffer: Dict[str, Optional[Tensor]],
821
+ ):
822
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
823
+
824
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
825
+ return attn_weights
components/semantic_extractor/ssl_model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import joblib
4
+ from components.semantic_extractor.WavLM import WavLM, WavLMConfig
5
+
6
+ class ApplyKmeans(nn.Module):
7
+ def __init__(self, km_path, device='cuda'):
8
+ super(ApplyKmeans, self).__init__()
9
+ print(f'Init k-means model from {km_path}')
10
+ self.km_model = joblib.load(km_path)
11
+ self.C_np = self.km_model.cluster_centers_.transpose()
12
+ self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True)
13
+ self.C = torch.from_numpy(self.C_np).to(device)
14
+ self.Cnorm = torch.from_numpy(self.Cnorm_np).to(device)
15
+ self.emb = nn.Embedding(num_embeddings=300, embedding_dim=1024)
16
+ self.emb.weight.data = self.C.transpose(0, 1)
17
+ self.emb.weight.require_grad = False
18
+
19
+ def forward(self, x, b, t):
20
+ if not hasattr(self, 'C'):
21
+ self.C = torch.from_numpy(self.C_np).to(x.device)
22
+ if not hasattr(self, 'Cnorm'):
23
+ self.Cnorm = torch.from_numpy(self.Cnorm_np).to(x.device)
24
+ dist = x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
25
+ tokens = dist.argmin(dim=-1).reshape(b, t)
26
+ return tokens
27
+
28
+ def get_ssl_model(ckpt_path, km_path, device='cuda', type='xlsr'):
29
+ if type == 'xlsr':
30
+ print(f'Init xlsr model from {ckpt_path}')
31
+ import fairseq
32
+ import argparse
33
+ task_arg = argparse.Namespace(task='audio_pretraining')
34
+ task = fairseq.tasks.setup_task(task_arg)
35
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path], task=task)
36
+ model = model[0]
37
+ model.eval()
38
+ elif type == 'wavlm':
39
+ print(f'Init wavlm model from {ckpt_path}')
40
+ cpt = torch.load(ckpt_path, map_location="cpu")
41
+ cfg = WavLMConfig(cpt["cfg"])
42
+ model = WavLM(cfg)
43
+ model.load_state_dict(cpt["model"])
44
+ model = model.eval()
45
+ model = model.requires_grad_(False)
46
+ else:
47
+ raise NotImplementedError
48
+ km_model = ApplyKmeans(km_path, device)
49
+ return model, km_model
50
+
components/simcodec/__init__.py ADDED
File without changes
components/simcodec/model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ from components.simcodec.modules import Encoder, Quantizer, Generator
5
+
6
+ class AttrDict(dict):
7
+ def __init__(self, *args, **kwargs):
8
+ super(AttrDict, self).__init__(*args, **kwargs)
9
+ self.__dict__ = self
10
+
11
+ class SimCodec(nn.Module):
12
+ def __init__(self, config_path):
13
+ super(SimCodec, self).__init__()
14
+ self.config_path = config_path
15
+ with open(self.config_path) as f:
16
+ data = f.read()
17
+ json_config = json.loads(data)
18
+ self.h = AttrDict(json_config)
19
+ self.encoder = Encoder(self.h)
20
+ self.quantizer = Quantizer(self.h)
21
+ self.generator = Generator(self.h)
22
+
23
+ def forward(self, x):
24
+ batch_size = x.size(0)
25
+ if len(x.shape) == 3 and x.shape[-1] == 1:
26
+ x = x.squeeze(-1)
27
+ c = self.encoder(x)
28
+ _, _, c = self.quantizer(c)
29
+ c = [code.reshape(batch_size, -1) for code in c]
30
+ return torch.stack(c, -1)
31
+
32
+ def decode(self, x):
33
+ return self.generator(self.quantizer.embed(x))
components/simcodec/modules.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils import weight_norm, remove_weight_norm
5
+ from torch.nn import Conv1d, ConvTranspose1d
6
+
7
+ LRELU_SLOPE = 0.1
8
+ alpha = 1.0
9
+
10
+ def get_padding(kernel_size, dilation=1):
11
+ return int((kernel_size*dilation - dilation)/2)
12
+
13
+ def init_weights(m, mean=0.0, std=0.01):
14
+ classname = m.__class__.__name__
15
+ if classname.find("Conv") != -1:
16
+ m.weight.data.normal_(mean, std)
17
+
18
+ class ResBlock1(torch.nn.Module):
19
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
20
+ super(ResBlock1, self).__init__()
21
+ self.h = h
22
+ self.convs1 = nn.ModuleList([
23
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
24
+ padding=get_padding(kernel_size, dilation[0]))),
25
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
26
+ padding=get_padding(kernel_size, dilation[1]))),
27
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
28
+ padding=get_padding(kernel_size, dilation[2])))
29
+ ])
30
+ self.convs1.apply(init_weights)
31
+
32
+ self.convs2 = nn.ModuleList([
33
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
34
+ padding=get_padding(kernel_size, 1))),
35
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
36
+ padding=get_padding(kernel_size, 1))),
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
38
+ padding=get_padding(kernel_size, 1)))
39
+ ])
40
+ self.convs2.apply(init_weights)
41
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
42
+ self.activations = nn.ModuleList([nn.LeakyReLU(LRELU_SLOPE) for _ in range(self.num_layers)])
43
+
44
+
45
+ def forward(self, x):
46
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
47
+ for c1, c2,a1,a2 in zip(self.convs1, self.convs2,acts1,acts2):
48
+ xt = a1(x)
49
+ xt = c1(xt)
50
+ xt = a2(xt)
51
+ xt = c2(xt)
52
+ x = xt + x
53
+ return x
54
+
55
+ def remove_weight_norm(self):
56
+ for l in self.convs1:
57
+ remove_weight_norm(l)
58
+ for l in self.convs2:
59
+ remove_weight_norm(l)
60
+
61
+
62
+ class Encoder(torch.nn.Module):
63
+ def __init__(self, h):
64
+ super(Encoder, self).__init__()
65
+ self.n_filters = h.en_filters
66
+ self.vq_dim = h.vq_dim
67
+ self.num_kernels = len(h.resblock_kernel_sizes)
68
+ self.num_upsamples = len(h.upsample_rates)
69
+ self.upsample_initial_channel = self.n_filters * ( 2**self.num_upsamples )
70
+ self.conv_pre = weight_norm(Conv1d(h.channel, self.n_filters, 7, 1, padding=3))
71
+ self.normalize = nn.ModuleList()
72
+ resblock = ResBlock1
73
+
74
+ self.ups = nn.ModuleList()
75
+ for i, (u, k) in enumerate(list(reversed(list(zip(h.upsample_rates, h.upsample_kernel_sizes))))):
76
+ self.ups.append(weight_norm(
77
+ Conv1d(self.n_filters*(2**i), self.n_filters*(2**(i+1)),
78
+ k, u,
79
+ padding=((k-u)//2)
80
+ )))
81
+ self.resblocks = nn.ModuleList()
82
+ ch = 1
83
+ for i in range(len(self.ups)):
84
+ ch = self.n_filters*(2**(i+1))
85
+ for j, (k, d) in enumerate(
86
+ zip(
87
+ list(reversed(h.resblock_kernel_sizes)),
88
+ list(reversed(h.resblock_dilation_sizes))
89
+ )
90
+ ):
91
+ self.resblocks.append(resblock(h, ch, k, d))
92
+ self.normalize.append(torch.nn.LayerNorm([ch],eps=1e-6,elementwise_affine=True))
93
+
94
+ self.activation_post = nn.LeakyReLU(LRELU_SLOPE)
95
+ self.conv_post = Conv1d(ch, self.vq_dim, 3, 1, padding=1)
96
+ self.ups.apply(init_weights)
97
+ self.conv_post.apply(init_weights)
98
+
99
+ def forward(self, x):
100
+ x = self.conv_pre(x)
101
+ for i in range(self.num_upsamples):
102
+ x = self.ups[i](x)
103
+ xs = None
104
+ for j in range(self.num_kernels):
105
+ if xs is None:
106
+ xs = self.resblocks[i*self.num_kernels+j](x)
107
+ xs = self.normalize[i*self.num_kernels+j](xs.transpose(1,2)).transpose(1,2)
108
+ else:
109
+ xs += self.resblocks[i*self.num_kernels+j](x)
110
+ xs = self.normalize[i*self.num_kernels+j](xs.transpose(1,2)).transpose(1,2)
111
+ x = xs / self.num_kernels
112
+ x = self.activation_post(x)
113
+ x = self.conv_post(x)
114
+ return x
115
+
116
+ def remove_weight_norm(self):
117
+ print('Removing weight norm...')
118
+ for l in self.ups:
119
+ remove_weight_norm(l)
120
+ for l in self.resblocks:
121
+ l.remove_weight_norm()
122
+ remove_weight_norm(self.conv_pre)
123
+
124
+ class Quantizer_module(torch.nn.Module):
125
+ def __init__(self, n_e, e_dim):
126
+ super(Quantizer_module, self).__init__()
127
+ self.embedding = nn.Embedding(n_e, e_dim)
128
+ self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
129
+ self.target = torch.arange(0,n_e)
130
+
131
+ def forward(self, x, idx=0):
132
+ loss=torch.Tensor([0.0])
133
+ d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) \
134
+ - 2 * torch.matmul(x, self.embedding.weight.T)
135
+ min_indicies = torch.argmin(d, 1)
136
+ z_q = self.embedding(min_indicies)
137
+ embed_vec = self.embedding.weight
138
+ embed_dis = torch.mm(embed_vec , embed_vec.T)*3
139
+ self.target = torch.arange(0,embed_vec.shape[0]).to(x.device)
140
+ loss = F.cross_entropy(embed_dis,self.target)*(idx==0)
141
+ return z_q, min_indicies,loss
142
+
143
+ class Quantizer(torch.nn.Module):
144
+ def __init__(self, h):
145
+ super(Quantizer, self).__init__()
146
+ assert h.vq_dim % h.n_code_groups == 0
147
+ self.lm_offset = 0
148
+ self.lm_states = None
149
+ self.vq_dim = h.vq_dim
150
+ self.residul_layer = h.n_q
151
+ self.n_code_groups = h.n_code_groups
152
+ self.quantizer_modules = nn.ModuleList()
153
+ for i in range(self.residul_layer):
154
+ self.quantizer_modules.append(nn.ModuleList([
155
+ Quantizer_module(h.n_codes, self.vq_dim // h.n_code_groups) for _ in range(h.n_code_groups)
156
+ ]))
157
+ self.h = h
158
+ self.codebook_loss_lambda = self.h.codebook_loss_lambda # e.g., 1
159
+ self.commitment_loss_lambda = self.h.commitment_loss_lambda # e.g., 0.25
160
+
161
+
162
+ def for_one_step(self, xin, idx):
163
+ xin = xin.transpose(1, 2)
164
+ x = xin.reshape(-1, self.vq_dim)
165
+ x = torch.split(x, self.vq_dim // self.h.n_code_groups, dim=-1)
166
+ min_indicies = []
167
+ z_q = []
168
+ all_losses = []
169
+ for _x, m in zip(x, self.quantizer_modules[idx]):
170
+ _z_q, _min_indicies,_loss = m(_x,idx)
171
+ all_losses.append(_loss)
172
+ z_q.append(_z_q)
173
+ min_indicies.append(_min_indicies)
174
+ z_q = torch.cat(z_q, -1).reshape(xin.shape)
175
+ z_q = z_q.transpose(1, 2)
176
+ all_losses = torch.stack(all_losses)
177
+ loss = torch.mean(all_losses)
178
+ return z_q, min_indicies, loss
179
+
180
+
181
+ def forward(self, xin,bw=-1,mask_id=None):
182
+ quantized_out = 0.0
183
+ residual = xin
184
+ all_losses = []
185
+ all_indices = []
186
+ if bw<=0:
187
+ bw = self.residul_layer
188
+ for i in range(bw):
189
+ quantized, indices, e_loss = self.for_one_step(residual, i) #
190
+ if mask_id is not None:
191
+ mask = (
192
+ torch.full([xin.shape[0],xin.shape[2],1], fill_value=i, device=xin.device) < mask_id.unsqueeze(2) + 1
193
+ )
194
+ mask = mask.repeat(1,1,xin.shape[1]).transpose(1,2)
195
+ if mask_id is not None:
196
+ loss = 0.1 * e_loss + self.codebook_loss_lambda * torch.mean((quantized - residual.detach()) ** 2 * mask) \
197
+ + self.commitment_loss_lambda * torch.mean((quantized.detach() - residual) ** 2 * mask )
198
+ else:
199
+ loss = 0.1 * e_loss \
200
+ + self.codebook_loss_lambda * torch.mean((quantized - residual.detach()) ** 2 ) \
201
+ + self.commitment_loss_lambda * torch.mean((quantized.detach() - residual) ** 2 )
202
+
203
+ quantized = residual + (quantized - residual).detach()
204
+ residual = residual - quantized
205
+ if mask_id is not None:
206
+ quantized_out = quantized_out + quantized * mask
207
+ else:
208
+ quantized_out = quantized_out + quantized
209
+ all_indices.extend(indices) #
210
+ all_losses.append(loss)
211
+ all_losses = torch.stack(all_losses)
212
+ loss = torch.mean(all_losses)
213
+ return quantized_out, loss, all_indices
214
+
215
+ def embed(self, x , bw=-1):
216
+ quantized_out = torch.tensor(0.0, device=x.device)
217
+ x = torch.split(x, 1, 2)
218
+ if bw <= 0 or bw > self.residul_layer:
219
+ bw = self.residul_layer
220
+ for i in range(bw):
221
+ ret = []
222
+ for j in range(self.n_code_groups):
223
+ q = x[j+self.n_code_groups*i]
224
+ embed = self.quantizer_modules[i][j]
225
+ q = embed.embedding(q.squeeze(-1))
226
+ ret.append(q)
227
+ ret = torch.cat(ret, -1)
228
+ quantized_out = quantized_out + ret
229
+ return quantized_out.transpose(1, 2)
230
+
231
+
232
+ class Generator(torch.nn.Module):
233
+ def __init__(self, h):
234
+ super(Generator, self).__init__()
235
+ self.h = h
236
+ self.n_filters = h.de_filters
237
+ self.vq_dim = h.vq_dim
238
+ self.num_kernels = len(h.resblock_kernel_sizes)
239
+ self.num_upsamples = len(h.upsample_rates)
240
+ self.upsample_initial_channel = self.n_filters * ( 2**self.num_upsamples )
241
+ self.conv_pre = weight_norm(Conv1d(self.vq_dim, self.upsample_initial_channel, 7, 1, padding=3))
242
+ resblock = ResBlock1
243
+
244
+
245
+ self.norm = nn.Identity()
246
+
247
+ self.ups = nn.ModuleList()
248
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
249
+ self.ups.append(weight_norm(
250
+ ConvTranspose1d(
251
+ self.upsample_initial_channel//(2**i), self.upsample_initial_channel//(2**(i+1)),
252
+ k, u,
253
+ padding=(k - u )//2,
254
+ )
255
+ ))
256
+ ch = 1
257
+ self.resblocks = nn.ModuleList()
258
+ for i in range(len(self.ups)):
259
+ ch = self.upsample_initial_channel//(2**(i+1))
260
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
261
+ self.resblocks.append(resblock(h, ch, k, d))
262
+
263
+
264
+ self.activation_post = nn.LeakyReLU(LRELU_SLOPE)
265
+ self.conv_post = weight_norm(Conv1d(ch, h.channel, 7, 1, padding=3))
266
+ self.ups.apply(init_weights)
267
+ self.conv_post.apply(init_weights)
268
+
269
+ def forward(self, x):
270
+ x = self.norm(x)
271
+ x = self.conv_pre(x)
272
+
273
+ for i in range(self.num_upsamples):
274
+ x = self.ups[i](x)
275
+ xs = None
276
+ for j in range(self.num_kernels):
277
+ if xs is None:
278
+ xs = self.resblocks[i*self.num_kernels+j](x)
279
+ else:
280
+ xs += self.resblocks[i*self.num_kernels+j](x)
281
+ x = xs / self.num_kernels
282
+ x = self.activation_post(x)
283
+ x = self.conv_post(x)
284
+ x = torch.tanh(x)
285
+
286
+ return x
287
+
288
+ def remove_weight_norm(self):
289
+ print('Removing weight norm...')
290
+ for l in self.ups:
291
+ remove_weight_norm(l)
292
+ for l in self.resblocks:
293
+ l.remove_weight_norm()
294
+ remove_weight_norm(self.conv_pre)
295
+ remove_weight_norm(self.conv_post)