sharmavaruncs commited on
Commit
3188dd2
·
1 Parent(s): d3c69b0

handling recording error properly!

Browse files
Files changed (2) hide show
  1. app.py +8 -2
  2. app.py_BKP +367 -0
app.py CHANGED
@@ -327,12 +327,18 @@ def process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer):
327
  # Apply bold styling to the button label
328
  st.sidebar.markdown("<span style='font-weight: bolder;'>Record a 4 sec audio!</span>", unsafe_allow_html=True)
329
 
330
-
331
  # recorded = True # Set the recording state to True after recording
332
 
333
  # Add your audio recording code here
334
  output_wav_file = "output.wav"
335
- record_audio(output_wav_file, duration=4)
 
 
 
 
 
 
 
336
 
337
  # # Use a div to encapsulate the audio element and apply the border
338
  with st.sidebar.markdown('<div class="audio-container">', unsafe_allow_html=True):
 
327
  # Apply bold styling to the button label
328
  st.sidebar.markdown("<span style='font-weight: bolder;'>Record a 4 sec audio!</span>", unsafe_allow_html=True)
329
 
 
330
  # recorded = True # Set the recording state to True after recording
331
 
332
  # Add your audio recording code here
333
  output_wav_file = "output.wav"
334
+
335
+ try:
336
+ record_audio(output_wav_file, duration=4)
337
+ except OSError as e:
338
+ if "[Errno -9996]" in str(e) and "Invalid input device (no default output device)" in str(e):
339
+ st.error("Recording not possible as no input device on cloud platforms. Please upload instead.")
340
+ else:
341
+ st.error(f"An error occurred while recording: {str(e)}")
342
 
343
  # # Use a div to encapsulate the audio element and apply the border
344
  with st.sidebar.markdown('<div class="audio-container">', unsafe_allow_html=True):
