Spaces:
Runtime error
Runtime error
import torch | |
import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2 | |
from torch import nn | |
from transformers import Wav2Vec2Model | |
class Wav2Vec2EncoderLayer(nn.Module): | |
def __init__( | |
self, | |
config, | |
i | |
): | |
super().__init__() | |
self.attention = w2v2.Wav2Vec2Attention( | |
embed_dim=config.hidden_size, | |
num_heads=config.num_attention_heads, | |
dropout=config.attention_dropout, | |
is_decoder=False, | |
) | |
self.dropout = nn.Dropout(config.hidden_dropout) | |
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
self.feed_forward = w2v2.Wav2Vec2FeedForward(config) | |
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
self.config = config | |
self.i = i | |
def forward(self, hidden_states, attention_mask=None, output_attentions=False): | |
attn_residual = hidden_states | |
hidden_states, attn_weights, _ = self.attention( | |
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions | |
) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = attn_residual + hidden_states | |
hidden_states = self.layer_norm(hidden_states) | |
hidden_states = hidden_states + self.feed_forward(hidden_states) | |
hidden_states = self.final_layer_norm(hidden_states) | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (attn_weights,) | |
return outputs | |
class Wav2VecWrapper(nn.Module): | |
def __init__( | |
self, | |
config, | |
): | |
super(Wav2VecWrapper, self).__init__() | |
self.config = config | |
self.backbone_model = Wav2Vec2Model.from_pretrained( | |
config._name_or_path, | |
output_hidden_states=config.output_hidden_states, | |
) | |
state_dict = self.backbone_model.state_dict() | |
self.model_config = self.backbone_model.config | |
self.backbone_model.encoder.layers = nn.ModuleList([Wav2Vec2EncoderLayer(self.model_config, i) for i in range(self.model_config.num_hidden_layers)]) | |
def forward(self, | |
input_features: torch.Tensor, | |
length: torch.Tensor = None, | |
): | |
with torch.no_grad(): | |
hidden_states = self.backbone_model.feature_extractor(input_features) | |
hidden_states = hidden_states.transpose(1, 2) | |
hidden_states, _ = self.backbone_model.feature_projection(hidden_states) | |
if length is not None: | |
length = self.get_feat_extract_output_lengths(length.detach().cpu()) | |
hidden_states = self.backbone_model.encoder( | |
hidden_states, | |
output_hidden_states=self.config.output_hidden_states | |
).hidden_states | |
return {'encoder_hidden_states': hidden_states, 'length': length} | |
def get_feat_extract_output_lengths(self, input_length): | |
def _conv_out_length(input_length, kernel_size, stride): | |
return (input_length - kernel_size) // stride + 1 | |
for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride): | |
input_length = _conv_out_length(input_length, kernel_size, stride) | |
return input_length | |
def prepare_mask(length, shape, dtype): | |
mask = torch.zeros( | |
shape, dtype=dtype | |
) | |
mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1 | |
mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool() | |
return mask |