Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from typing import Optional, Union, Tuple | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
from .wav2vec2_wrapper import Wav2VecWrapper | |
from .multilevel_classifier import MultiLevelDownstreamModel | |
class CustomModelForAudioClassification(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
assert config.output_hidden_states == True, "The upstream model must return all hidden states" | |
self.config = config | |
self.encoder = Wav2VecWrapper(config) | |
self.classifier = MultiLevelDownstreamModel(config, use_conv_output=True) | |
def forward( | |
self, | |
input_features: Optional[torch.LongTensor], | |
length: Optional[torch.LongTensor] = None, | |
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: | |
if encoder_outputs is None: | |
encoder_output = self.encoder( | |
input_features, | |
length=length, | |
) | |
logits = self.classifier(**encoder_output) | |
loss = None | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=encoder_output['encoder_hidden_states'] | |
) |