File size: 10,029 Bytes
960b1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92da7ef
 
960b1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efc049f
960b1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efc049f
960b1a0
 
 
 
 
 
 
 
 
 
efc049f
960b1a0
 
 
 
 
 
 
e9c8cf5
efc049f
960b1a0
 
 
 
 
 
 
 
 
 
 
efc049f
960b1a0
 
 
 
 
 
 
 
efc049f
960b1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9c8cf5
960b1a0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import gradio as gr
import torchaudio
import pandas as pd
import torch.nn.functional as F
import whisper
import logging
import plotly.express as px
from utils.config_loader import ConfigLoader
from data_loading.feature_extractor import (
    PretrainedAudioEmbeddingExtractor,
    PretrainedTextEmbeddingExtractor
)
import chardet
import torch
from models.models import BiFormer


# DEVICE = torch.device('cuda')
DEVICE = torch.device('cpu')

# Configure logging
logging.basicConfig(level=logging.INFO)

# Constants with emojis and colors
LABEL_TO_EMOTION = {
    0: '😠 Anger',
    1: '🀒 Disgust',
    2: '😨 Fear',
    3: 'πŸ˜„ Joy/Happiness',
    4: '😐 Neutral',
    5: '😒 Sadness',
    6: '😲 Surprise/Enthusiasm'
}
EMOTION_COLORS = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEEAD', '#FF9999', '#D4A5A5']
emotion_color_map = {emotion: color for emotion, color in zip(LABEL_TO_EMOTION.values(), EMOTION_COLORS)}
TARGET_SAMPLE_RATE = 16000


def initialize_components(config_path='config.toml'):
    """Initialize configuration and models."""
    config = ConfigLoader(config_path)
    config.show_config()
    model = BiFormer(
        audio_dim=256,
        text_dim=1024,
        seg_len=95,
        hidden_dim=256,
        hidden_dim_gated=256,
        num_transformer_heads=8,
        num_graph_heads=2,
        positional_encoding=False,
        dropout=0.15,
        mode='mean',
        tr_layer_number=5,
        out_features=256,
        num_classes=7
    )
    checkpoint_path = "best_model_dev_0_5895_epoch_8.pt"
    state = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(state)
    model = model.to(DEVICE)
    model.eval()
    return (
        PretrainedAudioEmbeddingExtractor(config),
        PretrainedTextEmbeddingExtractor(config),
        whisper.load_model("base"),
        model
    )


audio_extractor, text_extractor, whisper_model, bimodal_model = initialize_components()


def load_and_preprocess_audio(audio_path):
    """Load and preprocess audio to mono 16kHz format."""
    try:
        waveform, orig_sr = torchaudio.load(audio_path)
        waveform = waveform.mean(dim=0, keepdim=False)

        if orig_sr != TARGET_SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(
                orig_freq=orig_sr,
                new_freq=TARGET_SAMPLE_RATE
            )
            waveform = resampler(waveform)

        return waveform, TARGET_SAMPLE_RATE
    except Exception as e:
        logging.error(f"Audio loading failed: {e}")
        raise


def transcribe_audio(audio_path):
    """Convert speech to text using Whisper."""
    try:
        result = whisper_model.transcribe(audio_path, fp16=False)
        return result.get('text', '')
    except Exception as e:
        logging.error(f"Transcription failed: {e}")
        return ""


def get_predictions(input_data, extractor, is_audio=False):
    """Generic prediction function for audio/text."""
    try:
        if is_audio:
            pred, emb = extractor.extract(input_data, TARGET_SAMPLE_RATE)
        else:
            pred, emb = extractor.extract(input_data)

        return F.softmax(pred, dim=-1)[0].tolist(), emb
    except Exception as e:
        logging.error(f"Prediction failed: {e}")
        return [0.0] * len(LABEL_TO_EMOTION), None


def create_emotion_df(probabilities):
    """Create sorted emotion probability dataframe with percentages."""
    df = pd.DataFrame({
        'Emotion': list(LABEL_TO_EMOTION.values()),
        'Probability': [round(p*100, 2) for p in probabilities]
    })
    return df


def create_plot(df, title):
    """Create Plotly bar chart with proper formatting."""
    fig = px.bar(
        df,
        x='Emotion',
        y='Probability',
        title=title,
        color='Emotion',
        color_discrete_map=emotion_color_map
    )
    fig.update_layout(
        xaxis=dict(tickangle=-45, tickfont=dict(size=12)),
        yaxis=dict(title='Probability (%)'),
        margin=dict(l=20, r=20, t=60, b=100),
        height=400,
        showlegend=False
    )
    return fig


def get_top_emotion(probabilities):
    """Return formatted top emotion with percentage."""
    max_idx = probabilities.index(max(probabilities))
    return f"{LABEL_TO_EMOTION[max_idx]} ({max(probabilities)*100:.1f}%)"


