BiBiER / app.py
DmitryRyumin's picture
Update app.py
efc049f verified
import gradio as gr
import torchaudio
import pandas as pd
import torch.nn.functional as F
import whisper
import logging
import plotly.express as px
from utils.config_loader import ConfigLoader
from data_loading.feature_extractor import (
PretrainedAudioEmbeddingExtractor,
PretrainedTextEmbeddingExtractor
)
import chardet
import torch
from models.models import BiFormer
# DEVICE = torch.device('cuda')
DEVICE = torch.device('cpu')
# Configure logging
logging.basicConfig(level=logging.INFO)
# Constants with emojis and colors
LABEL_TO_EMOTION = {
0: '😠 Anger',
1: '🀒 Disgust',
2: '😨 Fear',
3: 'πŸ˜„ Joy/Happiness',
4: '😐 Neutral',
5: '😒 Sadness',
6: '😲 Surprise/Enthusiasm'
}
EMOTION_COLORS = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEEAD', '#FF9999', '#D4A5A5']
emotion_color_map = {emotion: color for emotion, color in zip(LABEL_TO_EMOTION.values(), EMOTION_COLORS)}
TARGET_SAMPLE_RATE = 16000
def initialize_components(config_path='config.toml'):
"""Initialize configuration and models."""
config = ConfigLoader(config_path)
config.show_config()
model = BiFormer(
audio_dim=256,
text_dim=1024,
seg_len=95,
hidden_dim=256,
hidden_dim_gated=256,
num_transformer_heads=8,
num_graph_heads=2,
positional_encoding=False,
dropout=0.15,
mode='mean',
tr_layer_number=5,
out_features=256,
num_classes=7
)
checkpoint_path = "best_model_dev_0_5895_epoch_8.pt"
state = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(state)
model = model.to(DEVICE)
model.eval()
return (
PretrainedAudioEmbeddingExtractor(config),
PretrainedTextEmbeddingExtractor(config),
whisper.load_model("base"),
model
)
audio_extractor, text_extractor, whisper_model, bimodal_model = initialize_components()
def load_and_preprocess_audio(audio_path):
"""Load and preprocess audio to mono 16kHz format."""
try:
waveform, orig_sr = torchaudio.load(audio_path)
waveform = waveform.mean(dim=0, keepdim=False)
if orig_sr != TARGET_SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(
orig_freq=orig_sr,
new_freq=TARGET_SAMPLE_RATE
)
waveform = resampler(waveform)
return waveform, TARGET_SAMPLE_RATE
except Exception as e:
logging.error(f"Audio loading failed: {e}")
raise
def transcribe_audio(audio_path):
"""Convert speech to text using Whisper."""
try:
result = whisper_model.transcribe(audio_path, fp16=False)
return result.get('text', '')
except Exception as e:
logging.error(f"Transcription failed: {e}")
return ""
def get_predictions(input_data, extractor, is_audio=False):
"""Generic prediction function for audio/text."""
try:
if is_audio:
pred, emb = extractor.extract(input_data, TARGET_SAMPLE_RATE)
else:
pred, emb = extractor.extract(input_data)
return F.softmax(pred, dim=-1)[0].tolist(), emb
except Exception as e:
logging.error(f"Prediction failed: {e}")
return [0.0] * len(LABEL_TO_EMOTION), None
def create_emotion_df(probabilities):
"""Create sorted emotion probability dataframe with percentages."""
df = pd.DataFrame({
'Emotion': list(LABEL_TO_EMOTION.values()),
'Probability': [round(p*100, 2) for p in probabilities]
})
return df
def create_plot(df, title):
"""Create Plotly bar chart with proper formatting."""
fig = px.bar(
df,
x='Emotion',
y='Probability',
title=title,
color='Emotion',
color_discrete_map=emotion_color_map
)
fig.update_layout(
xaxis=dict(tickangle=-45, tickfont=dict(size=12)),
yaxis=dict(title='Probability (%)'),
margin=dict(l=20, r=20, t=60, b=100),
height=400,
showlegend=False
)
return fig
def get_top_emotion(probabilities):
"""Return formatted top emotion with percentage."""
max_idx = probabilities.index(max(probabilities))
return f"{LABEL_TO_EMOTION[max_idx]} ({max(probabilities)*100:.1f}%)"
def process_audio(audio_path):
"""Main processing pipeline."""
try:
if not audio_path:
empty = create_emotion_df([0]*len(LABEL_TO_EMOTION))
return (
create_plot(empty, "🎧 Audio Analysis"),
"No audio detected",
create_plot(empty, "πŸ“ Text Analysis"),
create_plot(empty, "🀝 Audio-Text Analysis"),
"πŸ”‡ Please provide audio input"
)
# Audio processing
waveform, sr = load_and_preprocess_audio(audio_path)
audio_probs, audio_features = get_predictions(waveform, audio_extractor, is_audio=True)
audio_df = create_emotion_df(audio_probs)
# Text processing
text = transcribe_audio(audio_path)
text_probs, text_features = get_predictions(text, text_extractor) if text.strip() else [0.0]*len(LABEL_TO_EMOTION)
text_df = create_emotion_df(text_probs)
# Combined results
combined_probs = bimodal_model(audio_features, text_features)
combined_probs = F.softmax(combined_probs, dim=-1)[0].detach().cpu().numpy().tolist()
combined_df = create_emotion_df(combined_probs)
top_emotion = get_top_emotion(combined_probs)
return (
create_plot(audio_df, "🎧 Audio Analysis"),
f"πŸ—£οΈ Transcription:\n{text}",
create_plot(text_df, "πŸ“ Text Analysis"),
create_plot(combined_df, "🀝 Audio-Text Analysis"),
f"## πŸ† Dominant Emotion: {top_emotion}"
)
except Exception as e:
logging.error(f"Processing failed: {e}")
error_df = create_emotion_df([0]*len(LABEL_TO_EMOTION))
return (
create_plot(error_df, "🎧 Audio Analysis"),
"❌ Error processing audio",
create_plot(error_df, "πŸ“ Text Analysis"),
create_plot(error_df, "🀝 Audio-Text Analysis"),
"⚠️ Processing Error"
)
def create_app():
"""Build enhanced Gradio interface."""
with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection from Speech") as demo:
gr.Markdown("# Intelligent system for Bilingual Bimodal Emotion Recognition (BiBiER)")
gr.Markdown("Analyze emotions in Russian and English speech through both audio characteristics and spoken content")
with gr.Row():
audio_input = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Record or Upload Audio",
format="wav",
interactive=True
)
with gr.Row():
top_emotion = gr.Markdown("## πŸ† Dominant Emotion: Waiting for input ...",
elem_classes="dominant-emotion")
with gr.Row():
with gr.Column():
audio_plot = gr.Plot(label="Audio Analysis")
with gr.Column():
text_plot = gr.Plot(label="Text Analysis")
with gr.Column():
combined_plot = gr.Plot(label="Audio-Text Analysis")
transcription = gr.Textbox(
label="πŸ“œ Transcription Results",
placeholder="Transcribed text will appear here...",
lines=3,
max_lines=6
)
audio_input.change(
process_audio,
inputs=audio_input,
outputs=[audio_plot, transcription, text_plot, combined_plot, top_emotion]
)
return demo
def create_authors():
df = pd.DataFrame({
"Name": ["Author", "Author"]
})
with gr.Blocks() as demo:
gr.Dataframe(df)
return demo
def create_reqs():
"""Create requirements tab with formatted data and explanations."""
# 1️⃣ Detect file encoding
with open('requirements.txt', 'rb') as f:
raw_data = f.read()
encoding = chardet.detect(raw_data)['encoding']
# 2️⃣ Parse requirements into library-version pairs
def parse_requirements(lines):
requirements = []
for line in lines:
line = line.strip()
if not line or line.startswith('#'):
continue # Skip empty lines and comments
parts = line.split('==')
library = parts[0].strip()
version = parts[1].strip() if len(parts) > 1 else 'latest'
requirements.append((library, version))
return requirements
# 3️⃣ Load and process requirements
with open('requirements.txt', 'r', encoding=encoding) as f:
requirements = parse_requirements(f.readlines())
# 4️⃣ Create structured data for display
df = pd.DataFrame({
"πŸ“¦ Library": [lib for lib, _ in requirements],
"πŸš€ Recommended Version": [ver for _, ver in requirements]
})
# 5️⃣ Build interactive components
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ“¦ Dependency Requirements")
gr.Markdown("""
## Essential Packages for Operation
These are the core libraries and versions needed to run the application successfully:
""")
gr.Dataframe(
df,
interactive=True,
wrap=True,
elem_id="requirements-table"
)
gr.Markdown("_Note: Versions marked 'latest' can use any compatible version_")
return demo
def create_demo():
app = create_app()
authors = create_authors()
reqs = create_reqs()
demo = gr.TabbedInterface(
[app, authors, reqs],
tab_names=["⭐ App", "🎭 Authors", "πŸ“‹ Requirements"]
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()