import torch from transformers import WhisperConfig import librosa import numpy as np import pathlib from torch.nn import functional as F from ..model import WhiStress PATH_TO_WEIGHTS = pathlib.Path(__file__).parent.parent / "weights" def get_loaded_model(device="cuda"): whisper_model_name = f"openai/whisper-small.en" whisper_config = WhisperConfig() whistress_model = WhiStress( whisper_config, layer_for_head=9, whisper_backbone_name=whisper_model_name ).to(device) whistress_model.processor.tokenizer.model_input_names = [ "input_ids", "attention_mask", "labels_head", ] whistress_model.load_model(PATH_TO_WEIGHTS) whistress_model.to(device) whistress_model.eval() return whistress_model def get_word_emphasis_pairs( transcription_preds, emphasis_preds, processor, filter_special_tokens=True ): emphasis_preds_list = emphasis_preds.tolist() transcription_preds_words = [ processor.tokenizer.decode([i], skip_special_tokens=False) for i in transcription_preds ] if filter_special_tokens: special_tokens_indices = [ i for i, x in enumerate(transcription_preds) if x in processor.tokenizer.all_special_ids ] emphasis_preds_list = [ x for i, x in enumerate(emphasis_preds_list) if i not in special_tokens_indices ] transcription_preds_words = [ x for i, x in enumerate(transcription_preds_words) if i not in special_tokens_indices ] return list(zip(transcription_preds_words, emphasis_preds_list)) def inference_from_audio(audio: np.ndarray, model: WhiStress, device: str): input_features = model.processor.feature_extractor( audio, sampling_rate=16000, return_tensors="pt" )["input_features"] out_model = model.generate_dual(input_features=input_features.to(device)) emphasis_probs = F.softmax(out_model.logits, dim=-1) emphasis_preds = torch.argmax(emphasis_probs, dim=-1) emphasis_preds_right_shifted = torch.cat((emphasis_preds[:, -1:], emphasis_preds[:, :-1]), dim=1) word_emphasis_pairs = get_word_emphasis_pairs( out_model.preds[0], emphasis_preds_right_shifted[0], model.processor, filter_special_tokens=True, ) return word_emphasis_pairs def prepare_audio(audio, target_sr=16000): # resample to 16kHz sr = audio["sampling_rate"] y = audio["array"] y = np.array(y, dtype=float) y_resampled = librosa.resample(y, orig_sr=sr, target_sr=target_sr) # Normalize the audio (scale to [-1, 1]) y_resampled /= max(abs(y_resampled)) return y_resampled def merge_stressed_tokens(tokens_with_stress): """ tokens_with_stress is a list of tuples: (token_string, stress_value) e.g.: [(" I", 0), (" didn", 1), ("'t", 0), (" say", 0), (" he", 0), (" stole", 0), (" the", 0), (" money", 0), (".", 0)] Returns a list of merged tuples, combining subwords into full words. """ merged = [] current_word = "" current_stress = 0 # 0 means not stressed, 1 means stressed for token, stress in tokens_with_stress: # If token starts with a space (or is the very first), we treat it as a new word # or if current_word is empty (first iteration). if token.startswith(" ") or current_word == "": # If we already have something in current_word, push it into merged # before starting a new one if current_word: merged.append((current_word, current_stress)) # Start a new word current_word = token current_stress = stress else: # Otherwise, it's a subword that should be appended to the previous word current_word += token # If any sub-token is stressed, the whole merged word is stressed current_stress = max(current_stress, stress) # Don't forget to append the final word if current_word: merged.append((current_word, current_stress)) return merged def inference_from_audio_and_transcription( audio: np.ndarray, transcription, model: WhiStress, device: str ): input_features = model.processor.feature_extractor( audio, sampling_rate=16000, return_tensors="pt" )["input_features"] # convert transcription to input_ids input_ids = model.processor.tokenizer( transcription, return_tensors="pt", padding="max_length", truncation=True, max_length=30, )["input_ids"] out_model = model( input_features=input_features.to(device), decoder_input_ids=input_ids.to(device), ) emphasis_probs = F.softmax(out_model.logits, dim=-1) emphasis_preds = torch.argmax(emphasis_probs, dim=-1) emphasis_preds_right_shifted = torch.cat((emphasis_preds[:, -1:], emphasis_preds[:, :-1]), dim=1) word_emphasis_pairs = get_word_emphasis_pairs( input_ids[0], emphasis_preds_right_shifted[0], model.processor, filter_special_tokens=True, ) return word_emphasis_pairs def scored_transcription(audio, model, strip_words=True, transcription: str = None, device="cuda"): audio_arr = prepare_audio(audio) token_stress_pairs = None if transcription: # if we want to use the ground truth transcription token_stress_pairs = inference_from_audio_and_transcription(audio_arr, transcription, model, device) else: token_stress_pairs = inference_from_audio(audio_arr, model, device) # token_stress_pairs = inference_from_audio(audio_arr, model) word_level_stress = merge_stressed_tokens(token_stress_pairs) if strip_words: word_level_stress = [(word.strip(), stress) for word, stress in word_level_stress] return word_level_stress