app.py_BKP ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ import librosa
6
+ import time
7
+ from matplotlib import cm
8
+ import soundfile as sf
9
+ import sounddevice as sd
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from torch.utils.data import DataLoader, random_split
14
+ from PIL import Image
15
+ import torch.nn.functional as F
16
+ import streamlit as st
17
+ import tempfile
18
+ import noisereduce as nr
19
+ import altair as alt
20
+ import pyaudio
21
+ import wave
22
+ import whisper
23
+ from transformers import (
24
+ HubertForSequenceClassification,
25
+ Wav2Vec2FeatureExtractor,
26
+ AutoModel,
27
+ AutoTokenizer,
28
+ HubertForSequenceClassification
29
+ )
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM
31
+
32
+ emo2promptMapping = {
33
+ 'Angry':'ANGRY',
34
+ 'Calm':'CALM',
35
+ 'Disgust':'DISGUSTED',
36
+ 'Fearful':'FEARFUL',
37
+ 'Happy': 'HAPPY',
38
+ 'Sad': 'SAD',
39
+ 'Surprised': 'SURPRISED'
40
+ }
41
+
42
+ # Check if GPU (cuda) is available
43
+ if torch.cuda.is_available():
44
+ device = torch.device('cuda')
45
+ else:
46
+ device = torch.device('cpu')
47
+
48
+ #Load speech to text model
49
+ speech_model = whisper.load_model("base")
50
+
51
+ #Define Labels related info
52
+ num_labels=7
53
+ label_mapping = ['angry', 'calm', 'disgust', 'fearful', 'happy', 'sad', 'surprised']
54
+
55
+ # Define your model name from the Hugging Face model hub
56
+ model_weights_path = "https://huggingface.co/netgvarun2005/MultiModalBertHubert/resolve/main/MultiModal_model_state_dict.pth"
57
+
58
+ # Emo Detector
59
+ model_id = "facebook/hubert-base-ls960"
60
+ bert_model_name = "bert-base-uncased"
61
+
62
+ def config():
63
+ # Loading Image using PIL
64
+ im = Image.open('./icon.png')
65
+
66
+ # Set the page configuration with the title and icon
67
+ st.set_page_config(page_title="Virtual Therapist", page_icon=im)
68
+
69
+ # Add custom CSS styles
70
+ st.markdown("""
71
+ <style>
72
+ .mobile-screen {
73
+ border: 2px solid black;
74
+ display: flex;
75
+ flex-direction: column;
76
+ align-items: center;
77
+ justify-content: flex-start; /* Align content to the top */
78
+ height: 20vh;
79
+ padding: 20px;
80
+ border-radius: 10px;
81
+ }
82
+ </style>
83
+ """, unsafe_allow_html=True)
84
+ # Render mobile screen container and its content
85
+ st.sidebar.title("Sound Recorder")
86
+
87
+ # Define a custom style for your title
88
+ title_style = """
89
+ <style>
90
+ h1 {
91
+ font-family: 'Comic Sans MS', cursive, sans-serif;
92
+ color: blue;
93
+ font-size: 22px; /* Add font size here */
94
+ }
95
+ </style>
96
+ """
97
+
98
+ # Display the title with the custom style
99
+ st.markdown(title_style, unsafe_allow_html=True)
100
+ st.markdown("# WELCOME! HOW ARE YOU FEELING? PLEASE RECORD AN AUDIO!", unsafe_allow_html=True)
101
+ st.markdown("# BASED ON YOUR EMOTIONAL STATE, I WILL SUGGEST SOME TIPS!", unsafe_allow_html=True)
102
+
103
+
104
+ return
105
+
106
+
107
+ class MultimodalModel(nn.Module):
108
+ '''
109
+ Custom PyTorch model that takes as input both the audio features and the text embeddings, and concatenates the last hidden states from the Hubert and BERT models.
110
+ '''
111
+ def __init__(self, bert_model_name, num_labels):
112
+ super().__init__()
113
+ self.hubert = HubertForSequenceClassification.from_pretrained("netgvarun2005/HubertStandaloneEmoDetector", num_labels=num_labels).hubert
114
+ self.bert = AutoModel.from_pretrained(bert_model_name)
115
+ self.classifier = nn.Linear(self.hubert.config.hidden_size + self.bert.config.hidden_size, num_labels)
116
+
117
+ def forward(self, input_values, text):
118
+ hubert_output = self.hubert(input_values).last_hidden_state
119
+
120
+ bert_output = self.bert(text).last_hidden_state
121
+
122
+ # Apply mean pooling along the sequence dimension
123
+ hubert_output = hubert_output.mean(dim=1)
124
+ bert_output = bert_output.mean(dim=1)
125
+
126
+ concat_output = torch.cat((hubert_output, bert_output), dim=-1)
127
+ logits = self.classifier(concat_output)
128
+ return logits
129
+
130
+
131
+ def speechtoText(wavfile):
132
+ return speech_model.transcribe(wavfile)['text']
133
+
134
+ def resampleaudio(wavfile):
135
+ audio, sr = librosa.load(wavfile, sr=None)
136
+
137
+ # Set the desired target sample rate
138
+ target_sample_rate = 16000
139
+
140
+ # Resample the audio to the target sample rate
141
+ resampled_audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate)
142
+
143
+ sf.write(wavfile,resampled_audio, target_sample_rate)
144
+ return wavfile
145
+
146
+
147
+ def noiseReduction(wavfile):
148
+ audio, sr = librosa.load(wavfile, sr=None)
149
+
150
+ # Set parameters for noise reduction
151
+ n_fft = 2048 # FFT window size
152
+ hop_length = 512 # Hop length for STFT
153
+
154
+ # Perform noise reduction
155
+ reduced_noise = nr.reduce_noise(y=audio, sr=sr, n_fft=n_fft, hop_length=hop_length)
156
+
157
+ # Save the denoised audio to a new WAV file
158
+ sf.write(wavfile,reduced_noise, sr)
159
+ return wavfile
160
+
161
+
162
+ def removeSilence(wavfile):
163
+ # Load the audio file
164
+ audio_file = wavfile
165
+
166
+ audio, sr = librosa.load(audio_file, sr=None)
167
+
168
+ # Split the audio file based on silence
169
+ clips = librosa.effects.split(audio, top_db=40)
170
+
171
+ # Combine the audio clips
172
+ non_silent_audio = []
173
+ for start, end in clips:
174
+ non_silent_audio.extend(audio[start:end])
175
+
176
+
177
+ # Save the audio without silence to a new WAV file
178
+ sf.write(wavfile,non_silent_audio, sr)
179
+ return wavfile
180
+
181
+ def preprocessWavFile(wavfile):
182
+ resampledwavfile = resampleaudio(wavfile)
183
+ denoised_file = noiseReduction(resampledwavfile)
184
+ return removeSilence(denoised_file)
185
+
186
+ @st.cache_data()
187
+ def load_model():
188
+ # Load the model
189
+ multiModel = MultimodalModel(bert_model_name, num_labels)
190
+
191
+ # Load the model weights directly from Hugging Face Spaces
192
+ multiModel.load_state_dict(torch.hub.load_state_dict_from_url(model_weights_path, map_location=device), strict=False)
193
+
194
+ # multiModel.load_state_dict(torch.load(file_path + "/MultiModal_model_state_dict.pth",map_location=device),strict=False)
195
+ tokenizer = AutoTokenizer.from_pretrained("netgvarun2005/MultiModalBertHubertTokenizer")
196
+
197
+ # GenAI
198
+ tokenizer_gpt = AutoTokenizer.from_pretrained("netgvarun2005/GPTVirtualTherapistTokenizer", pad_token='<|pad|>',bos_token='<|startoftext|>',eos_token='<|endoftext|>')
199
+ model_gpt = AutoModelForCausalLM.from_pretrained("netgvarun2005/GPTVirtualTherapist")
200
+
201
+ return multiModel,tokenizer,model_gpt,tokenizer_gpt
202
+
203
+
204
+ def predict(audio_array,multiModal_model,key,tokenizer,text):
205
+ input_text = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
206
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_id)
207
+
208
+ input_audio = feature_extractor(
209
+ raw_speech=audio_array,
210
+ sampling_rate=16000,
211
+ padding=True,
212
+ return_tensors="pt"
213
+ )
214
+ logits = multiModal_model(input_audio["input_values"], input_text["input_ids"])
215
+
216
+ probabilities = F.softmax(logits, dim=1).to_dense()
217
+ _, predicted = torch.max(probabilities, 1)
218
+ class_prob = probabilities.tolist()
219
+ class_prob = class_prob[0]
220
+ class_prob = [round(value, 2) for value in class_prob]
221
+ maxVal = np.argmax(class_prob)
222
+
223
+ # Display the final transcript
224
+ if label_mapping[predicted] == "":
225
+ st.write("Inference impossible, a problem occurred with your audio or your parameters, we apologize :(")
226
+
227
+ return (label_mapping[maxVal]).capitalize()
228
+
229
+ def record_audio(output_file, duration=5):
230
+ # st.sidebar.markdown("Recording...")
231
+ sd.wait() # Wait for microphone to start
232
+ sd.wait() # Wait for microphone to start
233
+ time.sleep(0.4)
234
+
235
+ st.sidebar.markdown("<p style='font-size: 14px; font-weight: bold;'>Recording...</p>", unsafe_allow_html=True)
236
+
237
+ chunk = 1024
238
+ sample_format = pyaudio.paInt16
239
+ channels = 2
240
+ fs = 44100
241
+
242
+ p = pyaudio.PyAudio()
243
+
244
+ stream = p.open(format=sample_format,
245
+ channels=channels,
246
+ rate=fs,
247
+ frames_per_buffer=chunk,
248
+ input=True)
249
+
250
+ frames = []
251
+
252
+ for _ in range(int(fs / chunk * duration)):
253
+ data = stream.read(chunk)
254
+ frames.append(data)
255
+
256
+ stream.stop_stream()
257
+ stream.close()
258
+ p.terminate()
259
+
260
+ wf = wave.open(output_file, 'wb')
261
+ wf.setnchannels(channels)
262
+ wf.setsampwidth(p.get_sample_size(sample_format))
263
+ wf.setframerate(fs)
264
+ wf.writeframes(b''.join(frames))
265
+ wf.close()
266
+ time.sleep(0.5)
267
+ # st.sidebar.markdown("Recording finished!")
268
+ st.sidebar.markdown("<p style='font-size: 14px; font-weight: bold;'>Recording finished!</p>", unsafe_allow_html=True)
269
+
270
+ time.sleep(0.5)
271
+
272
+ def GenerateText(emo,gpt_tokenizer,gpt_model):
273
+ prompt = f'<startoftext>{emo2promptMapping[emo]}:'
274
+
275
+ generated = gpt_tokenizer(prompt, return_tensors="pt").input_ids
276
+
277
+ sample_outputs = gpt_model.generate(generated, do_sample=True, top_k=50,
278
+ max_length=20, top_p=0.95, temperature=0.2, num_return_sequences=10,no_repeat_ngram_size=1)
279
+
280
+ # Extract and split the generated text into words
281
+ outputs = set([gpt_tokenizer.decode(sample_output, skip_special_tokens=True).split(':')[-1] for sample_output in sample_outputs])
282
+ for i, sample_output in enumerate(outputs):
283
+ st.write(f"<span style='font-size: 18px; font-family: Arial, sans-serif; font-weight: bold;'>{i+1}: {sample_output}</span>", unsafe_allow_html=True)
284
+ time.sleep(0.5)
285
+
286
+
287
+ def process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer):
288
+ emo = ""
289
+ button_label = "Show Helpful Tips"
290
+ recorded = False # Initialize the recording state as False
291
+
292
+ if 'stage' not in st.session_state:
293
+ st.session_state.stage = 0
294
+
295
+ def set_stage(stage):
296
+ st.session_state.stage = stage
297
+
298
+ # Add custom CSS styles
299
+ st.markdown("""
300
+ <style>
301
+ .stRecordButton {
302
+ width: 50px;
303
+ height: 50px;
304
+ border-radius: 50px;
305
+ background-color: red;
306
+ color: black; /* Text color */
307
+ font-size: 16px;
308
+ font-weight: bold;
309
+ border: 2px solid white; /* Solid border */
310
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
311
+ cursor: pointer;
312
+ transition: background-color 0.2s;
313
+ display: flex;
314
+ justify-content: center;
315
+ align-items: center;
316
+ }
317
+
318
+ .stRecordButton:hover {
319
+ background-color: darkred; /* Change background color on hover */
320
+ }
321
+ </style>
322
+ """, unsafe_allow_html=True)
323
+
324
+ if st.sidebar.button("Record a 4 sec audio!", key="record_button", help="Click to start recording", on_click=set_stage, args=(1,)):
325
+ # Your button click action here
326
+
327
+ # Apply bold styling to the button label
328
+ st.sidebar.markdown("<span style='font-weight: bolder;'>Record a 4 sec audio!</span>", unsafe_allow_html=True)
329
+
330
+
331
+ # recorded = True # Set the recording state to True after recording
332
+
333
+ # Add your audio recording code here
334
+ output_wav_file = "output.wav"
335
+ record_audio(output_wav_file, duration=4)
336
+
337
+ # # Use a div to encapsulate the audio element and apply the border
338
+ with st.sidebar.markdown('<div class="audio-container">', unsafe_allow_html=True):
339
+ # Play recorded sound
340
+ st.audio(output_wav_file, format="wav")
341
+
342
+ audio_array, sr = librosa.load(preprocessWavFile(output_wav_file), sr=None)
343
+ st.sidebar.markdown("<p style='font-size: 14px; font-weight: bold;'>Generating transcriptions! Please wait...</p>", unsafe_allow_html=True)
344
+
345
+ transcription = speechtoText(output_wav_file)
346
+
347
+ emo = predict(audio_array,ser_model,2,tokenizer,transcription)
348
+
349
+ # Display the transcription in a textbox
350
+ st.sidebar.text_area("Transcription", transcription, height=25)
351
+
352
+ txt = f"You seem to be <b>{(emo2promptMapping[emo]).capitalize()}!</b>\n Click on 'Show Helpful Tips' button to proceed further."
353
+ st.markdown(f"<div class='mobile-screen' style='font-size: 24px;'>{txt} </div>", unsafe_allow_html=True)
354
+
355
+ # Store the value of emo in the session state
356
+ st.session_state.emo = emo
357
+
358
+ if st.session_state.stage > 0:
359
+ if st.button(button_label,on_click=set_stage, args=(2,)):
360
+ # Retrieve prompt from the emotion
361
+ emo = st.session_state.emo
362
+ GenerateText(emo,gpt_tokenizer,gpt_model)
363
+
364
+ if __name__ == '__main__':
365
+ config()
366
+ ser_model,tokenizer,gpt_model,gpt_tokenizer = load_model()
367
+ process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer)