Spaces:
Running
Running
import gradio as gr | |
import torchaudio | |
import pandas as pd | |
import torch.nn.functional as F | |
import whisper | |
import logging | |
import plotly.express as px | |
from utils.config_loader import ConfigLoader | |
from data_loading.feature_extractor import ( | |
PretrainedAudioEmbeddingExtractor, | |
PretrainedTextEmbeddingExtractor | |
) | |
import chardet | |
import torch | |
from models.models import BiFormer | |
# DEVICE = torch.device('cuda') | |
DEVICE = torch.device('cpu') | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Constants with emojis and colors | |
LABEL_TO_EMOTION = { | |
0: 'π Anger', | |
1: 'π€’ Disgust', | |
2: 'π¨ Fear', | |
3: 'π Joy/Happiness', | |
4: 'π Neutral', | |
5: 'π’ Sadness', | |
6: 'π² Surprise/Enthusiasm' | |
} | |
EMOTION_COLORS = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEEAD', '#FF9999', '#D4A5A5'] | |
emotion_color_map = {emotion: color for emotion, color in zip(LABEL_TO_EMOTION.values(), EMOTION_COLORS)} | |
TARGET_SAMPLE_RATE = 16000 | |
def initialize_components(config_path='config.toml'): | |
"""Initialize configuration and models.""" | |
config = ConfigLoader(config_path) | |
config.show_config() | |
model = BiFormer( | |
audio_dim=256, | |
text_dim=1024, | |
seg_len=95, | |
hidden_dim=256, | |
hidden_dim_gated=256, | |
num_transformer_heads=8, | |
num_graph_heads=2, | |
positional_encoding=False, | |
dropout=0.15, | |
mode='mean', | |
tr_layer_number=5, | |
out_features=256, | |
num_classes=7 | |
) | |
checkpoint_path = "best_model_dev_0_5895_epoch_8.pt" | |
state = torch.load(checkpoint_path, map_location="cpu") | |
model.load_state_dict(state) | |
model = model.to(DEVICE) | |
model.eval() | |
return ( | |
PretrainedAudioEmbeddingExtractor(config), | |
PretrainedTextEmbeddingExtractor(config), | |
whisper.load_model("base"), | |
model | |
) | |
audio_extractor, text_extractor, whisper_model, bimodal_model = initialize_components() | |
def load_and_preprocess_audio(audio_path): | |
"""Load and preprocess audio to mono 16kHz format.""" | |
try: | |
waveform, orig_sr = torchaudio.load(audio_path) | |
waveform = waveform.mean(dim=0, keepdim=False) | |
if orig_sr != TARGET_SAMPLE_RATE: | |
resampler = torchaudio.transforms.Resample( | |
orig_freq=orig_sr, | |
new_freq=TARGET_SAMPLE_RATE | |
) | |
waveform = resampler(waveform) | |
return waveform, TARGET_SAMPLE_RATE | |
except Exception as e: | |
logging.error(f"Audio loading failed: {e}") | |
raise | |
def transcribe_audio(audio_path): | |
"""Convert speech to text using Whisper.""" | |
try: | |
result = whisper_model.transcribe(audio_path, fp16=False) | |
return result.get('text', '') | |
except Exception as e: | |
logging.error(f"Transcription failed: {e}") | |
return "" | |
def get_predictions(input_data, extractor, is_audio=False): | |
"""Generic prediction function for audio/text.""" | |
try: | |
if is_audio: | |
pred, emb = extractor.extract(input_data, TARGET_SAMPLE_RATE) | |
else: | |
pred, emb = extractor.extract(input_data) | |
return F.softmax(pred, dim=-1)[0].tolist(), emb | |
except Exception as e: | |
logging.error(f"Prediction failed: {e}") | |
return [0.0] * len(LABEL_TO_EMOTION), None | |
def create_emotion_df(probabilities): | |
"""Create sorted emotion probability dataframe with percentages.""" | |
df = pd.DataFrame({ | |
'Emotion': list(LABEL_TO_EMOTION.values()), | |
'Probability': [round(p*100, 2) for p in probabilities] | |
}) | |
return df | |
def create_plot(df, title): | |
"""Create Plotly bar chart with proper formatting.""" | |
fig = px.bar( | |
df, | |
x='Emotion', | |
y='Probability', | |
title=title, | |
color='Emotion', | |
color_discrete_map=emotion_color_map | |
) | |
fig.update_layout( | |
xaxis=dict(tickangle=-45, tickfont=dict(size=12)), | |
yaxis=dict(title='Probability (%)'), | |
margin=dict(l=20, r=20, t=60, b=100), | |
height=400, | |
showlegend=False | |
) | |
return fig | |
def get_top_emotion(probabilities): | |
"""Return formatted top emotion with percentage.""" | |
max_idx = probabilities.index(max(probabilities)) | |
return f"{LABEL_TO_EMOTION[max_idx]} ({max(probabilities)*100:.1f}%)" | |
def process_audio(audio_path): | |
"""Main processing pipeline.""" | |
try: | |
if not audio_path: | |
empty = create_emotion_df([0]*len(LABEL_TO_EMOTION)) | |
return ( | |
create_plot(empty, "π§ Audio Analysis"), | |
"No audio detected", | |
create_plot(empty, "π Text Analysis"), | |
create_plot(empty, "π€ Audio-Text Analysis"), | |
"π Please provide audio input" | |
) | |
# Audio processing | |
waveform, sr = load_and_preprocess_audio(audio_path) | |
audio_probs, audio_features = get_predictions(waveform, audio_extractor, is_audio=True) | |
audio_df = create_emotion_df(audio_probs) | |
# Text processing | |
text = transcribe_audio(audio_path) | |
text_probs, text_features = get_predictions(text, text_extractor) if text.strip() else [0.0]*len(LABEL_TO_EMOTION) | |
text_df = create_emotion_df(text_probs) | |
# Combined results | |
combined_probs = bimodal_model(audio_features, text_features) | |
combined_probs = F.softmax(combined_probs, dim=-1)[0].detach().cpu().numpy().tolist() | |
combined_df = create_emotion_df(combined_probs) | |
top_emotion = get_top_emotion(combined_probs) | |
return ( | |
create_plot(audio_df, "π§ Audio Analysis"), | |
f"π£οΈ Transcription:\n{text}", | |
create_plot(text_df, "π Text Analysis"), | |
create_plot(combined_df, "π€ Audio-Text Analysis"), | |
f"## π Dominant Emotion: {top_emotion}" | |
) | |
except Exception as e: | |
logging.error(f"Processing failed: {e}") | |
error_df = create_emotion_df([0]*len(LABEL_TO_EMOTION)) | |
return ( | |
create_plot(error_df, "π§ Audio Analysis"), | |
"β Error processing audio", | |
create_plot(error_df, "π Text Analysis"), | |
create_plot(error_df, "π€ Audio-Text Analysis"), | |
"β οΈ Processing Error" | |
) | |
def create_app(): | |
"""Build enhanced Gradio interface.""" | |
with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection from Speech") as demo: | |
gr.Markdown("# Intelligent system for Bilingual Bimodal Emotion Recognition (BiBiER)") | |
gr.Markdown("Analyze emotions in Russian and English speech through both audio characteristics and spoken content") | |
with gr.Row(): | |
audio_input = gr.Audio( | |
sources=["upload", "microphone"], | |
type="filepath", | |
label="Record or Upload Audio", | |
format="wav", | |
interactive=True | |
) | |
with gr.Row(): | |
top_emotion = gr.Markdown("## π Dominant Emotion: Waiting for input ...", | |
elem_classes="dominant-emotion") | |
with gr.Row(): | |
with gr.Column(): | |
audio_plot = gr.Plot(label="Audio Analysis") | |
with gr.Column(): | |
text_plot = gr.Plot(label="Text Analysis") | |
with gr.Column(): | |
combined_plot = gr.Plot(label="Audio-Text Analysis") | |
transcription = gr.Textbox( | |
label="π Transcription Results", | |
placeholder="Transcribed text will appear here...", | |
lines=3, | |
max_lines=6 | |
) | |
audio_input.change( | |
process_audio, | |
inputs=audio_input, | |
outputs=[audio_plot, transcription, text_plot, combined_plot, top_emotion] | |
) | |
return demo | |
def create_authors(): | |
df = pd.DataFrame({ | |
"Name": ["Author", "Author"] | |
}) | |
with gr.Blocks() as demo: | |
gr.Dataframe(df) | |
return demo | |
def create_reqs(): | |
"""Create requirements tab with formatted data and explanations.""" | |
# 1οΈβ£ Detect file encoding | |
with open('requirements.txt', 'rb') as f: | |
raw_data = f.read() | |
encoding = chardet.detect(raw_data)['encoding'] | |
# 2οΈβ£ Parse requirements into library-version pairs | |
def parse_requirements(lines): | |
requirements = [] | |
for line in lines: | |
line = line.strip() | |
if not line or line.startswith('#'): | |
continue # Skip empty lines and comments | |
parts = line.split('==') | |
library = parts[0].strip() | |
version = parts[1].strip() if len(parts) > 1 else 'latest' | |
requirements.append((library, version)) | |
return requirements | |
# 3οΈβ£ Load and process requirements | |
with open('requirements.txt', 'r', encoding=encoding) as f: | |
requirements = parse_requirements(f.readlines()) | |
# 4οΈβ£ Create structured data for display | |
df = pd.DataFrame({ | |
"π¦ Library": [lib for lib, _ in requirements], | |
"π Recommended Version": [ver for _, ver in requirements] | |
}) | |
# 5οΈβ£ Build interactive components | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π¦ Dependency Requirements") | |
gr.Markdown(""" | |
## Essential Packages for Operation | |
These are the core libraries and versions needed to run the application successfully: | |
""") | |
gr.Dataframe( | |
df, | |
interactive=True, | |
wrap=True, | |
elem_id="requirements-table" | |
) | |
gr.Markdown("_Note: Versions marked 'latest' can use any compatible version_") | |
return demo | |
def create_demo(): | |
app = create_app() | |
authors = create_authors() | |
reqs = create_reqs() | |
demo = gr.TabbedInterface( | |
[app, authors, reqs], | |
tab_names=["β App", "π Authors", "π Requirements"] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch() | |