fama-medium / conformer_model.py
mgaido91's picture
ad FAMA medium
5491e76 verified
# Copyright 2024 FBK
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
# the code below contains parts copied from the Conformer implementation in
# https://github.com/hlt-mt/FBK-fairseq/blob/master/examples/speech_to_text/models/conformer.py
import math
from itertools import groupby
from typing import Union, Tuple, Optional
import torch
import transformers
from torch import nn, Tensor
from torch.nn import CrossEntropyLoss, functional as F
from transformers import Speech2TextPreTrainedModel, add_start_docstrings, GenerationMixin, Speech2TextProcessor, \
Speech2TextTokenizer, Speech2TextFeatureExtractor
from transformers.modeling_outputs import Seq2SeqModelOutput, BaseModelOutput, Seq2SeqLMOutput
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, \
SPEECH_TO_TEXT_INPUTS_DOCSTRING, shift_tokens_right
from transformers.utils import replace_return_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_conformer import Speech2TextConformerConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Speech2TextConformerConfig"
CONFORMER_START_DOCSTRING = r"""
This model is an implementation of an attention-based autoregressive encoder-decoder model, in which the encoder
is a Conformer Encoder and decoder is a Transformer Decoder. The encoder expects 80-feature spectrograms as input
as the [`Speech2TextModel`] and its implementation follows that of the paper:
`"When Good and Reproducible Results are a Giant with Feet of Clay: The Importance of Software Quality in NLP"
(Papi, et al, ACL 2024) <https://aclanthology.org/2024.acl-long.200/>`_.
This ensures consistency of results regardless of the presence of padding.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Speech2TextConformerConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
class Conv1dSubsampler(nn.Module):
"""Convolutional subsampler: a stack of 1D convolution (along temporal
dimension) followed by non-linear activation via gated linear units
(https://arxiv.org/abs/1911.08460)
"""
def __init__(self, config: Speech2TextConformerConfig):
super(Conv1dSubsampler, self).__init__()
self.n_layers = len(config.conv_kernel_sizes)
in_channels = config.input_feat_per_channel * config.input_channels
mid_channels = config.conv_channels
out_channels = config.d_model
self.conv_layers = nn.ModuleList(
nn.Conv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
k,
stride=2,
padding=k // 2,
)
for i, k in enumerate(config.conv_kernel_sizes)
)
@staticmethod
def subsampled_sequence_len(seq_lens, kernel_size=5, padding=1, stride=2):
compressed_seq_lens = seq_lens.clone()
return ((compressed_seq_lens.float() - kernel_size + 2 * padding) / stride + 1).floor().long()
@staticmethod
def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor:
bsz, max_lens = lens.size(0), torch.max(lens).item()
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
return mask
def forward(self, src_tokens: torch.FloatTensor, padding_mask: torch.IntTensor) -> torch.Tensor:
x = src_tokens.transpose(1, 2).contiguous() # B x T x (C x D) -> B x (C x D) x T
actual_src_lengths = padding_mask.sum(dim=1)
for conv in self.conv_layers:
x = conv(x)
x = nn.functional.glu(x, dim=1)
actual_src_lengths = self.subsampled_sequence_len(
actual_src_lengths,
kernel_size=conv.kernel_size[0],
padding=conv.padding[0],
stride=conv.stride[0])
x = x.masked_fill(
self.lengths_to_padding_mask(actual_src_lengths).unsqueeze(1), 0)
x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D)
return x
class PositionalEncoding(nn.Module):
"""
Positional Encoding proposed in "Attention Is All You Need".
"Attention Is All You Need" use sine and cosine functions of different frequencies:
PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
The version implemented on Fairseq differs slightly from the paper, this implementation is faithful to the
original one. Please see
:func:`~fairseq.modules.sinusoidal_positional_embedding.SinusoidalPositionalEmbedding.get_embedding` for more
details.
"""
def __init__(self, d_model: int = 512, max_len: int = 10000) -> None:
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model, requires_grad=False)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, length: int) -> Tensor:
return self.pe[:, :length]
class RelativeMultiHeadAttention(nn.Module):
"""
Multi-head attention with relative positional encoding.
This concept was proposed in the `"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
<https://arxiv.org/pdf/1901.02860.pdf>`_.
Args:
d_model (int): The dimension of model
num_heads (int): The number of attention heads.
dropout_p (float): probability of dropout
Inputs: query, key, value, pos_embedding, mask
query (batch, time, dim): Tensor containing query vector
key (batch, time, dim): Tensor containing key vector
value (batch, time, dim): Tensor containing value vector
pos_embedding (batch, time, dim): Positional embedding tensor
mask (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
Returns:
**outputs**: Tensor produces by relative multi head attention module.
"""
def __init__(
self,
d_model: int = 512,
num_heads: int = 16,
dropout_p: float = 0.1,
batch_unsafe_relative_shift: bool = False
):
super(RelativeMultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
self.d_model = d_model
self.d_head = int(d_model / num_heads)
self.num_heads = num_heads
self.sqrt_dim = math.sqrt(d_model)
self.query_proj = nn.Linear(d_model, d_model)
nn.init.xavier_uniform_(self.query_proj.weight)
nn.init.zeros_(self.query_proj.bias)
self.key_proj = nn.Linear(d_model, d_model)
nn.init.xavier_uniform_(self.key_proj.weight)
nn.init.zeros_(self.key_proj.bias)
self.value_proj = nn.Linear(d_model, d_model)
nn.init.xavier_uniform_(self.value_proj.weight)
nn.init.zeros_(self.value_proj.bias)
self.pos_proj = nn.Linear(d_model, d_model, bias=False)
nn.init.xavier_uniform_(self.pos_proj.weight)
self.dropout = nn.Dropout(p=dropout_p)
# u and v are the trainable parameters of the Transformer-XL attention computation
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
nn.init.xavier_uniform_(self.u_bias)
nn.init.xavier_uniform_(self.v_bias)
self.out_proj = nn.Linear(d_model, d_model)
nn.init.xavier_uniform_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
self.relative_shift_func = self._relative_shift_unsafe if batch_unsafe_relative_shift else self._relative_shift
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_embedding: Tensor,
mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
batch_size = value.size(0)
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
# Attention weights computation using Q + u as in Transformer-XL
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
# Relative positional weights computation using Q + v as in Transformer-XL
pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
# Right shifting mechanism described in Transformer-XL
pos_score = self.relative_shift_func(pos_score, mask)
# Final attention weights obtained summing the attention with its relative positional embeddings
score = (content_score + pos_score) / self.sqrt_dim
if mask is not None:
mask = mask.unsqueeze(1)
score.masked_fill_(mask, -1e9 if mask.dtype == torch.float32 else -1e4)
attn = F.softmax(score, dim=-1)
# set to 0.0 all attention weights of padding elements
if mask is not None:
attn = attn.masked_fill(mask, 0.0)
attn = self.dropout(attn)
# Attention computation
context = torch.matmul(attn, value).transpose(1, 2)
context = context.contiguous().view(batch_size, -1, self.d_model)
return self.out_proj(context), attn
def _relative_shift(self, pos_score: Tensor, padding_mask: Tensor) -> Tensor:
"""
This methods performs the relative shift operation row-wise.
Although inefficient, it enforces that each row is shifted accounting its padding,
which enforces that the result does not change depending on whether a given row
is padded or not.
"""
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
assert seq_length1 == seq_length2, "Currently we support only self-attention"
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
seq_lengths = (seq_length1 - (padding_mask[:, :, 0]).sum(-1)).tolist()
for b_i in range(batch_size):
padded_batch_pos_scores = padded_pos_score[b_i, :, :seq_lengths[b_i], :seq_lengths[b_i] + 1]
padded_batch_pos_scores = padded_batch_pos_scores.reshape(num_heads, seq_lengths[b_i] + 1, seq_lengths[b_i])
pos_score[b_i, :, :seq_lengths[b_i], :seq_lengths[b_i]] = padded_batch_pos_scores[:, 1:, :]
pos_score.masked_fill_(padding_mask.unsqueeze(1), 0.0)
return pos_score
def _relative_shift_unsafe(self, pos_score: Tensor, padding_mask: Tensor) -> Tensor:
"""
This implementation reflects other open source ones (e.g. fairseq), which
shift the values from the row above in the batch. Although efficient,
this leads to inconsistencies in the results, as the same row has different
values according to whether it is padded (and how much it is) or not.
"""
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
return pos_score
class MultiHeadedSelfAttentionModule(nn.Module):
"""
Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
module to generalize better on different input length and the resulting encoder is more robust to the variance of
the utterance length. Conformer use prenorm residual units with dropout which helps training
and regularizing deeper models.
Args:
d_model (int): The dimension of model
num_heads (int): The number of attention heads.
dropout_p (float): probability of dropout
Inputs: inputs, mask
x (batch, time, dim): Tensor containing input vector
mask (batch, time1, time2): Tensor containing indices to be masked
Returns:
**outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module.
"""
def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1, batch_unsafe_relative_shift: bool = False):
super(MultiHeadedSelfAttentionModule, self).__init__()
self.positional_encoding = PositionalEncoding(d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p, batch_unsafe_relative_shift)
self.dropout = nn.Dropout(p=dropout_p)
def forward(
self, x: Tensor, encoder_padding_mask: Optional[Tensor] = None, output_attention: bool = False
) -> Tuple[Tensor, Tensor]:
batch_size, seq_length, _ = x.size()
pos_embedding = self.positional_encoding(seq_length)
pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
# we need attention padding mask (attn_mask) to be applied during the attention calculation,
# we obtain it from the encoder_padding_mask (B x T) by repeating it T times (x.shape[1]) and
# taking the logical or to correctly mask both T x T dimensions
att_mask = encoder_padding_mask.unsqueeze(1).repeat([1, x.shape[1], 1])
att_mask = att_mask.logical_or(att_mask.transpose(1, 2)) # B x T x T
x = self.layer_norm(x)
outputs, attn = self.attention(x, x, x, pos_embedding=pos_embedding, mask=att_mask)
return self.dropout(outputs), attn if output_attention else None
class FeedForwardModule(nn.Module):
"""
Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit
and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps
regularizing the network.
Args:
encoder_dim (int): Dimension of conformer encoder
expansion_factor (int): Expansion factor of feed forward module.
dropout_p (float): Ratio of dropout
Inputs: inputs
x (batch, time, dim): Tensor contains input sequences
Outputs: outputs
**outputs** (batch, time, dim): Tensor produces by feed forward module.
"""
def __init__(
self,
encoder_dim: int = 512,
expansion_factor: int = 4,
dropout_p: float = 0.1,
) -> None:
super(FeedForwardModule, self).__init__()
self.layernorm = nn.LayerNorm(encoder_dim)
self.dropout_module = nn.Dropout(p=dropout_p)
self.first_linear = nn.Linear(encoder_dim, encoder_dim * expansion_factor, bias=True)
nn.init.xavier_uniform_(self.first_linear.weight)
nn.init.zeros_(self.first_linear.bias)
self.second_linear = nn.Linear(encoder_dim * expansion_factor, encoder_dim, bias=True)
nn.init.xavier_uniform_(self.second_linear.weight)
nn.init.zeros_(self.second_linear.bias)
def forward(self, x: Tensor) -> Tensor:
x = self.layernorm(x)
x = self.first_linear(x)
x = F.silu(x)
x = self.dropout_module(x)
x = self.second_linear(x)
x = self.dropout_module(x)
return x
class ConformerConvModule(nn.Module):
"""
Conformer convolution module starts with the first pointwise convolution and a gated linear unit (GLU).
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
to aid training deep models. Then, Swift (or SiLu) activation function is applied and followed by the second
pointwise convolution. The Dropout module is applied in the end.
Args:
in_channels (int): Number of channels in the input
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
dropout_p (float, optional): probability of dropout
Inputs: inputs
x (batch, time, dim): Tensor contains input sequences
Outputs: outputs
**outputs** (batch, time, dim): Tensor produces by conformer convolution module.
"""
def __init__(
self,
in_channels: int,
kernel_size: int = 31,
expansion_factor: int = 2,
dropout_p: float = 0.1,
no_syncbatchnorm: bool = False,
) -> None:
super(ConformerConvModule, self).__init__()
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
assert expansion_factor == 2, "Currently, only supports expansion_factor 2"
self.layernorm = nn.LayerNorm(in_channels)
self.batchnorm = nn.SyncBatchNorm(in_channels) if not no_syncbatchnorm else nn.BatchNorm1d(in_channels)
self.first_pointwise_conv1d = nn.Conv1d(
in_channels=in_channels,
out_channels=in_channels * expansion_factor,
kernel_size=(1, ),
stride=(1, ),
padding=0,
bias=True,
)
self.second_pointwise_conv1d = nn.Conv1d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=(1, ),
stride=(1, ),
padding=0,
bias=True,
)
self.depthwise_conv1d = nn.Conv1d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=(kernel_size, ),
stride=(1, ),
groups=in_channels,
padding=(kernel_size - 1) // 2,
bias=False,
)
self.dropout_module = nn.Dropout(p=dropout_p)
def forward(self, x: Tensor, encoder_padding_mask: Tensor) -> Tensor:
x = self.layernorm(x).transpose(1, 2)
x = self.first_pointwise_conv1d(x)
x = F.glu(x, dim=1)
bool_padding_mask = None
if encoder_padding_mask is not None:
bool_padding_mask = encoder_padding_mask.unsqueeze(1).bool()
if bool_padding_mask is not None:
x = x.float().masked_fill(bool_padding_mask, 0.0)
x = self.depthwise_conv1d(x)
if bool_padding_mask is not None:
x = x.float().masked_fill(bool_padding_mask, 0.0)
x = self.batchnorm(x)
if bool_padding_mask is not None:
x = x.float().masked_fill(bool_padding_mask, 0.0)
x = F.silu(x)
x = self.second_pointwise_conv1d(x)
if bool_padding_mask is not None:
x = x.float().masked_fill(bool_padding_mask, 0.0)
x = self.dropout_module(x)
return x.transpose(1, 2)
class ConformerEncoderLayer(nn.Module):
"""
Conformer block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module
and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing
the original feed-forward layer in the Transformer block into two half-step feed-forward layers,
one before the attention layer and one after.
Args:
encoder_dim (int, optional): Dimension of conformer encoder
num_attention_heads (int, optional): Number of attention heads
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
attention_dropout_p (float, optional): Probability of attention module dropout
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
half_step_residual (bool): Flag indication whether to use half step residual or not
Inputs: inputs
x (time, batch, dim): Tensor containing input vector
Returns: outputs
**outputs** (batch, time, dim): Tensor produces by conformer block.
"""
def __init__(self, config: Speech2TextConformerConfig):
super().__init__()
self.encoder_dim = config.d_model
self.num_attention_heads = config.encoder_attention_heads
self.feed_forward_expansion_factor = config.feed_forward_expansion_factor
self.conv_expansion_factor = config.conv_expansion_factor
self.feed_forward_dropout_p = config.conformer_feedforward_dropout
self.attention_dropout_p = config.conformer_attention_dropout
self.conv_dropout_p = config.conformer_conv_dropout
self.conv_kernel_size = config.conformer_conv_kernel_size
self.half_step_residual = config.conformer_half_step_residual
self.no_syncbatchnorm = config.no_syncbatchnorm
self.batch_unsafe_relative_shift = getattr(config, 'batch_unsafe_relative_shift', False)
if self.half_step_residual:
self.feed_forward_residual_factor = 0.5
else:
self.feed_forward_residual_factor = 1
self.first_feed_forward = FeedForwardModule(
encoder_dim=self.encoder_dim,
expansion_factor=self.feed_forward_expansion_factor,
dropout_p=self.feed_forward_dropout_p,
)
self.attention = MultiHeadedSelfAttentionModule(
d_model=self.encoder_dim,
num_heads=self.num_attention_heads,
dropout_p=self.attention_dropout_p,
batch_unsafe_relative_shift=self.batch_unsafe_relative_shift,
)
self.conv_module = ConformerConvModule(
in_channels=self.encoder_dim,
kernel_size=self.conv_kernel_size,
expansion_factor=self.conv_expansion_factor,
dropout_p=self.conv_dropout_p,
no_syncbatchnorm=self.no_syncbatchnorm,
)
self.second_feed_forward = FeedForwardModule(
encoder_dim=self.encoder_dim,
expansion_factor=self.feed_forward_expansion_factor,
dropout_p=self.feed_forward_dropout_p,
)
self.layernorm = nn.LayerNorm(self.encoder_dim)
def forward(
self, x: Tensor, encoder_padding_mask: Tensor, output_attentions: bool = False
) -> Tuple[Tensor, Optional[Tensor]]:
x = x.transpose(0, 1) # B x T x C
new_x = self.first_feed_forward(x)
x = new_x * self.feed_forward_residual_factor + x
new_x, attn = self.attention(x, encoder_padding_mask, output_attentions)
x = new_x + x
new_x = self.conv_module(x, encoder_padding_mask)
x = new_x + x
new_x = self.second_feed_forward(x)
x = new_x * self.feed_forward_residual_factor + x
x = self.layernorm(x).transpose(1, 0)
return x, attn
class CTCCompressStrategy:
FIXED_RATIO = 4
@staticmethod
def new_lengths(batch_predicted):
return [len(p) for p in batch_predicted]
@staticmethod
def avg(prob_ctc, predicted, dtype, device):
new_lengths = CTCCompressStrategy.new_lengths(predicted)
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = 1.0 / same[1]
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix.to(device), new_lengths
@staticmethod
def weighted(prob_ctc, predicted, dtype, device):
new_lengths = CTCCompressStrategy.new_lengths(predicted)
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
# Get the probabilities of the prediction for the different time steps as weight
weights = prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]]
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
weights / weights.sum()
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix, new_lengths
@staticmethod
def softmax(prob_ctc, predicted, dtype, device):
new_lengths = CTCCompressStrategy.new_lengths(predicted)
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
# Get the probabilities of the prediction for the different time steps as weight
weights = F.softmax(prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]])
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
weights / weights.sum()
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix, new_lengths
@staticmethod
def fixed(prob_ctc, predicted, dtype, device):
new_maxlen = math.ceil(prob_ctc.shape[1] / CTCCompressStrategy.FIXED_RATIO)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype)
new_lengths = []
for b_idx, pred in enumerate(predicted):
original_len = sum(x[1] for x in pred)
new_len = 0
for new_t_idx in range(new_maxlen):
processed_inputs_cnt = new_t_idx * CTCCompressStrategy.FIXED_RATIO
processed_inputs_cnt_end = processed_inputs_cnt + CTCCompressStrategy.FIXED_RATIO
if processed_inputs_cnt_end > original_len:
processed_inputs_cnt_end = original_len
weights_matrix[b_idx, processed_inputs_cnt:processed_inputs_cnt_end, new_t_idx] = \
1.0 / (processed_inputs_cnt_end - processed_inputs_cnt)
new_len += 1
if processed_inputs_cnt_end == original_len:
break
new_lengths.append(new_len)
return weights_matrix.to(device), new_lengths
class ConformerEncoderDecoderPreTrainedModel(Speech2TextPreTrainedModel):
config_class = Speech2TextConformerConfig
class ConformerEncoder(ConformerEncoderDecoderPreTrainedModel):
"""
Conformer encoder consisting of *config.encoder_layers* layers. Each layer is a
[`ConformerEncoderLayer`].
Args:
config: Speech2TextConformerConfig
"""
def __init__(self, config: Speech2TextConformerConfig):
super().__init__(config)
self.dropout = config.dropout
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.subsample = Conv1dSubsampler(config)
self.layers = nn.ModuleList([ConformerEncoderLayer(config) for _ in range(config.encoder_layers)])
self.ctc_flag = False
if config.ctc_compress_strategy != "none":
self.ctc_flag = True
self.ctc_fc = nn.Linear(config.encoder_embed_dim, config.src_vocab_size)
self.ctc_layer = config.ctc_encoder_layer
self.ctc_compress_method = getattr(CTCCompressStrategy, config.ctc_compress_strategy)
self.ctc_compress_max_out_size = config.ctc_compress_max_out_size
CTCCompressStrategy.FIXED_RATIO = config.ctc_compress_fixed_ratio
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def ensure_max_ctc_out_len(self, batch_predicted):
"""
Ensures that the output of the CTC compression is not longer than the ctc_compress_max_out_size.
If there are samples violating this constraint, consecutive predictions are merged so to shorten the sentence.
E.g. if the ctc_compress_max_out_size is set to 3, and the output of the CTC compression would be
long 5, the first and second predictions are merged, as well as the third and the fourth. So, the
corresponding vectors will be merged according to the CTC compression strategy.
"""
if self.ctc_compress_max_out_size > 0:
def merge_sublist(elements):
"""
Takes a list of Tuples (predicted_element, num_corresponding_vectors) and returns
a single tuple with the predicted_element having the highest number of corresponding_vectors
(in case of a tie, the first is returned) and the total sum of the num_corresponding_vectors
E.g. if the input is [(a, 3), (b, 5), (c, 6), (a, 4)], the output will be (a, 18).
"""
sum_num_vectors = 0
max_element = None
max_element_cnt = 0
temp_dict = {}
for predicted_element, num_corresponding_vectors in elements:
if predicted_element in temp_dict:
temp_dict[predicted_element] += num_corresponding_vectors
else:
temp_dict[predicted_element] = num_corresponding_vectors
if temp_dict[predicted_element] > max_element_cnt:
max_element_cnt = temp_dict[predicted_element]
max_element = predicted_element
sum_num_vectors += num_corresponding_vectors
return max_element, sum_num_vectors
for b_idx, p in enumerate(batch_predicted):
pred_len = len(p)
if pred_len > self.ctc_compress_max_out_size:
reduction_factor = math.ceil(pred_len / self.ctc_compress_max_out_size)
i = 0
new_p = []
while i < pred_len:
new_p.append(merge_sublist(p[i:i + reduction_factor]))
i += reduction_factor
batch_predicted[b_idx] = new_p
return batch_predicted
def average_same_ctc_features(self, x_ctc, x, input_lengths):
with torch.no_grad():
batch_predicted = []
prob_ctc = F.softmax(x_ctc, dim=-1).transpose(0, 1) # from T x B x D to B x T x D
for b in range(prob_ctc.shape[0]):
predicted = prob_ctc[b][: input_lengths[b]].argmax(-1).tolist()
batch_predicted.append([(p[0], len(list(p[1]))) for p in groupby(predicted)])
batch_predicted = self.ensure_max_ctc_out_len(batch_predicted)
weights_matrix, new_lengths = self.ctc_compress_method(
prob_ctc, batch_predicted, x.dtype, x.device)
# x is T x B x C -> B x C x T; weights_matrix is B x T x T'
compressed_output = x.permute(1, 2, 0).bmm(weights_matrix) # B x C x T'
return compressed_output.permute(2, 0, 1), input_lengths.new(new_lengths)
@staticmethod
def lengths_to_padding_mask(lens: torch.LongTensor) -> Tensor:
bsz, max_lens = lens.size(0), torch.max(lens).item()
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
return mask
def apply_ctc(self, x, input_lengths):
x_ctc = self.ctc_fc(x)
x, input_lengths = self.average_same_ctc_features(x_ctc, x, input_lengths)
padding_mask = ConformerEncoder.lengths_to_padding_mask(input_lengths)
return x, x_ctc, padding_mask
def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the fbank features,
padding and conversion into a tensor of type `torch.FloatTensor`. See
[`~Speech2TextFeatureExtractor.__call__`]
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
`[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs_embeds = self.subsample(input_features, attention_mask)
inputs_embeds = self.embed_scale * inputs_embeds
# subsample attention mask if necessary
if attention_mask is not None:
attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[0], attention_mask)
hidden_states = nn.functional.dropout(inputs_embeds, p=self.dropout, training=self.training)
# expand attention_mask
if attention_mask is not None:
padding_mask = attention_mask.ne(1).long()
else:
padding_mask = torch.zeros(inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# TODO: implement head mask
assert head_mask is None, "Head masking is not yet implemented for Conformer model"
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states.transpose(0, 1),)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
padding_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
padding_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if self.ctc_flag and self.ctc_layer == idx + 1:
hidden_states, ctc_output, padding_mask = self.apply_ctc(hidden_states, attention_mask.sum(dim=1))
attention_mask = padding_mask.ne(1).long()
hidden_states = hidden_states.transpose(0, 1) # T x B x C -> B x T x C
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
@add_start_docstrings(
"The bare Conformer Model outputting raw hidden-states without any specific head on top.",
CONFORMER_START_DOCSTRING,
)
class ConformerEncoderDecoderModel(ConformerEncoderDecoderPreTrainedModel):
def __init__(self, config: Speech2TextConformerConfig):
super().__init__(config)
self.encoder = ConformerEncoder(config)
self.decoder = Speech2TextDecoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
r"""
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, AutoModel
>>> from datasets import load_dataset
>>> model = AutoModel.from_pretrained("FBK-MT/balbetto-asr-small-test")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("FBK-MT/balbetto-asr-small-test")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt"
... )
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 256]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# downsample encoder attention mask
if attention_mask is not None:
encoder_attention_mask = self._get_feature_vector_attention_mask(
encoder_outputs[0].shape[1], attention_mask
)
else:
encoder_attention_mask = None
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=encoder_attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"The Conformer Model with a language modeling head.",
CONFORMER_START_DOCSTRING,
)
class ConformerEncoderDecoderForConditionalGeneration(ConformerEncoderDecoderPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: Speech2TextConformerConfig):
super().__init__(config)
self.model = ConformerEncoderDecoderModel(config)
self.lm_head = nn.Linear(config.d_model, self.config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> import torch
>>> import transformers
>>> from datasets import load_dataset
>>> pipe = transformers.pipeline(
... "automatic-speech-recognition",
... model='FBK-MT/balbetto-asr-small-test',
... feature_extractor='FBK-MT/balbetto-asr-small-test',
... trust_remote_code=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> generated_ids = pipe(ds[0]["audio"])
>>> transcription = pipe.feature_extractor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel'
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
Speech2TextConformerConfig.register_for_auto_class()
ConformerEncoderDecoderForConditionalGeneration.register_for_auto_class("AutoModel")
ConformerEncoderDecoderForConditionalGeneration.register_for_auto_class("AutoModelForSpeechSeq2Seq")
transformers.AutoConfig.register("conformer_encoder_decoder", Speech2TextConformerConfig)
transformers.AutoModel.register(
Speech2TextConformerConfig, ConformerEncoderDecoderForConditionalGeneration)
transformers.AutoModelForSpeechSeq2Seq.register(
Speech2TextConformerConfig, ConformerEncoderDecoderForConditionalGeneration)
transformers.AutoProcessor.register(Speech2TextConformerConfig, Speech2TextProcessor)
transformers.models.auto.modeling_auto.MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES['conformer_encoder_decoder'] = \
"ConformerEncoderDecoderForConditionalGeneration"
transformers.TOKENIZER_MAPPING.register(Speech2TextConformerConfig, (Speech2TextTokenizer, None))
transformers.FEATURE_EXTRACTOR_MAPPING.register(Speech2TextConformerConfig, Speech2TextFeatureExtractor)