from transformers import ( WhisperForConditionalGeneration, WhisperProcessor, PreTrainedModel, WhisperConfig, ) from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer from transformers.modeling_outputs import BaseModelOutput import torch.nn.functional as F import torch.nn as nn import torch import os from dataclasses import dataclass from typing import Optional import json @dataclass class CustomModelOutput(BaseModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None head_preds: torch.FloatTensor = None labels_head: Optional[torch.FloatTensor] = None whisper_logits: torch.FloatTensor = None preds: Optional[torch.Tensor] = None # Define a new head (e.g., a classification layer) class LinearHead(nn.Module): def __init__(self, input_dim, output_dim): super(LinearHead, self).__init__() self.linear = nn.Linear(input_dim, output_dim) def forward(self, x): return self.linear(x) class FCNN(nn.Module): def __init__(self, input_dim, output_dim): super(FCNN, self).__init__() hidden_dim = 2 * input_dim self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return x class WhiStress(PreTrainedModel): config_class = WhisperConfig model_input_names = ["input_features", "labels_head", "whisper_labels"] def __init__( self, config: WhisperConfig, layer_for_head: Optional[int] = None, whisper_backbone_name="openai/whisper-small.en", ): super().__init__(config) self.whisper_backbone_name = whisper_backbone_name self.whisper_model = WhisperForConditionalGeneration.from_pretrained( self.whisper_backbone_name, ).eval() self.processor = WhisperProcessor.from_pretrained(self.whisper_backbone_name) input_dim = self.whisper_model.config.d_model # Model's hidden size output_dim = 2 # Number of classes or output features for the new head config = self.whisper_model.config # add additional decoder block using the existing Whisper config self.additional_decoder_block = WhisperDecoderLayer(config) self.classifier = FCNN(input_dim, output_dim) # add weighted loss for CE neg_weight = 1.0 pos_weight = 0.7 / 0.3 class_weights = torch.tensor([neg_weight, pos_weight]) self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights) self.layer_for_head = -1 if layer_for_head is None else layer_for_head def to(self, device: str = ("cuda" if torch.cuda.is_available() else "cpu")): self.whisper_model.to(device) self.additional_decoder_block.to(device) self.classifier.to(device) super().to(device) return self def load_model(self, save_dir=None): # load only the classifier and extra decoder layer (saved locally) if save_dir is not None: print('loading model from:', save_dir) device = "cuda" if torch.cuda.is_available() else "cpu" self.classifier.load_state_dict( torch.load( os.path.join(save_dir, "classifier.pt"), weights_only=False, map_location=torch.device(device), ) ) self.additional_decoder_block.load_state_dict( torch.load( os.path.join(save_dir, "additional_decoder_block.pt"), weights_only=False, map_location=torch.device(device), ) ) # read and load the layer_for_head.json # the json format is {"layer_for_head": 9} with open(os.path.join(save_dir, "metadata.json"), "r") as f: metadata = json.load(f) self.layer_for_head = metadata["layer_for_head"] return def train(self, mode: Optional[bool] = True): # freeze whisper and train classifier self.whisper_model.eval() # mark whisper model requires grad false for param in self.whisper_model.parameters(): param.requires_grad = False for param in self.additional_decoder_block.parameters(): param.requires_grad = True for param in self.classifier.parameters(): param.requires_grad = True self.additional_decoder_block.train() self.classifier.train() def eval(self): self.whisper_model.eval() self.additional_decoder_block.eval() self.classifier.eval() def forward( self, input_features, attention_mask=None, decoder_input_ids=None, labels_head=None, whisper_labels=None, ): device = "cuda" if torch.cuda.is_available() else "cpu" self.whisper_model.eval() # pass the inputs through the model backbone_outputs = self.whisper_model( input_features=input_features, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=True, labels=whisper_labels, ) # Extract the hidden states of the last layer of the decoder decoder_last_layer_hidden_states = backbone_outputs.decoder_hidden_states[ self.layer_for_head ].to(device) # Extract the hidden states of the layer of the encoder who encapsulates best the prosodic features layer_for_head_hidden_states = backbone_outputs.encoder_hidden_states[ self.layer_for_head ].to(device) # Pass the decoder last hidden layers through the new head (decoder_block + lin cls) additional_decoder_block_outputs = self.additional_decoder_block( hidden_states=decoder_last_layer_hidden_states, encoder_hidden_states=layer_for_head_hidden_states, ) head_logits = self.classifier(additional_decoder_block_outputs[0].to(device)) # calculate softmax head_probs = F.softmax(head_logits, dim=-1) preds = head_probs.argmax(dim=-1).to(device) if labels_head is not None: preds = torch.where( torch.isin( labels_head, torch.tensor(list([-100])).to(device) # 50257, 50362, ), torch.tensor(-100), preds, ) # Calculate custom loss if labels are provided loss = None if labels_head is not None: # CrossEntropyLoss for the custom head loss = self.loss_fct( head_logits.reshape(-1, head_logits.size(-1)), labels_head.reshape(-1) ) return CustomModelOutput( logits=head_logits, labels_head=labels_head, whisper_logits=backbone_outputs.logits, loss=loss, preds=preds, ) def generate( self, input_features, max_length=128, labels_head=None, whisper_labels=None, **generate_kwargs, ): """ Generate both the Whisper output and custom head output sequences in alignment. """ device = "cuda" if torch.cuda.is_available() else "cpu" # Generate the Whisper output sequence whisper_outputs = self.whisper_model.generate( input_features=input_features, max_length=max_length, labels=whisper_labels, do_sample=False, **generate_kwargs, ) # pass the inputs through the model backbone_outputs = self.whisper_model( input_features=input_features, decoder_input_ids=whisper_outputs, output_hidden_states=True, ) # Extract the hidden states of the last layer of the decoder decoder_last_layer_hidden_states = backbone_outputs.decoder_hidden_states[ self.layer_for_head ].to(device) # Extract the hidden states of the last layer of the encoder layer_for_head_hidden_states = backbone_outputs.encoder_hidden_states[ self.layer_for_head ].to(device) # Pass the decoder last hidden layers through the new head (decoder_block + lin cls) additional_decoder_block_outputs = self.additional_decoder_block( hidden_states=decoder_last_layer_hidden_states, encoder_hidden_states=layer_for_head_hidden_states, ) head_logits = self.classifier(additional_decoder_block_outputs[0].to(device)) # calculate softmax head_probs = F.softmax(head_logits, dim=-1) preds = head_probs.argmax(dim=-1).to(device) preds = torch.where( torch.isin( whisper_outputs, torch.tensor(list([50256])).to(device) # 50257, 50362, ), torch.tensor(-100), preds, ) # preds_shifted = torch.cat((preds[:, 1:], preds[:, :1]), dim=1) return preds def generate_dual( self, input_features, attention_mask=None, max_length=200, labels_head=None, whisper_labels=None, **generate_kwargs, ): """ Generate both the Whisper output and custom head output sequences in alignment. """ device = "cuda" if torch.cuda.is_available() else "cpu" # Generate the Whisper output sequence whisper_outputs = self.whisper_model.generate( input_features=input_features, attention_mask=attention_mask, max_length=max_length, labels=whisper_labels, return_dict_in_generate=True, **generate_kwargs, ) # pass the inputs through the model backbone_outputs = self.whisper_model( input_features=input_features, attention_mask=attention_mask, decoder_input_ids=whisper_outputs.sequences, output_hidden_states=True, ) # Extract the hidden states of the last layer of the decoder decoder_last_layer_hidden_states = backbone_outputs.decoder_hidden_states[ self.layer_for_head ].to(device) # Extract the hidden states of the last layer of the encoder layer_for_head_hidden_states = backbone_outputs.encoder_hidden_states[ self.layer_for_head ].to(device) # Pass the decoder last hidden layers through the new head (decoder_block + lin cls) additional_decoder_block_outputs = self.additional_decoder_block( hidden_states=decoder_last_layer_hidden_states, encoder_hidden_states=layer_for_head_hidden_states, ) head_logits = self.classifier(additional_decoder_block_outputs[0].to(device)) head_probs = F.softmax(head_logits, dim=-1) preds = head_probs.argmax(dim=-1).to(device) preds = torch.where( torch.isin( whisper_outputs.sequences, torch.tensor(list([50256])).to(device) # 50257, 50362, ), torch.tensor(-100), preds, ) return CustomModelOutput( logits=head_logits, head_preds=preds, whisper_logits=whisper_outputs.logits, preds=whisper_outputs.sequences ) def __str__(self): return "WhiStress"