iyosha commited on
Commit
73c9c96
·
verified ·
1 Parent(s): 6c0f20f

Upload 12 files

Browse files
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from pathlib import Path
4
+ from whistress import WhiStressInferenceClient
5
+
6
+ CURRENT_DIR = Path(__file__).parent
7
+ # Load the model
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model = WhiStressInferenceClient(device=device)
10
+
11
+
12
+ def get_whistress_predictions(audio):
13
+ """
14
+ Get the transcription and emphasis scores for the given audio input.
15
+ Args:
16
+ audio (sr, numpy.ndarray): The audio input as a NumPy array.
17
+ Returns:
18
+ List[Tuple[str, int]]: A list of tuples containing words and their emphasis scores.
19
+ """
20
+ audio = {
21
+ "sampling_rate": audio[0],
22
+ "array": audio[1],
23
+ }
24
+ return model.predict(audio=audio, transcription=None, return_pairs=True)
25
+
26
+
27
+ # App UI
28
+ with gr.Blocks() as demo:
29
+ with gr.Row():
30
+ with gr.Column(scale=2):
31
+ gr.Markdown(
32
+ """
33
+ # WhiStress: Enriching Transcriptions with Sentence Stress Detection
34
+ The WhiStress model allows you to detect important words in your speech.
35
+
36
+ Check out our paper: 📚 [WhiStress](https://arxiv.org/),
37
+
38
+ ## Architecture
39
+ The model is built on [Whisper](https://arxiv.org/abs/2212.04356) model,
40
+ using `whisper-small.en` [model](https://huggingface.co/openai/whisper-small.en)
41
+ as the backbone.
42
+ WhiStress includes an additional decoder based classifier that predicts the stress label of each transcription token.
43
+
44
+ ## Training Data
45
+ The model was trained using [TinyStress-15K](https://huggingface.co/datasets/loud-whisper-project/tinyStories-audio-emphasized),
46
+ that is derived from [tinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset.
47
+
48
+ ## Inference Demo
49
+ Upload an audio file or record your own voice to transcribe the speech and emphasize the important words.
50
+
51
+ For maximal performance, please speak clearly.
52
+ """
53
+ )
54
+ with gr.Column(scale=1):
55
+ # Define Gradio interface for displaying image with HTML component
56
+ gr.Image(
57
+ f"{CURRENT_DIR}/assets/whistress_model.svg",
58
+ label="Architecture",
59
+ )
60
+
61
+ gr.Interface(
62
+ get_whistress_predictions,
63
+ gr.Audio(
64
+ sources=["microphone", "upload"],
65
+ label="Upload speech or record your own",
66
+ type="numpy",
67
+ ),
68
+ gr.HighlightedText(),
69
+ allow_flagging="never",
70
+ )
71
+
72
+
73
+ def launch():
74
+ demo.launch()
75
+
76
+
77
+ if __name__ == "__main__":
78
+ launch()
assets/whistress_model.svg ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchaudio==2.1.0
3
+ torchlibrosa==0.1.0
4
+ librosa==0.10.2.post1
5
+ transformers==4.44.0
6
+ numpy==1.26.4
7
+ gradio==5.31.0
whistress/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .inference_client import WhiStressInferenceClient
whistress/inference_client/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .whistress_client import WhiStressInferenceClient
whistress/inference_client/utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import WhisperConfig
3
+ import librosa
4
+ import numpy as np
5
+ import pathlib
6
+ from torch.nn import functional as F
7
+ from ..model import WhiStress
8
+
9
+
10
+ PATH_TO_WEIGHTS = pathlib.Path(__file__).parent.parent / "weights"
11
+
12
+
13
+ def get_loaded_model(device="cuda"):
14
+ whisper_model_name = f"openai/whisper-small.en"
15
+ whisper_config = WhisperConfig()
16
+ whistress_model = WhiStress(
17
+ whisper_config, layer_for_head=9, whisper_backbone_name=whisper_model_name
18
+ ).to(device)
19
+ whistress_model.processor.tokenizer.model_input_names = [
20
+ "input_ids",
21
+ "attention_mask",
22
+ "labels_head",
23
+ ]
24
+ whistress_model.load_model(PATH_TO_WEIGHTS)
25
+ whistress_model.to(device)
26
+ whistress_model.eval()
27
+ return whistress_model
28
+
29
+
30
+ def get_word_emphasis_pairs(
31
+ transcription_preds, emphasis_preds, processor, filter_special_tokens=True
32
+ ):
33
+ emphasis_preds_list = emphasis_preds.tolist()
34
+ transcription_preds_words = [
35
+ processor.tokenizer.decode([i], skip_special_tokens=False)
36
+ for i in transcription_preds
37
+ ]
38
+ if filter_special_tokens:
39
+ special_tokens_indices = [
40
+ i
41
+ for i, x in enumerate(transcription_preds)
42
+ if x in processor.tokenizer.all_special_ids
43
+ ]
44
+ emphasis_preds_list = [
45
+ x
46
+ for i, x in enumerate(emphasis_preds_list)
47
+ if i not in special_tokens_indices
48
+ ]
49
+ transcription_preds_words = [
50
+ x
51
+ for i, x in enumerate(transcription_preds_words)
52
+ if i not in special_tokens_indices
53
+ ]
54
+ return list(zip(transcription_preds_words, emphasis_preds_list))
55
+
56
+
57
+ def inference_from_audio(audio: np.ndarray, model: WhiStress, device: str):
58
+ input_features = model.processor.feature_extractor(
59
+ audio, sampling_rate=16000, return_tensors="pt"
60
+ )["input_features"]
61
+ out_model = model.generate_dual(input_features=input_features.to(device))
62
+ emphasis_probs = F.softmax(out_model.logits, dim=-1)
63
+ emphasis_preds = torch.argmax(emphasis_probs, dim=-1)
64
+ emphasis_preds_right_shifted = torch.cat((emphasis_preds[:, -1:], emphasis_preds[:, :-1]), dim=1)
65
+ word_emphasis_pairs = get_word_emphasis_pairs(
66
+ out_model.preds[0],
67
+ emphasis_preds_right_shifted[0],
68
+ model.processor,
69
+ filter_special_tokens=True,
70
+ )
71
+ return word_emphasis_pairs
72
+
73
+
74
+ def prepare_audio(audio, target_sr=16000):
75
+ # resample to 16kHz
76
+ sr = audio["sampling_rate"]
77
+ y = audio["array"]
78
+ y = np.array(y, dtype=float)
79
+ y_resampled = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
80
+ # Normalize the audio (scale to [-1, 1])
81
+ y_resampled /= max(abs(y_resampled))
82
+ return y_resampled
83
+
84
+
85
+ def merge_stressed_tokens(tokens_with_stress):
86
+ """
87
+ tokens_with_stress is a list of tuples: (token_string, stress_value)
88
+ e.g.:
89
+ [(" I", 0), (" didn", 1), ("'t", 0), (" say", 0), (" he", 0), (" stole", 0),
90
+ (" the", 0), (" money", 0), (".", 0)]
91
+ Returns a list of merged tuples, combining subwords into full words.
92
+ """
93
+ merged = []
94
+
95
+ current_word = ""
96
+ current_stress = 0 # 0 means not stressed, 1 means stressed
97
+
98
+ for token, stress in tokens_with_stress:
99
+ # If token starts with a space (or is the very first), we treat it as a new word
100
+ # or if current_word is empty (first iteration).
101
+ if token.startswith(" ") or current_word == "":
102
+ # If we already have something in current_word, push it into merged
103
+ # before starting a new one
104
+ if current_word:
105
+ merged.append((current_word, current_stress))
106
+
107
+ # Start a new word
108
+ current_word = token
109
+ current_stress = stress
110
+ else:
111
+ # Otherwise, it's a subword that should be appended to the previous word
112
+ current_word += token
113
+ # If any sub-token is stressed, the whole merged word is stressed
114
+ current_stress = max(current_stress, stress)
115
+
116
+ # Don't forget to append the final word
117
+ if current_word:
118
+ merged.append((current_word, current_stress))
119
+
120
+ return merged
121
+
122
+
123
+ def inference_from_audio_and_transcription(
124
+ audio: np.ndarray, transcription, model: WhiStress, device: str
125
+ ):
126
+ input_features = model.processor.feature_extractor(
127
+ audio, sampling_rate=16000, return_tensors="pt"
128
+ )["input_features"]
129
+ # convert transcription to input_ids
130
+ input_ids = model.processor.tokenizer(
131
+ transcription,
132
+ return_tensors="pt",
133
+ padding="max_length",
134
+ truncation=True,
135
+ max_length=30,
136
+ )["input_ids"]
137
+ out_model = model(
138
+ input_features=input_features.to(device),
139
+ decoder_input_ids=input_ids.to(device),
140
+ )
141
+ emphasis_probs = F.softmax(out_model.logits, dim=-1)
142
+ emphasis_preds = torch.argmax(emphasis_probs, dim=-1)
143
+ emphasis_preds_right_shifted = torch.cat((emphasis_preds[:, -1:], emphasis_preds[:, :-1]), dim=1)
144
+ word_emphasis_pairs = get_word_emphasis_pairs(
145
+ input_ids[0],
146
+ emphasis_preds_right_shifted[0],
147
+ model.processor,
148
+ filter_special_tokens=True,
149
+ )
150
+ return word_emphasis_pairs
151
+
152
+ def scored_transcription(audio, model, strip_words=True, transcription: str = None, device="cuda"):
153
+ audio_arr = prepare_audio(audio)
154
+ token_stress_pairs = None
155
+ if transcription: # if we want to use the ground truth transcription
156
+ token_stress_pairs = inference_from_audio_and_transcription(audio_arr, transcription, model, device)
157
+ else:
158
+ token_stress_pairs = inference_from_audio(audio_arr, model, device)
159
+ # token_stress_pairs = inference_from_audio(audio_arr, model)
160
+ word_level_stress = merge_stressed_tokens(token_stress_pairs)
161
+ if strip_words:
162
+ word_level_stress = [(word.strip(), stress) for word, stress in word_level_stress]
163
+ return word_level_stress
whistress/inference_client/whistress_client.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from .utils import get_loaded_model, scored_transcription
3
+ from typing import Union, Dict
4
+
5
+
6
+ class WhiStressInferenceClient:
7
+ def __init__(self, device="cuda"):
8
+ self.device = device
9
+ self.whistress = get_loaded_model(self.device)
10
+
11
+ def predict(
12
+ self, audio: Dict[str, Union[np.ndarray, int]], transcription=None, return_pairs=True
13
+ ):
14
+ word_emphasis_pairs = scored_transcription(
15
+ audio=audio,
16
+ model=self.whistress,
17
+ device=self.device,
18
+ strip_words=True,
19
+ transcription=transcription
20
+ )
21
+ if return_pairs:
22
+ return word_emphasis_pairs
23
+ # returs transcription str and list of emphasized words
24
+ return " ".join([x[0] for x in word_emphasis_pairs]), [
25
+ x[0] for x in word_emphasis_pairs if x[1] == 1
26
+ ]
whistress/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import WhiStress
whistress/model/model.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ WhisperForConditionalGeneration,
3
+ WhisperProcessor,
4
+ PreTrainedModel,
5
+ WhisperConfig,
6
+ )
7
+ from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer
8
+ from transformers.modeling_outputs import BaseModelOutput
9
+ import torch.nn.functional as F
10
+ import torch.nn as nn
11
+ import torch
12
+ import os
13
+ from dataclasses import dataclass
14
+ from typing import Optional
15
+ import json
16
+
17
+
18
+ @dataclass
19
+ class CustomModelOutput(BaseModelOutput):
20
+ loss: Optional[torch.FloatTensor] = None
21
+ logits: torch.FloatTensor = None
22
+ head_preds: torch.FloatTensor = None
23
+ labels_head: Optional[torch.FloatTensor] = None
24
+ whisper_logits: torch.FloatTensor = None
25
+ preds: Optional[torch.Tensor] = None
26
+
27
+
28
+ # Define a new head (e.g., a classification layer)
29
+ class LinearHead(nn.Module):
30
+ def __init__(self, input_dim, output_dim):
31
+ super(LinearHead, self).__init__()
32
+ self.linear = nn.Linear(input_dim, output_dim)
33
+
34
+ def forward(self, x):
35
+ return self.linear(x)
36
+
37
+
38
+ class FCNN(nn.Module):
39
+ def __init__(self, input_dim, output_dim):
40
+ super(FCNN, self).__init__()
41
+ hidden_dim = 2 * input_dim
42
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
43
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
44
+
45
+ def forward(self, x):
46
+ x = F.relu(self.fc1(x))
47
+ x = self.fc2(x)
48
+ return x
49
+
50
+
51
+ class WhiStress(PreTrainedModel):
52
+
53
+ config_class = WhisperConfig
54
+ model_input_names = ["input_features", "labels_head", "whisper_labels"]
55
+
56
+ def __init__(
57
+ self,
58
+ config: WhisperConfig,
59
+ layer_for_head: Optional[int] = None,
60
+ whisper_backbone_name="openai/whisper-small.en",
61
+ ):
62
+ super().__init__(config)
63
+ self.whisper_backbone_name = whisper_backbone_name
64
+ self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
65
+ self.whisper_backbone_name,
66
+ ).eval()
67
+ self.processor = WhisperProcessor.from_pretrained(self.whisper_backbone_name)
68
+
69
+ input_dim = self.whisper_model.config.d_model # Model's hidden size
70
+ output_dim = 2 # Number of classes or output features for the new head
71
+
72
+ config = self.whisper_model.config
73
+ # add additional decoder block using the existing Whisper config
74
+ self.additional_decoder_block = WhisperDecoderLayer(config)
75
+ self.classifier = FCNN(input_dim, output_dim)
76
+ # add weighted loss for CE
77
+ neg_weight = 1.0
78
+ pos_weight = 0.7 / 0.3
79
+ class_weights = torch.tensor([neg_weight, pos_weight])
80
+ self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights)
81
+ self.layer_for_head = -1 if layer_for_head is None else layer_for_head
82
+
83
+ def to(self, device: str = ("cuda" if torch.cuda.is_available() else "cpu")):
84
+ self.whisper_model.to(device)
85
+ self.additional_decoder_block.to(device)
86
+ self.classifier.to(device)
87
+ super().to(device)
88
+ return self
89
+
90
+ def load_model(self, save_dir=None):
91
+ # load only the classifier and extra decoder layer (saved locally)
92
+ if save_dir is not None:
93
+ print('loading model from:', save_dir)
94
+ device = "cuda" if torch.cuda.is_available() else "cpu"
95
+ self.classifier.load_state_dict(
96
+ torch.load(
97
+ os.path.join(save_dir, "classifier.pt"),
98
+ weights_only=False,
99
+ map_location=torch.device(device),
100
+ )
101
+ )
102
+ self.additional_decoder_block.load_state_dict(
103
+ torch.load(
104
+ os.path.join(save_dir, "additional_decoder_block.pt"),
105
+ weights_only=False,
106
+ map_location=torch.device(device),
107
+ )
108
+ )
109
+ # read and load the layer_for_head.json
110
+ # the json format is {"layer_for_head": 9}
111
+ with open(os.path.join(save_dir, "metadata.json"), "r") as f:
112
+ metadata = json.load(f)
113
+ self.layer_for_head = metadata["layer_for_head"]
114
+ return
115
+
116
+ def train(self, mode: Optional[bool] = True):
117
+ # freeze whisper and train classifier
118
+ self.whisper_model.eval()
119
+ # mark whisper model requires grad false
120
+ for param in self.whisper_model.parameters():
121
+ param.requires_grad = False
122
+ for param in self.additional_decoder_block.parameters():
123
+ param.requires_grad = True
124
+ for param in self.classifier.parameters():
125
+ param.requires_grad = True
126
+ self.additional_decoder_block.train()
127
+ self.classifier.train()
128
+
129
+ def eval(self):
130
+ self.whisper_model.eval()
131
+ self.additional_decoder_block.eval()
132
+ self.classifier.eval()
133
+
134
+ def forward(
135
+ self,
136
+ input_features,
137
+ attention_mask=None,
138
+ decoder_input_ids=None,
139
+ labels_head=None,
140
+ whisper_labels=None,
141
+ ):
142
+ device = "cuda" if torch.cuda.is_available() else "cpu"
143
+ self.whisper_model.eval()
144
+
145
+ # pass the inputs through the model
146
+ backbone_outputs = self.whisper_model(
147
+ input_features=input_features,
148
+ attention_mask=attention_mask,
149
+ decoder_input_ids=decoder_input_ids,
150
+ output_hidden_states=True,
151
+ labels=whisper_labels,
152
+ )
153
+
154
+ # Extract the hidden states of the last layer of the decoder
155
+ decoder_last_layer_hidden_states = backbone_outputs.decoder_hidden_states[
156
+ self.layer_for_head
157
+ ].to(device)
158
+
159
+ # Extract the hidden states of the layer of the encoder who encapsulates best the prosodic features
160
+ layer_for_head_hidden_states = backbone_outputs.encoder_hidden_states[
161
+ self.layer_for_head
162
+ ].to(device)
163
+ # Pass the decoder last hidden layers through the new head (decoder_block + lin cls)
164
+
165
+ additional_decoder_block_outputs = self.additional_decoder_block(
166
+ hidden_states=decoder_last_layer_hidden_states,
167
+ encoder_hidden_states=layer_for_head_hidden_states,
168
+ )
169
+ head_logits = self.classifier(additional_decoder_block_outputs[0].to(device))
170
+
171
+ # calculate softmax
172
+ head_probs = F.softmax(head_logits, dim=-1)
173
+ preds = head_probs.argmax(dim=-1).to(device)
174
+ if labels_head is not None:
175
+ preds = torch.where(
176
+ torch.isin(
177
+ labels_head, torch.tensor(list([-100])).to(device) # 50257, 50362,
178
+ ),
179
+ torch.tensor(-100),
180
+ preds,
181
+ )
182
+ # Calculate custom loss if labels are provided
183
+ loss = None
184
+ if labels_head is not None:
185
+ # CrossEntropyLoss for the custom head
186
+ loss = self.loss_fct(
187
+ head_logits.reshape(-1, head_logits.size(-1)), labels_head.reshape(-1)
188
+ )
189
+ return CustomModelOutput(
190
+ logits=head_logits,
191
+ labels_head=labels_head,
192
+ whisper_logits=backbone_outputs.logits,
193
+ loss=loss,
194
+ preds=preds,
195
+ )
196
+
197
+ def generate(
198
+ self,
199
+ input_features,
200
+ max_length=128,
201
+ labels_head=None,
202
+ whisper_labels=None,
203
+ **generate_kwargs,
204
+ ):
205
+ """
206
+ Generate both the Whisper output and custom head output sequences in alignment.
207
+ """
208
+ device = "cuda" if torch.cuda.is_available() else "cpu"
209
+ # Generate the Whisper output sequence
210
+ whisper_outputs = self.whisper_model.generate(
211
+ input_features=input_features,
212
+ max_length=max_length,
213
+ labels=whisper_labels,
214
+ do_sample=False,
215
+ **generate_kwargs,
216
+ )
217
+
218
+ # pass the inputs through the model
219
+ backbone_outputs = self.whisper_model(
220
+ input_features=input_features,
221
+ decoder_input_ids=whisper_outputs,
222
+ output_hidden_states=True,
223
+ )
224
+
225
+ # Extract the hidden states of the last layer of the decoder
226
+ decoder_last_layer_hidden_states = backbone_outputs.decoder_hidden_states[
227
+ self.layer_for_head
228
+ ].to(device)
229
+
230
+ # Extract the hidden states of the last layer of the encoder
231
+ layer_for_head_hidden_states = backbone_outputs.encoder_hidden_states[
232
+ self.layer_for_head
233
+ ].to(device)
234
+ # Pass the decoder last hidden layers through the new head (decoder_block + lin cls)
235
+
236
+ additional_decoder_block_outputs = self.additional_decoder_block(
237
+ hidden_states=decoder_last_layer_hidden_states,
238
+ encoder_hidden_states=layer_for_head_hidden_states,
239
+ )
240
+ head_logits = self.classifier(additional_decoder_block_outputs[0].to(device))
241
+ # calculate softmax
242
+ head_probs = F.softmax(head_logits, dim=-1)
243
+ preds = head_probs.argmax(dim=-1).to(device)
244
+ preds = torch.where(
245
+ torch.isin(
246
+ whisper_outputs, torch.tensor(list([50256])).to(device) # 50257, 50362,
247
+ ),
248
+ torch.tensor(-100),
249
+ preds,
250
+ )
251
+ # preds_shifted = torch.cat((preds[:, 1:], preds[:, :1]), dim=1)
252
+ return preds
253
+
254
+ def generate_dual(
255
+ self,
256
+ input_features,
257
+ attention_mask=None,
258
+ max_length=200,
259
+ labels_head=None,
260
+ whisper_labels=None,
261
+ **generate_kwargs,
262
+ ):
263
+ """
264
+ Generate both the Whisper output and custom head output sequences in alignment.
265
+ """
266
+ device = "cuda" if torch.cuda.is_available() else "cpu"
267
+ # Generate the Whisper output sequence
268
+ whisper_outputs = self.whisper_model.generate(
269
+ input_features=input_features,
270
+ attention_mask=attention_mask,
271
+ max_length=max_length,
272
+ labels=whisper_labels,
273
+ return_dict_in_generate=True,
274
+ **generate_kwargs,
275
+ )
276
+
277
+ # pass the inputs through the model
278
+ backbone_outputs = self.whisper_model(
279
+ input_features=input_features,
280
+ attention_mask=attention_mask,
281
+ decoder_input_ids=whisper_outputs.sequences,
282
+ output_hidden_states=True,
283
+ )
284
+
285
+ # Extract the hidden states of the last layer of the decoder
286
+ decoder_last_layer_hidden_states = backbone_outputs.decoder_hidden_states[
287
+ self.layer_for_head
288
+ ].to(device)
289
+
290
+ # Extract the hidden states of the last layer of the encoder
291
+ layer_for_head_hidden_states = backbone_outputs.encoder_hidden_states[
292
+ self.layer_for_head
293
+ ].to(device)
294
+ # Pass the decoder last hidden layers through the new head (decoder_block + lin cls)
295
+
296
+ additional_decoder_block_outputs = self.additional_decoder_block(
297
+ hidden_states=decoder_last_layer_hidden_states,
298
+ encoder_hidden_states=layer_for_head_hidden_states,
299
+ )
300
+ head_logits = self.classifier(additional_decoder_block_outputs[0].to(device))
301
+ head_probs = F.softmax(head_logits, dim=-1)
302
+ preds = head_probs.argmax(dim=-1).to(device)
303
+ preds = torch.where(
304
+ torch.isin(
305
+ whisper_outputs.sequences, torch.tensor(list([50256])).to(device) # 50257, 50362,
306
+ ),
307
+ torch.tensor(-100),
308
+ preds,
309
+ )
310
+ return CustomModelOutput(
311
+ logits=head_logits,
312
+ head_preds=preds,
313
+ whisper_logits=whisper_outputs.logits,
314
+ preds=whisper_outputs.sequences
315
+ )
316
+
317
+ def __str__(self):
318
+ return "WhiStress"
whistress/weights/additional_decoder_block.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7d440821c831364c5046e859843926120550a38143f89e1bace82a2ed03cc77
3
+ size 37809834
whistress/weights/classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:599257b647cbca9fc21aac4ede87651cd43d03c3338e705bd59d919ee19ebc6f
3
+ size 4739176
whistress/weights/metadata.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "layer_for_head": 9
3
+ }