jimbozhang commited on
Commit
fb96339
·
verified ·
1 Parent(s): 52527c2

Upload 3 files

Browse files
Files changed (3) hide show
  1. configuration_ced.py +137 -0
  2. feature_extraction_ced.py +166 -0
  3. modeling_ced.py +549 -0
configuration_ced.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Xiaomi Corporation and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ CED model configuration"""
16
+
17
+
18
+ from transformers import PretrainedConfig
19
+ from transformers.utils import logging
20
+ from transformers.utils.hub import cached_file
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ CED_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "mispeech/ced-tiny": "https://huggingface.co/mispeech/ced-tiny/resolve/main/config.json",
26
+ }
27
+
28
+
29
+ class CedConfig(PretrainedConfig):
30
+ model_type = "ced"
31
+
32
+ r"""
33
+ Configuration class for the CED model.
34
+
35
+ Args:
36
+ name (str, optional, *optional*):
37
+ Name of the pre-defined configuration. Can be "ced-tiny", "ced-mini", "ced-small" or "ced-base".
38
+ attn_drop_rate (float, *optional*, defaults to 0.0):
39
+ Dropout probability for attention weights. Default to 0.0.
40
+ depth (int, *optional*, defaults to 12): Number of transformer layers. Default to 12.
41
+ drop_path_rate (float, *optional*, defaults to 0.0): Drop path is taken from timm. Default to 0.0.
42
+ drop_rate (float, *optional*, defaults to 0.0):
43
+ Dropout probability for input embeddings. Default to 0.0.
44
+ embed_dim (int, *optional*, defaults to 768):
45
+ Dimensionality of the audio patch embeddings. Default to 768.
46
+ eval_avg (str, *optional*, defaults to `"mean"`):
47
+ Type of pooling to use for evaluation. Can be "mean", "token", "dm" or "logit". Default to "mean".
48
+ mlp_ratio (float, *optional*, defaults to 4.0):
49
+ Ratio of hidden size in the feedforward layer to the embedding size. Default to 4.0.
50
+ num_heads (int, *optional*, defaults to 12): Number of attention heads. Default to 12.
51
+ outputdim (int, *optional*, defaults to 527): Dimensionality of the output. Default to 527.
52
+ patch_size (int, *optional*, defaults to 16): Size of the patches. Default to 16.
53
+ patch_stride (int, *optional*, defaults to 16): Stride of the patches. Default to 16.
54
+ pooling (str, *optional*, defaults to `"mean"`):
55
+ Type of pooling to use for the output. Can be "mean", "token", "dm" or "logit". Default to "mean".
56
+ qkv_bias (bool, *optional*, defaults to `True`):
57
+ Whether to include bias terms in the query, key and value projections. Default to True.
58
+ target_length (int, *optional*, defaults to 1012): Frames of an audio chunk. Default to 1012.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ name=None,
64
+ attn_drop_rate=0.0,
65
+ depth=12,
66
+ drop_path_rate=0.0,
67
+ drop_rate=0.0,
68
+ embed_dim=768,
69
+ eval_avg="mean",
70
+ mlp_ratio=4.0,
71
+ num_heads=12,
72
+ outputdim=527,
73
+ patch_size=16,
74
+ patch_stride=16,
75
+ pooling="mean",
76
+ qkv_bias=True,
77
+ target_length=1012,
78
+ **kwargs,
79
+ ):
80
+ r"""
81
+ TODO: Add docstring
82
+ """
83
+
84
+ super().__init__(**kwargs)
85
+
86
+ if name == "ced-tiny":
87
+ embed_dim = 192
88
+ num_heads = 3
89
+ elif name == "ced-mini":
90
+ embed_dim = 256
91
+ num_heads = 4
92
+ elif name == "ced-small":
93
+ embed_dim = 384
94
+ num_heads = 6
95
+ elif name == "ced-base":
96
+ embed_dim = 768
97
+ num_heads = 12
98
+ else:
99
+ logger.info("No model name specified for CedConfig, use default settings.")
100
+
101
+ assert pooling in ("mean", "token", "dm", "logit")
102
+ self.name = name
103
+ self.attn_drop_rate = attn_drop_rate
104
+ self.center = kwargs.get("center", True)
105
+ self.depth = depth
106
+ self.drop_path_rate = drop_path_rate
107
+ self.drop_rate = drop_rate
108
+ self.embed_dim = embed_dim
109
+ self.eval_avg = eval_avg
110
+ self.f_max = kwargs.get("f_max", 8000)
111
+ self.f_min = kwargs.get("f_min", 0)
112
+ self.hop_size = kwargs.get("hop_size", 160)
113
+ self.mlp_ratio = mlp_ratio
114
+ self.n_fft = kwargs.get("n_fft", 512)
115
+ self.n_mels = kwargs.get("n_mels", 64)
116
+ self.n_mels = kwargs.get("n_mels", 64)
117
+ self.num_heads = num_heads
118
+ self.outputdim = outputdim
119
+ self.pad_last = kwargs.get("pad_last", True)
120
+ self.patch_size = patch_size
121
+ self.patch_stride = patch_stride
122
+ self.pooling = pooling
123
+ self.qkv_bias = qkv_bias
124
+ self.target_length = target_length
125
+ self.win_size = kwargs.get("win_size", 512)
126
+ self.loss = "BCE"
127
+
128
+ if self.outputdim == 527:
129
+ with open(cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r") as f:
130
+ self.id2label = {
131
+ int(line.split(",", maxsplit=3)[0]): line.split(",", maxsplit=3)[2].replace('"', "").strip("\n")
132
+ for line in f.readlines()[1:]
133
+ }
134
+ self.label2id = {v: k for k, v in self.id2label.items()}
135
+ else:
136
+ self.id2label = None
137
+ self.label2id = None
feature_extraction_ced.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Xiaomi Corporation and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Feature extractor class for CED.
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torchaudio.transforms as audio_transforms
24
+
25
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
26
+ from transformers.feature_extraction_utils import BatchFeature
27
+ from transformers.utils import logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class CedFeatureExtractor(SequenceFeatureExtractor):
34
+ r"""
35
+ CedFeatureExtractor extracts Mel spectrogram features from audio signals.
36
+
37
+ Args:
38
+ f_min (int, *optional*, defaults to 0): Minimum frequency for the Mel filterbank.
39
+ sampling_rate (int, *optional*, defaults to 16000):
40
+ Sampling rate of the input audio signal.
41
+ win_size (int, *optional*, defaults to 512): Window size for the STFT.
42
+ center (bool, *optional*, defaults to `True`):
43
+ Whether to pad the signal on both sides to center it.
44
+ n_fft (int, *optional*, defaults to 512): Number of FFT points for the STFT.
45
+ f_max (int, optional, *optional*): Maximum frequency for the Mel filterbank.
46
+ hop_size (int, *optional*, defaults to 160): Hop size for the STFT.
47
+ feature_size (int, *optional*, defaults to 64): Number of Mel bands to generate.
48
+ padding_value (float, *optional*, defaults to 0.0): Value for padding.
49
+
50
+ Returns:
51
+ BatchFeature: A BatchFeature object containing the extracted features.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ f_min: int = 0,
57
+ sampling_rate: int = 16000,
58
+ win_size: int = 512,
59
+ center: bool = True,
60
+ n_fft: int = 512,
61
+ f_max: Optional[int] = None,
62
+ hop_size: int = 160,
63
+ feature_size: int = 64,
64
+ padding_value: float = 0.0,
65
+ **kwargs,
66
+ ):
67
+ super().__init__(
68
+ feature_size=feature_size,
69
+ sampling_rate=sampling_rate,
70
+ padding_value=padding_value,
71
+ **kwargs,
72
+ )
73
+ self.f_min = f_min
74
+ self.win_size = win_size
75
+ self.center = center
76
+ self.n_fft = n_fft
77
+ self.f_max = f_max
78
+ self.hop_size = hop_size
79
+
80
+ self.model_input_names = ["input_values"]
81
+
82
+ def __call__(
83
+ self,
84
+ x: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
85
+ sampling_rate: Optional[int] = None,
86
+ max_length: Optional[int] = None,
87
+ truncation: bool = False,
88
+ return_tensors="pt",
89
+ ) -> BatchFeature:
90
+ r"""
91
+ Extracts Mel spectrogram features from an audio signal tensor.
92
+
93
+ Args:
94
+ x: Input audio signal tensor.
95
+ sampling_rate (int, *optional*, defaults to `None`):
96
+ Sampling rate of the input audio signal.
97
+ max_length (int, *optional*, defaults to None):
98
+ Maximum length of the input audio signal.
99
+ truncation (bool, *optional*, defaults to `False`):
100
+ Whether to truncate the input signal to max_length.
101
+ return_tensors (str, *optional*, defaults to "pt"):
102
+ If set to "pt", the return type will be a PyTorch tensor.
103
+
104
+ Returns:
105
+ BatchFeature: A dictionary containing the extracted features.
106
+ """
107
+ if sampling_rate is None:
108
+ sampling_rate = self.sampling_rate
109
+
110
+ if return_tensors != "pt":
111
+ raise NotImplementedError("Only return_tensors='pt' is currently supported.")
112
+
113
+ mel_spectrogram = audio_transforms.MelSpectrogram(
114
+ f_min=self.f_min,
115
+ sample_rate=sampling_rate,
116
+ win_length=self.win_size,
117
+ center=self.center,
118
+ n_fft=self.n_fft,
119
+ f_max=self.f_max,
120
+ hop_length=self.hop_size,
121
+ n_mels=self.feature_size,
122
+ )
123
+ amplitude_to_db = audio_transforms.AmplitudeToDB(top_db=120)
124
+
125
+ if isinstance(x, np.ndarray):
126
+ if x.ndim == 1:
127
+ x = x[np.newaxis, :]
128
+ if x.ndim != 2:
129
+ raise ValueError("np.ndarray input must be a 1D or 2D.")
130
+ x = torch.from_numpy(x)
131
+ elif isinstance(x, torch.Tensor):
132
+ if x.dim() == 1:
133
+ x = x.unsqueeze(0)
134
+ if x.dim() != 2:
135
+ raise ValueError("torch.Tensor input must be a 1D or 2D.")
136
+ elif isinstance(x, (list, tuple)):
137
+ longest_length = max(x_.shape[0] for x_ in x)
138
+ if not truncation and max_length is not None and max_length < longest_length:
139
+ max_length = longest_length
140
+ if not truncation and max_length is None:
141
+ max_length = longest_length
142
+
143
+
144
+ if all(isinstance(x_, np.ndarray) for x_ in x):
145
+ if not all(x_.ndim == 1 for x_ in x):
146
+ raise ValueError("All np.ndarray in a list must be 1D.")
147
+
148
+ x_trim = [x_[:max_length] for x_ in x]
149
+ x_pad = [np.pad(x_, (0, max_length - x_.shape[0]), mode="constant", constant_values=0) for x_ in x_trim]
150
+ x = torch.stack([torch.from_numpy(x_) for x_ in x_pad])
151
+ elif all(isinstance(x_, torch.Tensor) for x_ in x):
152
+ if not all(x_.dim() == 1 for x_ in x):
153
+ raise ValueError("All torch.Tensor in a list must be 1D.")
154
+ x_pad = [torch.nn.functional.pad(x_, (0, max_length - x_.shape[0]), value=0) for x_ in x]
155
+ x = torch.stack(x_pad)
156
+ else:
157
+ raise ValueError("Input list must be numpy arrays or PyTorch tensors.")
158
+ else:
159
+ raise ValueError(
160
+ "Input must be a numpy array, a list of numpy arrays, a PyTorch tensor, or a list of PyTorch tensor."
161
+ )
162
+
163
+ x = x.float()
164
+ x = mel_spectrogram(x)
165
+ x = amplitude_to_db(x)
166
+ return BatchFeature({"input_values": x})
modeling_ced.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Xiaomi Corporation and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch CED (Ced) model."""
16
+
17
+ import collections
18
+ import math
19
+ from functools import partial
20
+ from typing import Any, Callable, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ from transformers.modeling_outputs import SequenceClassifierOutput
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.utils import (
29
+ add_code_sample_docstrings,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ )
34
+ from .configuration_ced import CedConfig
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ _CONFIG_FOR_DOC = "CedConfig"
40
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'Speech synthesizer'"
41
+ _SEQ_CLASS_EXPECTED_LOSS = 0.69
42
+
43
+ # Audio classification docstring
44
+ _SEQ_CLASS_CHECKPOINT = "mispeech/ced-tiny"
45
+
46
+
47
+ CED_PRETRAINED_MODEL_ARCHIVE_LIST = [
48
+ "mispeech/ced-tiny",
49
+ "mispeech/ced-mini",
50
+ "mispeech/ced-small",
51
+ "mispeech/ced-base",
52
+ # See all CED models at https://huggingface.co/models?search=mispeech%2Fced
53
+ ]
54
+
55
+
56
+ class CedPreTrainedModel(PreTrainedModel):
57
+ """
58
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
59
+ models.
60
+ """
61
+
62
+ config_class = CedConfig
63
+ base_model_prefix = "ced"
64
+ main_input_name = "input_values"
65
+ supports_gradient_checkpointing = True
66
+
67
+ def _init_weights(self, module):
68
+ """Initialize the weights"""
69
+ if isinstance(module, nn.Linear):
70
+ trunc_normal_(module.weight, std=0.02)
71
+ if module.bias is not None:
72
+ nn.init.zeros_(module.bias)
73
+ elif isinstance(module, nn.LayerNorm):
74
+ nn.init.constant_(module.bias, 0)
75
+ nn.init.constant_(module.weight, 1.0)
76
+
77
+
78
+ Conv_Kernel = Union[int, Tuple[int, int]]
79
+
80
+
81
+ def to_2tuple(x: Any) -> Tuple[Any, Any]:
82
+ if isinstance(x, collections.abc.Iterable):
83
+ return x
84
+ return (x, x)
85
+
86
+
87
+ class CedAudioPatchEmbed(nn.Module):
88
+ def __init__(
89
+ self,
90
+ input_size: Conv_Kernel = 224,
91
+ patch_size: Conv_Kernel = 16,
92
+ patch_stride: Conv_Kernel = 16,
93
+ in_chans: int = 1,
94
+ embed_dim: int = 768,
95
+ norm_layer: Optional[Callable] = None,
96
+ flatten: bool = False,
97
+ ):
98
+ super().__init__()
99
+ self.input_size = to_2tuple(input_size)
100
+ self.patch_size = to_2tuple(patch_size)
101
+ self.patch_stride = to_2tuple(patch_stride)
102
+ self.grid_size = (
103
+ self.input_size[0] // self.patch_stride[0],
104
+ self.input_size[1] // self.patch_stride[1],
105
+ )
106
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
107
+ self.flatten = flatten
108
+
109
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride)
110
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
111
+
112
+ def forward(self, x):
113
+ x = self.proj(x)
114
+ if self.flatten:
115
+ x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1))
116
+ x = self.norm(x)
117
+ return x
118
+
119
+
120
+ class CedAttention(nn.Module):
121
+ def __init__(
122
+ self,
123
+ dim,
124
+ num_heads=8,
125
+ qkv_bias=False,
126
+ attn_drop=0.0,
127
+ proj_drop=0.0,
128
+ causal: bool = False,
129
+ ):
130
+ super().__init__()
131
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
132
+ self.num_heads = num_heads
133
+ head_dim = dim // num_heads
134
+ self.scale = head_dim**-0.5
135
+
136
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
137
+ self.attn_drop = nn.Dropout(attn_drop)
138
+ self.proj = nn.Linear(dim, dim)
139
+ self.proj_drop = nn.Dropout(proj_drop)
140
+ self.causal = causal
141
+
142
+ def forward(self, x):
143
+ B, N, C = x.shape
144
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
145
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
146
+
147
+ attn = (q @ k.transpose(-2, -1)) * self.scale
148
+ # if mask is not None:
149
+ # # Mask is a tensor of shape [B, T, T]
150
+ # # Different from self.causal == True, the mask might be something like:
151
+ # # [False, False, True]
152
+ # # [False, False, True]
153
+ # # [True, True, True]
154
+ # # We use -inf to pad here, since if we would pad by any number, the entries at rows only containing
155
+ # # [True, True, True] would lead to weights such as: [0.33,0.33,0.33], which is not correct
156
+ # mask_value = torch.as_tensor(-float('inf'))
157
+ # print(mask.shape, attn.shape)
158
+ # attn = attn.masked_fill(mask, mask_value)
159
+ if self.causal:
160
+ mask_value = -torch.finfo(attn.dtype).max
161
+ i, j = attn.shape[-2:]
162
+ mask = torch.ones(i, j, device=q.device, dtype=torch.bool).triu(j - i + 1)
163
+ attn = attn.masked_fill(mask, mask_value)
164
+ attn = attn.softmax(dim=-1)
165
+ # Only for the case that a mask with all True entries on a row is passed.
166
+ # attn = torch.nan_to_num(attn)
167
+ attn = self.attn_drop(attn)
168
+
169
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
170
+ x = self.proj(x)
171
+ x = self.proj_drop(x)
172
+ return x
173
+
174
+
175
+ class CedMlp(nn.Module):
176
+ def __init__(
177
+ self,
178
+ in_features: int,
179
+ hidden_features: Optional[int] = None,
180
+ out_features: Optional[int] = None,
181
+ act_layer: Callable = nn.GELU,
182
+ drop: float = 0.0,
183
+ ):
184
+ super().__init__()
185
+ out_features = out_features or in_features
186
+ hidden_features = hidden_features or in_features
187
+ self.fc1 = nn.Linear(in_features, hidden_features)
188
+ self.act = act_layer()
189
+ self.fc2 = nn.Linear(hidden_features, out_features)
190
+ self.drop = nn.Dropout(drop)
191
+
192
+ def forward(self, x):
193
+ x = self.fc1(x)
194
+ x = self.act(x)
195
+ x = self.drop(x)
196
+ x = self.fc2(x)
197
+ x = self.drop(x)
198
+ return x
199
+
200
+
201
+ # Drop path is taken from Timm
202
+ # https://github.com/huggingface/pytorch-image-models/blob/7c67d6aca992f039eece0af5f7c29a43d48c00e4/timm/models/layers/drop.py#L155
203
+ class DropPath(nn.Module):
204
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
205
+
206
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
207
+ super(DropPath, self).__init__()
208
+ self.drop_prob = drop_prob
209
+ self.scale_by_keep = scale_by_keep
210
+
211
+ def forward(self, x):
212
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
213
+
214
+ def extra_repr(self):
215
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
216
+
217
+
218
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
219
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
220
+
221
+ This is the same as the DropConnect impl I (https://github.com/rwightman) created for EfficientNet, etc networks,
222
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
223
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
224
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
225
+ argument.
226
+
227
+ """
228
+ if drop_prob == 0.0 or not training:
229
+ return x
230
+ keep_prob = 1 - drop_prob
231
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
232
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
233
+ if keep_prob > 0.0 and scale_by_keep:
234
+ random_tensor.div_(keep_prob)
235
+ return x * random_tensor
236
+
237
+
238
+ class CedBlock(nn.Module):
239
+ def __init__(
240
+ self,
241
+ dim,
242
+ num_heads,
243
+ mlp_ratio=4.0,
244
+ qkv_bias=False,
245
+ drop=0.0,
246
+ attn_drop=0.0,
247
+ drop_path=0.0,
248
+ act_layer: Callable = nn.GELU,
249
+ norm_layer: Callable = nn.LayerNorm,
250
+ attention_type: Callable = CedAttention,
251
+ attention_kwargs={},
252
+ **kwargs,
253
+ ):
254
+ super().__init__()
255
+ self.norm1 = norm_layer(dim)
256
+ self.attn = attention_type(
257
+ dim,
258
+ num_heads=num_heads,
259
+ qkv_bias=qkv_bias,
260
+ attn_drop=attn_drop,
261
+ proj_drop=drop,
262
+ **attention_kwargs,
263
+ )
264
+ self.ls1 = nn.Identity()
265
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
266
+
267
+ self.norm2 = norm_layer(dim)
268
+ self.mlp = CedMlp(
269
+ in_features=dim,
270
+ hidden_features=int(dim * mlp_ratio),
271
+ act_layer=act_layer,
272
+ drop=drop,
273
+ )
274
+ self.ls2 = nn.Identity()
275
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
276
+
277
+ def forward(self, x):
278
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
279
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
280
+ return x
281
+
282
+
283
+ # Taken from timm
284
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
285
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
286
+
287
+
288
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
289
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
290
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
291
+ def norm_cdf(x):
292
+ # Computes standard normal cumulative distribution function
293
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
294
+
295
+ with torch.no_grad():
296
+ # Values are generated by using a truncated uniform distribution and
297
+ # then using the inverse CDF for the normal distribution.
298
+ # Get upper and lower cdf values
299
+ l = norm_cdf((a - mean) / std)
300
+ u = norm_cdf((b - mean) / std)
301
+
302
+ # Uniformly fill tensor with values from [l, u], then translate to
303
+ # [2l-1, 2u-1].
304
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
305
+
306
+ # Use inverse cdf transform for normal distribution to get truncated
307
+ # standard normal
308
+ tensor.erfinv_()
309
+
310
+ # Transform to proper mean, std
311
+ tensor.mul_(std * math.sqrt(2.0))
312
+ tensor.add_(mean)
313
+
314
+ # Clamp to ensure it's in the proper range
315
+ tensor.clamp_(min=a, max=b)
316
+ return tensor
317
+
318
+
319
+ CED_START_DOCSTRING = r"""
320
+
321
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
322
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
323
+ etc.)
324
+
325
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
326
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
327
+ and behavior.
328
+
329
+ Parameters:
330
+ config ([`CedConfig`]): Model configuration class with all the parameters of the model.
331
+ Initializing with a config file does not load the weights associated with the model, only the
332
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
333
+ """
334
+
335
+ CED_INPUTS_DOCSTRING = r"""
336
+ Args:
337
+ input_values (`torch.FloatTensor` of shape `(batch_size, n_mels, sequence_length)`):
338
+ The sequence of audio features extracted from the audio signal. Can be obtained from a raw audio waveform
339
+ using `~transformers.CedFeatureExtractor.__call__`.
340
+ """
341
+
342
+
343
+ @add_start_docstrings(
344
+ "The bare Ced Model transformer outputting raw hidden-states without any specific head on top.",
345
+ CED_START_DOCSTRING,
346
+ )
347
+ class CedModel(CedPreTrainedModel):
348
+ def __init__(self, config: CedConfig) -> None:
349
+ super().__init__(config)
350
+ self.config = config
351
+ self.name = config.name
352
+
353
+ # Allowed length in number of frames, otherwise the positional embedding will throw an error
354
+ self.maximal_allowed_length = self.config.target_length
355
+
356
+ self.init_bn = torch.nn.BatchNorm2d(config.n_mels, momentum=0.01)
357
+
358
+ self.patch_embed = CedAudioPatchEmbed(
359
+ input_size=(config.n_mels, config.target_length),
360
+ embed_dim=config.embed_dim,
361
+ patch_size=config.patch_size,
362
+ flatten=False,
363
+ patch_stride=config.patch_stride,
364
+ )
365
+
366
+ self.time_pos_embed = nn.Parameter(torch.randn(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02)
367
+ self.freq_pos_embed = nn.Parameter(torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02)
368
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
369
+ act_layer = nn.GELU
370
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)] # stochastic depth decay rule
371
+ self.pos_drop = nn.Dropout(p=config.drop_rate)
372
+ self.blocks = nn.Sequential(
373
+ *[
374
+ CedBlock(
375
+ dim=config.embed_dim,
376
+ num_heads=config.num_heads,
377
+ mlp_ratio=config.mlp_ratio,
378
+ qkv_bias=config.qkv_bias,
379
+ drop=config.drop_rate,
380
+ attn_drop=config.attn_drop_rate,
381
+ drop_path=dpr[i],
382
+ norm_layer=norm_layer,
383
+ act_layer=act_layer,
384
+ attention_type=CedAttention,
385
+ )
386
+ for i in range(config.depth)
387
+ ]
388
+ )
389
+ self.norm = norm_layer(config.embed_dim)
390
+
391
+ # Initialize weights and apply final processing
392
+ self.post_init()
393
+
394
+ def _freeze_parameters(self):
395
+ for param in self.parameters():
396
+ param.requires_grad = False
397
+ self._requires_grad = False
398
+
399
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
400
+ x = self.patch_embed(x)
401
+ _, _, _, t = x.shape
402
+ x = x + self.time_pos_embed[:, :, :, :t]
403
+ x = x + self.freq_pos_embed[:, :, :, :] # Just to support __getitem__ in posembed
404
+
405
+ # x = rearrange(x, 'b c f t -> b (f t) c')
406
+ x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1))
407
+
408
+ if self.config.pooling == "token":
409
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
410
+ cls_token = cls_token + self.token_pos_embed
411
+ x = torch.cat((cls_token, x), dim=1)
412
+ x = self.pos_drop(x)
413
+ x = self.blocks(x)
414
+ x = self.norm(x)
415
+ return x
416
+
417
+ def forward(self, input_values: torch.Tensor):
418
+ r"""
419
+ Runs a forward pass of the CED model as an audio encoder.
420
+ """
421
+ x = torch.unsqueeze(input_values, 1)
422
+
423
+ x = torch.permute(x, (0, 2, 1, 3))
424
+ x = self.init_bn(x)
425
+ x = torch.permute(x, (0, 2, 1, 3))
426
+
427
+ if x.shape[-1] > self.maximal_allowed_length:
428
+ splits = x.split(self.maximal_allowed_length, -1)
429
+
430
+ if splits[-1].shape[-1] < self.maximal_allowed_length:
431
+ if self.config.pad_last:
432
+ pad = torch.zeros(*x.shape[:-1], self.maximal_allowed_length, device=x.device)
433
+ pad[..., : splits[-1].shape[-1]] = splits[-1]
434
+ splits = torch.stack((*splits[:-1], pad), dim=0)
435
+ else:
436
+ splits = torch.stack(splits[:-1], dim=0)
437
+ else:
438
+ splits = torch.stack(splits[:-1], dim=0)
439
+ n_splits = len(splits)
440
+ x = torch.flatten(splits, 0, 1) # spl b c f t-> (spl b) c f t
441
+ else:
442
+ n_splits = 1
443
+
444
+ x = self.forward_features(x)
445
+ x = torch.reshape(x, (x.shape[0] // n_splits, -1, x.shape[-1]))
446
+
447
+ return SequenceClassifierOutput(logits=x)
448
+
449
+
450
+ @add_start_docstrings(
451
+ """
452
+ Ced model with an audio classification head on top (a linear layer on top of the pooled output).
453
+ """,
454
+ CED_START_DOCSTRING,
455
+ )
456
+ class CedForAudioClassification(CedPreTrainedModel):
457
+ def __init__(self, config: CedConfig) -> None:
458
+ super().__init__(config)
459
+ self.config = config
460
+
461
+ self.encoder = CedModel(config)
462
+
463
+ # Classifier head
464
+ self.outputlayer = nn.Sequential(
465
+ nn.LayerNorm(config.embed_dim),
466
+ nn.Linear(config.embed_dim, config.outputdim),
467
+ )
468
+
469
+ # Initialize weights and apply final processing
470
+ self.post_init()
471
+
472
+ def forward_head(self, x: torch.Tensor) -> torch.Tensor:
473
+ if self.config.pooling == "token":
474
+ x = x[:, 0]
475
+ return self.outputlayer(x).sigmoid()
476
+ elif self.config.pooling == "mean":
477
+ x = x.mean(1)
478
+ return self.outputlayer(x).sigmoid()
479
+ elif self.config.pooling == "logit":
480
+ x = x.mean(1)
481
+ return self.outputlayer(x)
482
+ elif self.config.pooling == "dm":
483
+ # Unpack using the frequency dimension, which is constant
484
+ # 'b (f t) d -> b f t d', f=self.patch_embed.grid_size[0])
485
+ x = torch.reshape(x, (x.shape[0], self.patch_embed.grid_size[0], -1, x.shape[3]))
486
+
487
+ # First poolin frequency, then sigmoid the (B T D) output
488
+ x = self.outputlayer(x.mean(1)).sigmoid()
489
+ return x.mean(1)
490
+ else:
491
+ return x.mean(1)
492
+
493
+ def freeze_encoder(self):
494
+ self.encoder._freeze_parameters()
495
+
496
+ @add_start_docstrings_to_model_forward(CED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
497
+ @add_code_sample_docstrings(
498
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
499
+ output_type=SequenceClassifierOutput,
500
+ config_class=_CONFIG_FOR_DOC,
501
+ modality="audio",
502
+ model_cls="CedForAudioClassification",
503
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
504
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
505
+ )
506
+ def forward(self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None):
507
+ """
508
+ Runs a forward pass of the CED model for audio classification task.
509
+
510
+ Examples:
511
+
512
+ ```python
513
+ >>> from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
514
+ >>> from datasets import load_dataset
515
+ >>> import torch
516
+
517
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
518
+ >>> dataset = dataset.sort("id")
519
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
520
+
521
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("mispeech/ced-tiny")
522
+ >>> model = AutoModelForAudioClassification.from_pretrained("mispeech/ced-tiny")
523
+
524
+ >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
525
+
526
+ >>> with torch.no_grad():
527
+ ... logits = model(**inputs).logits
528
+
529
+ >>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
530
+ >>> predicted_label = model.config.id2label[predicted_class_ids]
531
+ >>> predicted_label
532
+ 'Speech synthesizer'
533
+ ```
534
+ """
535
+ last_hidden_states = self.encoder(input_values).logits
536
+ logits = self.forward_head(last_hidden_states)
537
+
538
+ if labels is not None:
539
+ try:
540
+ loss_fct = getattr(nn.modules.loss, self.config.loss)()
541
+ except AttributeError:
542
+ raise NotImplementedError(f"Loss {self.config.loss} not implemented.")
543
+
544
+ labels = nn.functional.one_hot(labels, num_classes=self.config.outputdim).float()
545
+ loss = loss_fct(logits, labels)
546
+ else:
547
+ loss = None
548
+
549
+ return SequenceClassifierOutput(logits=logits, loss=loss, hidden_states=last_hidden_states)