def process_audio(audio_path):
    """Main processing pipeline."""
    try:
        if not audio_path:
            empty = create_emotion_df([0]*len(LABEL_TO_EMOTION))
            return (
                create_plot(empty, "🎧 Audio Analysis"),
                "No audio detected",
                create_plot(empty, "πŸ“ Text Analysis"),
                create_plot(empty, "🀝 Audio-Text Analysis"),
                "πŸ”‡ Please provide audio input"
            )

        # Audio processing
        waveform, sr = load_and_preprocess_audio(audio_path)
        audio_probs, audio_features = get_predictions(waveform, audio_extractor, is_audio=True)
        audio_df = create_emotion_df(audio_probs)

        # Text processing
        text = transcribe_audio(audio_path)
        text_probs, text_features = get_predictions(text, text_extractor) if text.strip() else [0.0]*len(LABEL_TO_EMOTION)
        text_df = create_emotion_df(text_probs)

        # Combined results
        combined_probs = bimodal_model(audio_features, text_features)
        combined_probs = F.softmax(combined_probs, dim=-1)[0].detach().cpu().numpy().tolist()
        combined_df = create_emotion_df(combined_probs)
        top_emotion = get_top_emotion(combined_probs)

        return (
            create_plot(audio_df, "🎧 Audio Analysis"),
            f"πŸ—£οΈ Transcription:\n{text}",
            create_plot(text_df, "πŸ“ Text Analysis"),
            create_plot(combined_df, "🀝 Audio-Text Analysis"),
            f"## πŸ† Dominant Emotion: {top_emotion}"
        )

    except Exception as e:
        logging.error(f"Processing failed: {e}")
        error_df = create_emotion_df([0]*len(LABEL_TO_EMOTION))
        return (
            create_plot(error_df, "🎧 Audio Analysis"),
            "❌ Error processing audio",
            create_plot(error_df, "πŸ“ Text Analysis"),
            create_plot(error_df, "🀝 Audio-Text Analysis"),
            "⚠️ Processing Error"
        )


def create_app():
    """Build enhanced Gradio interface."""
    with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection from Speech") as demo:
        gr.Markdown("# Intelligent system for Bilingual Bimodal Emotion Recognition (BiBiER)")
        gr.Markdown("Analyze emotions in Russian and English speech through both audio characteristics and spoken content")

        with gr.Row():
            audio_input = gr.Audio(
                sources=["upload", "microphone"],
                type="filepath",
                label="Record or Upload Audio",
                format="wav",
                interactive=True
            )

        with gr.Row():
            top_emotion = gr.Markdown("## πŸ† Dominant Emotion: Waiting for input ...",
                                      elem_classes="dominant-emotion")

        with gr.Row():
            with gr.Column():
                audio_plot = gr.Plot(label="Audio Analysis")
            with gr.Column():
                text_plot = gr.Plot(label="Text Analysis")
            with gr.Column():
                combined_plot = gr.Plot(label="Audio-Text Analysis")

        transcription = gr.Textbox(
            label="πŸ“œ Transcription Results",
            placeholder="Transcribed text will appear here...",
            lines=3,
            max_lines=6
        )

        audio_input.change(
            process_audio,
            inputs=audio_input,
            outputs=[audio_plot, transcription, text_plot, combined_plot, top_emotion]
        )

    return demo


def create_authors():
    df = pd.DataFrame({
        "Name": ["Author", "Author"]
    })
    with gr.Blocks() as demo:
        gr.Dataframe(df)
    return demo


def create_reqs():
    """Create requirements tab with formatted data and explanations."""
    # 1️⃣ Detect file encoding
    with open('requirements.txt', 'rb') as f:
        raw_data = f.read()
        encoding = chardet.detect(raw_data)['encoding']

    # 2️⃣ Parse requirements into library-version pairs
    def parse_requirements(lines):
        requirements = []
        for line in lines:
            line = line.strip()
            if not line or line.startswith('#'):
                continue  # Skip empty lines and comments
            parts = line.split('==')
            library = parts[0].strip()
            version = parts[1].strip() if len(parts) > 1 else 'latest'
            requirements.append((library, version))
        return requirements

    # 3️⃣ Load and process requirements
    with open('requirements.txt', 'r', encoding=encoding) as f:
        requirements = parse_requirements(f.readlines())

    # 4️⃣ Create structured data for display
    df = pd.DataFrame({
        "πŸ“¦ Library": [lib for lib, _ in requirements],
        "πŸš€ Recommended Version": [ver for _, ver in requirements]
    })

    # 5️⃣ Build interactive components
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("# πŸ“¦ Dependency Requirements")
        gr.Markdown("""
        ## Essential Packages for Operation
        These are the core libraries and versions needed to run the application successfully:
        """)
        gr.Dataframe(
            df,
            interactive=True,
            wrap=True,
            elem_id="requirements-table"
        )
        gr.Markdown("_Note: Versions marked 'latest' can use any compatible version_")

    return demo


def create_demo():
    app = create_app()
    authors = create_authors()
    reqs = create_reqs()
    demo = gr.TabbedInterface(
        [app, authors, reqs],
        tab_names=["⭐ App", "🎭 Authors", "πŸ“‹ Requirements"]
    )
    return demo


if __name__ == "__main__":
    demo = create_demo()
    demo.launch()