Spaces:
Runtime error
Runtime error
import time | |
import torch | |
import librosa | |
import numpy as np | |
import gradio as gr | |
import gradio as gr | |
from .generate_graph import create_behaviour_gantt_plot | |
from transformers import Wav2Vec2Processor | |
SAMPLING_RATE = 16_000 | |
class AudioProcessor: | |
def __init__( | |
self, | |
emotion_model, | |
segmentation_model, | |
device, | |
behaviour_model=None, | |
): | |
self.emotion_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
self.emotion_model = emotion_model | |
self.behaviour_model = behaviour_model | |
self.device = device | |
self.audio_emotion_labels = { | |
0: "Neutralità", | |
1: "Rabbia", | |
2: "Paura", | |
3: "Gioia", | |
4: "Sorpresa", | |
5: "Tristezza", | |
6: "Disgusto", | |
} | |
self.emotion_translation = { | |
"neutrality": "Neutralità", | |
"anger": "Rabbia", | |
"fear": "Paura", | |
"joy": "Gioia", | |
"surprise": "Sorpresa", | |
"sadness": "Tristezza", | |
"disgust": "Disgusto" | |
} | |
self.behaviour_labels = { | |
0: "frustrated", | |
1: "delighted", | |
2: "dysregulated", | |
} | |
self.behaviour_translation = { | |
"frustrated": "frustazione", | |
"delighted": "incantato", | |
"dysregulated": "disregolazione", | |
} | |
self.segmentation_model = segmentation_model | |
self._set_emotion_model() | |
if self.behaviour_model: | |
self._set_behaviour_model() | |
self.behaviour_confidence = 0.6 | |
self.chart_generator = None | |
def _set_emotion_model(self): | |
self.emotion_model.to(self.device) | |
self.emotion_model.eval() | |
def _set_behaviour_model(self): | |
self.behaviour_model.to(self.device) | |
self.behaviour_model.eval() | |
def _prepare_transcribed_text(self, chunks): | |
formated_timestamps = [] | |
predictions = [] | |
for chunk in chunks: | |
start = chunk[0] / SAMPLING_RATE | |
end = chunk[1] / SAMPLING_RATE | |
formated_start = time.strftime('%H:%M:%S', time.gmtime(start)) | |
formated_end = time.strftime('%H:%M:%S', time.gmtime(end)) | |
formated_timestamps.append(f"**({formated_start} - {formated_end})**") | |
predictions.append(f"**[{chunk[2]}]**") | |
transcribed_texts = [chunk[3] for chunk in chunks] | |
transcribed_text = "<br/>".join( | |
[ | |
f"{formated_timestamps[i]}: {transcribed_texts[i]} {predictions[i]}" for i in range(len(transcribed_texts)) | |
] | |
) | |
print(f"Transcribed text:\n{transcribed_text}") | |
return transcribed_text | |
def __call__(self, audio_path: str): | |
""" | |
Predicts the emotion label for a given audio input. | |
Args: | |
audio (filepath): The audio input path to be processed. | |
Returns: | |
str: The predicted emotion label. | |
""" | |
try: | |
input_frames, _ = librosa.load( | |
audio_path, | |
sr=SAMPLING_RATE | |
) | |
except Exception as e: | |
gr.Error(f"Error loading audio file: {e}.") | |
print("Segmenting audio...") | |
out = self.segmentation_model( | |
inputs={ | |
"raw": input_frames, | |
"sampling_rate": SAMPLING_RATE, | |
}, | |
chunk_length_s=30, | |
stride_length_s=5, | |
return_timestamps=True, | |
) | |
emotion_chunks = [] | |
behaviour_chunks = [] | |
timestamps = [] | |
predicted_labels = [] | |
all_probabilities = [] | |
print("Analizing chunks...") | |
for chunk in out["chunks"]: | |
# trim audio from timestamps | |
start = int(chunk["timestamp"][0] * SAMPLING_RATE) | |
end = int(chunk["timestamp"][1] * SAMPLING_RATE if chunk["timestamp"][1] else len(input_frames)) | |
audio = input_frames[start:end] | |
inputs = self.emotion_processor(audio, chunk["text"], return_tensors="pt", sampling_rate=SAMPLING_RATE) | |
print(f"Inputs: {inputs}") | |
if "input_values" in inputs: | |
inputs["input_features"] = inputs.pop("input_values") | |
inputs['input_features'] = inputs['input_features'].to(self.device) | |
inputs['input_ids'] = inputs['input_ids'].to(self.device) | |
inputs['text_attention_mask'] = inputs['text_attention_mask'].to(self.device) | |
print("Predicting emotion for chunk...") | |
logits = self.emotion_model(**inputs).logits | |
logits = logits.detach().cpu() | |
softmax = torch.nn.Softmax(dim=1) | |
probabilities = softmax(logits).squeeze(0) | |
prediction = probabilities.argmax().item() | |
predicted_label = self.emotion_processor.config.id2label[prediction] | |
label_translation = self.emotion_translation[predicted_label] | |
emotion_chunks.append( | |
( | |
start, | |
end, | |
label_translation, | |
chunk["text"], | |
np.round(probabilities[prediction].item(), 2) | |
) | |
) | |
timestamps.append((start, end)) | |
predicted_labels.append(label_translation) | |
all_probabilities.append(probabilities[prediction].item()) | |
inputs = self.emotion_processor(audio, return_tensors="pt", sampling_rate=SAMPLING_RATE) | |
if "input_values" in inputs: | |
inputs["input_features"] = inputs.pop("input_values") | |
inputs = inputs.input_features.to(self.device) | |
print("Predicting behaviour for chunk...") | |
logits = self.behaviour_model(inputs).logits | |
probabilities = torch.nn.functional.softmax(logits.detach().cpu(), dim=-1).squeeze() | |
behaviour_chunks.append( | |
( | |
start, | |
end, | |
chunk["text"], | |
np.round(probabilities[2].item(), 2), | |
label_translation, | |
) | |
) | |
behaviour_gantt = create_behaviour_gantt_plot(behaviour_chunks) | |
# transcribed_text = self._prepare_transcribed_text(emotion_chunks) | |
return ( | |
behaviour_gantt, | |
# transcribed_text, | |
) |