import gradio as gr import torchaudio import pandas as pd import torch.nn.functional as F import whisper import logging import plotly.express as px from utils.config_loader import ConfigLoader from data_loading.feature_extractor import ( PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor ) import chardet import torch from models.models import BiFormer # DEVICE = torch.device('cuda') DEVICE = torch.device('cpu') # Configure logging logging.basicConfig(level=logging.INFO) # Constants with emojis and colors LABEL_TO_EMOTION = { 0: '😠 Anger', 1: '🤢 Disgust', 2: '😨 Fear', 3: '😄 Joy/Happiness', 4: '😐 Neutral', 5: '😢 Sadness', 6: '😲 Surprise/Enthusiasm' } EMOTION_COLORS = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEEAD', '#FF9999', '#D4A5A5'] emotion_color_map = {emotion: color for emotion, color in zip(LABEL_TO_EMOTION.values(), EMOTION_COLORS)} TARGET_SAMPLE_RATE = 16000 def initialize_components(config_path='config.toml'): """Initialize configuration and models.""" config = ConfigLoader(config_path) config.show_config() model = BiFormer( audio_dim=256, text_dim=1024, seg_len=95, hidden_dim=256, hidden_dim_gated=256, num_transformer_heads=8, num_graph_heads=2, positional_encoding=False, dropout=0.15, mode='mean', tr_layer_number=5, out_features=256, num_classes=7 ) checkpoint_path = "best_model_dev_0_5895_epoch_8.pt" state = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(state) model = model.to(DEVICE) model.eval() return ( PretrainedAudioEmbeddingExtractor(config), PretrainedTextEmbeddingExtractor(config), whisper.load_model("base"), model ) audio_extractor, text_extractor, whisper_model, bimodal_model = initialize_components() def load_and_preprocess_audio(audio_path): """Load and preprocess audio to mono 16kHz format.""" try: waveform, orig_sr = torchaudio.load(audio_path) waveform = waveform.mean(dim=0, keepdim=False) if orig_sr != TARGET_SAMPLE_RATE: resampler = torchaudio.transforms.Resample( orig_freq=orig_sr, new_freq=TARGET_SAMPLE_RATE ) waveform = resampler(waveform) return waveform, TARGET_SAMPLE_RATE except Exception as e: logging.error(f"Audio loading failed: {e}") raise def transcribe_audio(audio_path): """Convert speech to text using Whisper.""" try: result = whisper_model.transcribe(audio_path, fp16=False) return result.get('text', '') except Exception as e: logging.error(f"Transcription failed: {e}") return "" def get_predictions(input_data, extractor, is_audio=False): """Generic prediction function for audio/text.""" try: if is_audio: pred, emb = extractor.extract(input_data, TARGET_SAMPLE_RATE) else: pred, emb = extractor.extract(input_data) return F.softmax(pred, dim=-1)[0].tolist(), emb except Exception as e: logging.error(f"Prediction failed: {e}") return [0.0] * len(LABEL_TO_EMOTION), None def create_emotion_df(probabilities): """Create sorted emotion probability dataframe with percentages.""" df = pd.DataFrame({ 'Emotion': list(LABEL_TO_EMOTION.values()), 'Probability': [round(p*100, 2) for p in probabilities] }) return df def create_plot(df, title): """Create Plotly bar chart with proper formatting.""" fig = px.bar( df, x='Emotion', y='Probability', title=title, color='Emotion', color_discrete_map=emotion_color_map ) fig.update_layout( xaxis=dict(tickangle=-45, tickfont=dict(size=12)), yaxis=dict(title='Probability (%)'), margin=dict(l=20, r=20, t=60, b=100), height=400, showlegend=False ) return fig def get_top_emotion(probabilities): """Return formatted top emotion with percentage.""" max_idx = probabilities.index(max(probabilities)) return f"{LABEL_TO_EMOTION[max_idx]} ({max(probabilities)*100:.1f}%)" def process_audio(audio_path): """Main processing pipeline.""" try: if not audio_path: empty = create_emotion_df([0]*len(LABEL_TO_EMOTION)) return ( create_plot(empty, "🎧 Audio Analysis"), "No audio detected", create_plot(empty, "📝 Text Analysis"), create_plot(empty, "🤝 Audio-Text Analysis"), "🔇 Please provide audio input" ) # Audio processing waveform, sr = load_and_preprocess_audio(audio_path) audio_probs, audio_features = get_predictions(waveform, audio_extractor, is_audio=True) audio_df = create_emotion_df(audio_probs) # Text processing text = transcribe_audio(audio_path) text_probs, text_features = get_predictions(text, text_extractor) if text.strip() else [0.0]*len(LABEL_TO_EMOTION) text_df = create_emotion_df(text_probs) # Combined results combined_probs = bimodal_model(audio_features, text_features) combined_probs = F.softmax(combined_probs, dim=-1)[0].detach().cpu().numpy().tolist() combined_df = create_emotion_df(combined_probs) top_emotion = get_top_emotion(combined_probs) return ( create_plot(audio_df, "🎧 Audio Analysis"), f"🗣️ Transcription:\n{text}", create_plot(text_df, "📝 Text Analysis"), create_plot(combined_df, "🤝 Audio-Text Analysis"), f"## 🏆 Dominant Emotion: {top_emotion}" ) except Exception as e: logging.error(f"Processing failed: {e}") error_df = create_emotion_df([0]*len(LABEL_TO_EMOTION)) return ( create_plot(error_df, "🎧 Audio Analysis"), "❌ Error processing audio", create_plot(error_df, "📝 Text Analysis"), create_plot(error_df, "🤝 Audio-Text Analysis"), "⚠️ Processing Error" ) def create_app(): """Build enhanced Gradio interface.""" with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection from Speech") as demo: gr.Markdown("# Intelligent system for Bilingual Bimodal Emotion Recognition (BiBiER)") gr.Markdown("Analyze emotions in Russian and English speech through both audio characteristics and spoken content") with gr.Row(): audio_input = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Record or Upload Audio", format="wav", interactive=True ) with gr.Row(): top_emotion = gr.Markdown("## 🏆 Dominant Emotion: Waiting for input ...", elem_classes="dominant-emotion") with gr.Row(): with gr.Column(): audio_plot = gr.Plot(label="Audio Analysis") with gr.Column(): text_plot = gr.Plot(label="Text Analysis") with gr.Column(): combined_plot = gr.Plot(label="Audio-Text Analysis") transcription = gr.Textbox( label="📜 Transcription Results", placeholder="Transcribed text will appear here...", lines=3, max_lines=6 ) audio_input.change( process_audio, inputs=audio_input, outputs=[audio_plot, transcription, text_plot, combined_plot, top_emotion] ) return demo def create_authors(): df = pd.DataFrame({ "Name": ["Author", "Author"] }) with gr.Blocks() as demo: gr.Dataframe(df) return demo def create_reqs(): """Create requirements tab with formatted data and explanations.""" # 1️⃣ Detect file encoding with open('requirements.txt', 'rb') as f: raw_data = f.read() encoding = chardet.detect(raw_data)['encoding'] # 2️⃣ Parse requirements into library-version pairs def parse_requirements(lines): requirements = [] for line in lines: line = line.strip() if not line or line.startswith('#'): continue # Skip empty lines and comments parts = line.split('==') library = parts[0].strip() version = parts[1].strip() if len(parts) > 1 else 'latest' requirements.append((library, version)) return requirements # 3️⃣ Load and process requirements with open('requirements.txt', 'r', encoding=encoding) as f: requirements = parse_requirements(f.readlines()) # 4️⃣ Create structured data for display df = pd.DataFrame({ "📦 Library": [lib for lib, _ in requirements], "🚀 Recommended Version": [ver for _, ver in requirements] }) # 5️⃣ Build interactive components with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 📦 Dependency Requirements") gr.Markdown(""" ## Essential Packages for Operation These are the core libraries and versions needed to run the application successfully: """) gr.Dataframe( df, interactive=True, wrap=True, elem_id="requirements-table" ) gr.Markdown("_Note: Versions marked 'latest' can use any compatible version_") return demo def create_demo(): app = create_app() authors = create_authors() reqs = create_reqs() demo = gr.TabbedInterface( [app, authors, reqs], tab_names=["⭐ App", "🎭 Authors", "📋 Requirements"] ) return demo if __name__ == "__main__": demo = create_demo() demo.launch()