Spaces:
Running
Running
from transformers import ( | |
WhisperForConditionalGeneration, | |
WhisperProcessor, | |
PreTrainedModel, | |
WhisperConfig, | |
) | |
from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer | |
from transformers.modeling_outputs import BaseModelOutput | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import torch | |
import os | |
from dataclasses import dataclass | |
from typing import Optional | |
import json | |
class CustomModelOutput(BaseModelOutput): | |
loss: Optional[torch.FloatTensor] = None | |
logits: torch.FloatTensor = None | |
head_preds: torch.FloatTensor = None | |
labels_head: Optional[torch.FloatTensor] = None | |
whisper_logits: torch.FloatTensor = None | |
preds: Optional[torch.Tensor] = None | |
# Define a new head (e.g., a classification layer) | |
class LinearHead(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super(LinearHead, self).__init__() | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, x): | |
return self.linear(x) | |
class FCNN(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super(FCNN, self).__init__() | |
hidden_dim = 2 * input_dim | |
self.fc1 = nn.Linear(input_dim, hidden_dim) | |
self.fc2 = nn.Linear(hidden_dim, output_dim) | |
def forward(self, x): | |
x = F.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
class WhiStress(PreTrainedModel): | |
config_class = WhisperConfig | |
model_input_names = ["input_features", "labels_head", "whisper_labels"] | |
def __init__( | |
self, | |
config: WhisperConfig, | |
layer_for_head: Optional[int] = None, | |
whisper_backbone_name="openai/whisper-small.en", | |
): | |
super().__init__(config) | |
self.whisper_backbone_name = whisper_backbone_name | |
self.whisper_model = WhisperForConditionalGeneration.from_pretrained( | |
self.whisper_backbone_name, | |
).eval() | |
self.processor = WhisperProcessor.from_pretrained(self.whisper_backbone_name) | |
input_dim = self.whisper_model.config.d_model # Model's hidden size | |
output_dim = 2 # Number of classes or output features for the new head | |
config = self.whisper_model.config | |
# add additional decoder block using the existing Whisper config | |
self.additional_decoder_block = WhisperDecoderLayer(config) | |
self.classifier = FCNN(input_dim, output_dim) | |
# add weighted loss for CE | |
neg_weight = 1.0 | |
pos_weight = 0.7 / 0.3 | |
class_weights = torch.tensor([neg_weight, pos_weight]) | |
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights) | |
self.layer_for_head = -1 if layer_for_head is None else layer_for_head | |
def to(self, device: str = ("cuda" if torch.cuda.is_available() else "cpu")): | |
self.whisper_model.to(device) | |
self.additional_decoder_block.to(device) | |
self.classifier.to(device) | |
super().to(device) | |
return self | |
def load_model(self, save_dir=None): | |
# load only the classifier and extra decoder layer (saved locally) | |
if save_dir is not None: | |
print('loading model from:', save_dir) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.classifier.load_state_dict( | |
torch.load( | |
os.path.join(save_dir, "classifier.pt"), | |
weights_only=False, | |
map_location=torch.device(device), | |
) | |
) | |
self.additional_decoder_block.load_state_dict( | |
torch.load( | |
os.path.join(save_dir, "additional_decoder_block.pt"), | |
weights_only=False, | |
map_location=torch.device(device), | |
) | |
) | |
# read and load the layer_for_head.json | |
# the json format is {"layer_for_head": 9} | |
with open(os.path.join(save_dir, "metadata.json"), "r") as f: | |
metadata = json.load(f) | |
self.layer_for_head = metadata["layer_for_head"] | |
return | |
def train(self, mode: Optional[bool] = True): | |
# freeze whisper and train classifier | |
self.whisper_model.eval() | |
# mark whisper model requires grad false | |
for param in self.whisper_model.parameters(): | |
param.requires_grad = False | |
for param in self.additional_decoder_block.parameters(): | |
param.requires_grad = True | |
for param in self.classifier.parameters(): | |
param.requires_grad = True | |
self.additional_decoder_block.train() | |
self.classifier.train() | |
def eval(self): | |
self.whisper_model.eval() | |
self.additional_decoder_block.eval() | |
self.classifier.eval() | |
def forward( | |
self, | |
input_features, | |
attention_mask=None, | |
decoder_input_ids=None, | |
labels_head=None, | |
whisper_labels=None, | |
): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.whisper_model.eval() | |
# pass the inputs through the model | |
backbone_outputs = self.whisper_model( | |
input_features=input_features, | |
attention_mask=attention_mask, | |
decoder_input_ids=decoder_input_ids, | |
output_hidden_states=True, | |
labels=whisper_labels, | |
) | |
# Extract the hidden states of the last layer of the decoder | |
decoder_last_layer_hidden_states = backbone_outputs.decoder_hidden_states[ | |
self.layer_for_head | |
].to(device) | |
# Extract the hidden states of the layer of the encoder who encapsulates best the prosodic features | |
layer_for_head_hidden_states = backbone_outputs.encoder_hidden_states[ | |
self.layer_for_head | |
].to(device) | |
# Pass the decoder last hidden layers through the new head (decoder_block + lin cls) | |
additional_decoder_block_outputs = self.additional_decoder_block( | |
hidden_states=decoder_last_layer_hidden_states, | |
encoder_hidden_states=layer_for_head_hidden_states, | |
) | |
head_logits = self.classifier(additional_decoder_block_outputs[0].to(device)) | |
# calculate softmax | |
head_probs = F.softmax(head_logits, dim=-1) | |
preds = head_probs.argmax(dim=-1).to(device) | |
if labels_head is not None: | |
preds = torch.where( | |
torch.isin( | |
labels_head, torch.tensor(list([-100])).to(device) # 50257, 50362, | |
), | |
torch.tensor(-100), | |
preds, | |
) | |
# Calculate custom loss if labels are provided | |
loss = None | |
if labels_head is not None: | |
# CrossEntropyLoss for the custom head | |
loss = self.loss_fct( | |
head_logits.reshape(-1, head_logits.size(-1)), labels_head.reshape(-1) | |
) | |
return CustomModelOutput( | |
logits=head_logits, | |
labels_head=labels_head, | |
whisper_logits=backbone_outputs.logits, | |
loss=loss, | |
preds=preds, | |
) | |
def generate( | |
self, | |
input_features, | |
max_length=128, | |
labels_head=None, | |
whisper_labels=None, | |
**generate_kwargs, | |
): | |
""" | |
Generate both the Whisper output and custom head output sequences in alignment. | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Generate the Whisper output sequence | |
whisper_outputs = self.whisper_model.generate( | |
input_features=input_features, | |
max_length=max_length, | |
labels=whisper_labels, | |
do_sample=False, | |
**generate_kwargs, | |
) | |
# pass the inputs through the model | |
backbone_outputs = self.whisper_model( | |
input_features=input_features, | |
decoder_input_ids=whisper_outputs, | |
output_hidden_states=True, | |
) | |
# Extract the hidden states of the last layer of the decoder | |
decoder_last_layer_hidden_states = backbone_outputs.decoder_hidden_states[ | |
self.layer_for_head | |
].to(device) | |
# Extract the hidden states of the last layer of the encoder | |
layer_for_head_hidden_states = backbone_outputs.encoder_hidden_states[ | |
self.layer_for_head | |
].to(device) | |
# Pass the decoder last hidden layers through the new head (decoder_block + lin cls) | |
additional_decoder_block_outputs = self.additional_decoder_block( | |
hidden_states=decoder_last_layer_hidden_states, | |
encoder_hidden_states=layer_for_head_hidden_states, | |
) | |
head_logits = self.classifier(additional_decoder_block_outputs[0].to(device)) | |
# calculate softmax | |
head_probs = F.softmax(head_logits, dim=-1) | |
preds = head_probs.argmax(dim=-1).to(device) | |
preds = torch.where( | |
torch.isin( | |
whisper_outputs, torch.tensor(list([50256])).to(device) # 50257, 50362, | |
), | |
torch.tensor(-100), | |
preds, | |
) | |
# preds_shifted = torch.cat((preds[:, 1:], preds[:, :1]), dim=1) | |
return preds | |
def generate_dual( | |
self, | |
input_features, | |
attention_mask=None, | |
max_length=200, | |
labels_head=None, | |
whisper_labels=None, | |
**generate_kwargs, | |
): | |
""" | |
Generate both the Whisper output and custom head output sequences in alignment. | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Generate the Whisper output sequence | |
whisper_outputs = self.whisper_model.generate( | |
input_features=input_features, | |
attention_mask=attention_mask, | |
max_length=max_length, | |
labels=whisper_labels, | |
return_dict_in_generate=True, | |
**generate_kwargs, | |
) | |
# pass the inputs through the model | |
backbone_outputs = self.whisper_model( | |
input_features=input_features, | |
attention_mask=attention_mask, | |
decoder_input_ids=whisper_outputs.sequences, | |
output_hidden_states=True, | |
) | |
# Extract the hidden states of the last layer of the decoder | |
decoder_last_layer_hidden_states = backbone_outputs.decoder_hidden_states[ | |
self.layer_for_head | |
].to(device) | |
# Extract the hidden states of the last layer of the encoder | |
layer_for_head_hidden_states = backbone_outputs.encoder_hidden_states[ | |
self.layer_for_head | |
].to(device) | |
# Pass the decoder last hidden layers through the new head (decoder_block + lin cls) | |
additional_decoder_block_outputs = self.additional_decoder_block( | |
hidden_states=decoder_last_layer_hidden_states, | |
encoder_hidden_states=layer_for_head_hidden_states, | |
) | |
head_logits = self.classifier(additional_decoder_block_outputs[0].to(device)) | |
head_probs = F.softmax(head_logits, dim=-1) | |
preds = head_probs.argmax(dim=-1).to(device) | |
preds = torch.where( | |
torch.isin( | |
whisper_outputs.sequences, torch.tensor(list([50256])).to(device) # 50257, 50362, | |
), | |
torch.tensor(-100), | |
preds, | |
) | |
return CustomModelOutput( | |
logits=head_logits, | |
head_preds=preds, | |
whisper_logits=whisper_outputs.logits, | |
preds=whisper_outputs.sequences | |
) | |
def __str__(self): | |
return "WhiStress" | |