Spaces:
Sleeping
Sleeping
Commit
·
3188dd2
1
Parent(s):
d3c69b0
handling recording error properly!
Browse files- app.py +8 -2
- app.py_BKP +367 -0
app.py
CHANGED
@@ -327,12 +327,18 @@ def process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer):
|
|
327 |
# Apply bold styling to the button label
|
328 |
st.sidebar.markdown("<span style='font-weight: bolder;'>Record a 4 sec audio!</span>", unsafe_allow_html=True)
|
329 |
|
330 |
-
|
331 |
# recorded = True # Set the recording state to True after recording
|
332 |
|
333 |
# Add your audio recording code here
|
334 |
output_wav_file = "output.wav"
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
# # Use a div to encapsulate the audio element and apply the border
|
338 |
with st.sidebar.markdown('<div class="audio-container">', unsafe_allow_html=True):
|
|
|
327 |
# Apply bold styling to the button label
|
328 |
st.sidebar.markdown("<span style='font-weight: bolder;'>Record a 4 sec audio!</span>", unsafe_allow_html=True)
|
329 |
|
|
|
330 |
# recorded = True # Set the recording state to True after recording
|
331 |
|
332 |
# Add your audio recording code here
|
333 |
output_wav_file = "output.wav"
|
334 |
+
|
335 |
+
try:
|
336 |
+
record_audio(output_wav_file, duration=4)
|
337 |
+
except OSError as e:
|
338 |
+
if "[Errno -9996]" in str(e) and "Invalid input device (no default output device)" in str(e):
|
339 |
+
st.error("Recording not possible as no input device on cloud platforms. Please upload instead.")
|
340 |
+
else:
|
341 |
+
st.error(f"An error occurred while recording: {str(e)}")
|
342 |
|
343 |
# # Use a div to encapsulate the audio element and apply the border
|
344 |
with st.sidebar.markdown('<div class="audio-container">', unsafe_allow_html=True):
|
app.py_BKP
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import os
|
5 |
+
import librosa
|
6 |
+
import time
|
7 |
+
from matplotlib import cm
|
8 |
+
import soundfile as sf
|
9 |
+
import sounddevice as sd
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.optim as optim
|
13 |
+
from torch.utils.data import DataLoader, random_split
|
14 |
+
from PIL import Image
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import streamlit as st
|
17 |
+
import tempfile
|
18 |
+
import noisereduce as nr
|
19 |
+
import altair as alt
|
20 |
+
import pyaudio
|
21 |
+
import wave
|
22 |
+
import whisper
|
23 |
+
from transformers import (
|
24 |
+
HubertForSequenceClassification,
|
25 |
+
Wav2Vec2FeatureExtractor,
|
26 |
+
AutoModel,
|
27 |
+
AutoTokenizer,
|
28 |
+
HubertForSequenceClassification
|
29 |
+
)
|
30 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
31 |
+
|
32 |
+
emo2promptMapping = {
|
33 |
+
'Angry':'ANGRY',
|
34 |
+
'Calm':'CALM',
|
35 |
+
'Disgust':'DISGUSTED',
|
36 |
+
'Fearful':'FEARFUL',
|
37 |
+
'Happy': 'HAPPY',
|
38 |
+
'Sad': 'SAD',
|
39 |
+
'Surprised': 'SURPRISED'
|
40 |
+
}
|
41 |
+
|
42 |
+
# Check if GPU (cuda) is available
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
device = torch.device('cuda')
|
45 |
+
else:
|
46 |
+
device = torch.device('cpu')
|
47 |
+
|
48 |
+
#Load speech to text model
|
49 |
+
speech_model = whisper.load_model("base")
|
50 |
+
|
51 |
+
#Define Labels related info
|
52 |
+
num_labels=7
|
53 |
+
label_mapping = ['angry', 'calm', 'disgust', 'fearful', 'happy', 'sad', 'surprised']
|
54 |
+
|
55 |
+
# Define your model name from the Hugging Face model hub
|
56 |
+
model_weights_path = "https://huggingface.co/netgvarun2005/MultiModalBertHubert/resolve/main/MultiModal_model_state_dict.pth"
|
57 |
+
|
58 |
+
# Emo Detector
|
59 |
+
model_id = "facebook/hubert-base-ls960"
|
60 |
+
bert_model_name = "bert-base-uncased"
|
61 |
+
|
62 |
+
def config():
|
63 |
+
# Loading Image using PIL
|
64 |
+
im = Image.open('./icon.png')
|
65 |
+
|
66 |
+
# Set the page configuration with the title and icon
|
67 |
+
st.set_page_config(page_title="Virtual Therapist", page_icon=im)
|
68 |
+
|
69 |
+
# Add custom CSS styles
|
70 |
+
st.markdown("""
|
71 |
+
<style>
|
72 |
+
.mobile-screen {
|
73 |
+
border: 2px solid black;
|
74 |
+
display: flex;
|
75 |
+
flex-direction: column;
|
76 |
+
align-items: center;
|
77 |
+
justify-content: flex-start; /* Align content to the top */
|
78 |
+
height: 20vh;
|
79 |
+
padding: 20px;
|
80 |
+
border-radius: 10px;
|
81 |
+
}
|
82 |
+
</style>
|
83 |
+
""", unsafe_allow_html=True)
|
84 |
+
# Render mobile screen container and its content
|
85 |
+
st.sidebar.title("Sound Recorder")
|
86 |
+
|
87 |
+
# Define a custom style for your title
|
88 |
+
title_style = """
|
89 |
+
<style>
|
90 |
+
h1 {
|
91 |
+
font-family: 'Comic Sans MS', cursive, sans-serif;
|
92 |
+
color: blue;
|
93 |
+
font-size: 22px; /* Add font size here */
|
94 |
+
}
|
95 |
+
</style>
|
96 |
+
"""
|
97 |
+
|
98 |
+
# Display the title with the custom style
|
99 |
+
st.markdown(title_style, unsafe_allow_html=True)
|
100 |
+
st.markdown("# WELCOME! HOW ARE YOU FEELING? PLEASE RECORD AN AUDIO!", unsafe_allow_html=True)
|
101 |
+
st.markdown("# BASED ON YOUR EMOTIONAL STATE, I WILL SUGGEST SOME TIPS!", unsafe_allow_html=True)
|
102 |
+
|
103 |
+
|
104 |
+
return
|
105 |
+
|
106 |
+
|
107 |
+
class MultimodalModel(nn.Module):
|
108 |
+
'''
|
109 |
+
Custom PyTorch model that takes as input both the audio features and the text embeddings, and concatenates the last hidden states from the Hubert and BERT models.
|
110 |
+
'''
|
111 |
+
def __init__(self, bert_model_name, num_labels):
|
112 |
+
super().__init__()
|
113 |
+
self.hubert = HubertForSequenceClassification.from_pretrained("netgvarun2005/HubertStandaloneEmoDetector", num_labels=num_labels).hubert
|
114 |
+
self.bert = AutoModel.from_pretrained(bert_model_name)
|
115 |
+
self.classifier = nn.Linear(self.hubert.config.hidden_size + self.bert.config.hidden_size, num_labels)
|
116 |
+
|
117 |
+
def forward(self, input_values, text):
|
118 |
+
hubert_output = self.hubert(input_values).last_hidden_state
|
119 |
+
|
120 |
+
bert_output = self.bert(text).last_hidden_state
|
121 |
+
|
122 |
+
# Apply mean pooling along the sequence dimension
|
123 |
+
hubert_output = hubert_output.mean(dim=1)
|
124 |
+
bert_output = bert_output.mean(dim=1)
|
125 |
+
|
126 |
+
concat_output = torch.cat((hubert_output, bert_output), dim=-1)
|
127 |
+
logits = self.classifier(concat_output)
|
128 |
+
return logits
|
129 |
+
|
130 |
+
|
131 |
+
def speechtoText(wavfile):
|
132 |
+
return speech_model.transcribe(wavfile)['text']
|
133 |
+
|
134 |
+
def resampleaudio(wavfile):
|
135 |
+
audio, sr = librosa.load(wavfile, sr=None)
|
136 |
+
|
137 |
+
# Set the desired target sample rate
|
138 |
+
target_sample_rate = 16000
|
139 |
+
|
140 |
+
# Resample the audio to the target sample rate
|
141 |
+
resampled_audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate)
|
142 |
+
|
143 |
+
sf.write(wavfile,resampled_audio, target_sample_rate)
|
144 |
+
return wavfile
|
145 |
+
|
146 |
+
|
147 |
+
def noiseReduction(wavfile):
|
148 |
+
audio, sr = librosa.load(wavfile, sr=None)
|
149 |
+
|
150 |
+
# Set parameters for noise reduction
|
151 |
+
n_fft = 2048 # FFT window size
|
152 |
+
hop_length = 512 # Hop length for STFT
|
153 |
+
|
154 |
+
# Perform noise reduction
|
155 |
+
reduced_noise = nr.reduce_noise(y=audio, sr=sr, n_fft=n_fft, hop_length=hop_length)
|
156 |
+
|
157 |
+
# Save the denoised audio to a new WAV file
|
158 |
+
sf.write(wavfile,reduced_noise, sr)
|
159 |
+
return wavfile
|
160 |
+
|
161 |
+
|
162 |
+
def removeSilence(wavfile):
|
163 |
+
# Load the audio file
|
164 |
+
audio_file = wavfile
|
165 |
+
|
166 |
+
audio, sr = librosa.load(audio_file, sr=None)
|
167 |
+
|
168 |
+
# Split the audio file based on silence
|
169 |
+
clips = librosa.effects.split(audio, top_db=40)
|
170 |
+
|
171 |
+
# Combine the audio clips
|
172 |
+
non_silent_audio = []
|
173 |
+
for start, end in clips:
|
174 |
+
non_silent_audio.extend(audio[start:end])
|
175 |
+
|
176 |
+
|
177 |
+
# Save the audio without silence to a new WAV file
|
178 |
+
sf.write(wavfile,non_silent_audio, sr)
|
179 |
+
return wavfile
|
180 |
+
|
181 |
+
def preprocessWavFile(wavfile):
|
182 |
+
resampledwavfile = resampleaudio(wavfile)
|
183 |
+
denoised_file = noiseReduction(resampledwavfile)
|
184 |
+
return removeSilence(denoised_file)
|
185 |
+
|
186 |
+
@st.cache_data()
|
187 |
+
def load_model():
|
188 |
+
# Load the model
|
189 |
+
multiModel = MultimodalModel(bert_model_name, num_labels)
|
190 |
+
|
191 |
+
# Load the model weights directly from Hugging Face Spaces
|
192 |
+
multiModel.load_state_dict(torch.hub.load_state_dict_from_url(model_weights_path, map_location=device), strict=False)
|
193 |
+
|
194 |
+
# multiModel.load_state_dict(torch.load(file_path + "/MultiModal_model_state_dict.pth",map_location=device),strict=False)
|
195 |
+
tokenizer = AutoTokenizer.from_pretrained("netgvarun2005/MultiModalBertHubertTokenizer")
|
196 |
+
|
197 |
+
# GenAI
|
198 |
+
tokenizer_gpt = AutoTokenizer.from_pretrained("netgvarun2005/GPTVirtualTherapistTokenizer", pad_token='<|pad|>',bos_token='<|startoftext|>',eos_token='<|endoftext|>')
|
199 |
+
model_gpt = AutoModelForCausalLM.from_pretrained("netgvarun2005/GPTVirtualTherapist")
|
200 |
+
|
201 |
+
return multiModel,tokenizer,model_gpt,tokenizer_gpt
|
202 |
+
|
203 |
+
|
204 |
+
def predict(audio_array,multiModal_model,key,tokenizer,text):
|
205 |
+
input_text = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
206 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_id)
|
207 |
+
|
208 |
+
input_audio = feature_extractor(
|
209 |
+
raw_speech=audio_array,
|
210 |
+
sampling_rate=16000,
|
211 |
+
padding=True,
|
212 |
+
return_tensors="pt"
|
213 |
+
)
|
214 |
+
logits = multiModal_model(input_audio["input_values"], input_text["input_ids"])
|
215 |
+
|
216 |
+
probabilities = F.softmax(logits, dim=1).to_dense()
|
217 |
+
_, predicted = torch.max(probabilities, 1)
|
218 |
+
class_prob = probabilities.tolist()
|
219 |
+
class_prob = class_prob[0]
|
220 |
+
class_prob = [round(value, 2) for value in class_prob]
|
221 |
+
maxVal = np.argmax(class_prob)
|
222 |
+
|
223 |
+
# Display the final transcript
|
224 |
+
if label_mapping[predicted] == "":
|
225 |
+
st.write("Inference impossible, a problem occurred with your audio or your parameters, we apologize :(")
|
226 |
+
|
227 |
+
return (label_mapping[maxVal]).capitalize()
|
228 |
+
|
229 |
+
def record_audio(output_file, duration=5):
|
230 |
+
# st.sidebar.markdown("Recording...")
|
231 |
+
sd.wait() # Wait for microphone to start
|
232 |
+
sd.wait() # Wait for microphone to start
|
233 |
+
time.sleep(0.4)
|
234 |
+
|
235 |
+
st.sidebar.markdown("<p style='font-size: 14px; font-weight: bold;'>Recording...</p>", unsafe_allow_html=True)
|
236 |
+
|
237 |
+
chunk = 1024
|
238 |
+
sample_format = pyaudio.paInt16
|
239 |
+
channels = 2
|
240 |
+
fs = 44100
|
241 |
+
|
242 |
+
p = pyaudio.PyAudio()
|
243 |
+
|
244 |
+
stream = p.open(format=sample_format,
|
245 |
+
channels=channels,
|
246 |
+
rate=fs,
|
247 |
+
frames_per_buffer=chunk,
|
248 |
+
input=True)
|
249 |
+
|
250 |
+
frames = []
|
251 |
+
|
252 |
+
for _ in range(int(fs / chunk * duration)):
|
253 |
+
data = stream.read(chunk)
|
254 |
+
frames.append(data)
|
255 |
+
|
256 |
+
stream.stop_stream()
|
257 |
+
stream.close()
|
258 |
+
p.terminate()
|
259 |
+
|
260 |
+
wf = wave.open(output_file, 'wb')
|
261 |
+
wf.setnchannels(channels)
|
262 |
+
wf.setsampwidth(p.get_sample_size(sample_format))
|
263 |
+
wf.setframerate(fs)
|
264 |
+
wf.writeframes(b''.join(frames))
|
265 |
+
wf.close()
|
266 |
+
time.sleep(0.5)
|
267 |
+
# st.sidebar.markdown("Recording finished!")
|
268 |
+
st.sidebar.markdown("<p style='font-size: 14px; font-weight: bold;'>Recording finished!</p>", unsafe_allow_html=True)
|
269 |
+
|
270 |
+
time.sleep(0.5)
|
271 |
+
|
272 |
+
def GenerateText(emo,gpt_tokenizer,gpt_model):
|
273 |
+
prompt = f'<startoftext>{emo2promptMapping[emo]}:'
|
274 |
+
|
275 |
+
generated = gpt_tokenizer(prompt, return_tensors="pt").input_ids
|
276 |
+
|
277 |
+
sample_outputs = gpt_model.generate(generated, do_sample=True, top_k=50,
|
278 |
+
max_length=20, top_p=0.95, temperature=0.2, num_return_sequences=10,no_repeat_ngram_size=1)
|
279 |
+
|
280 |
+
# Extract and split the generated text into words
|
281 |
+
outputs = set([gpt_tokenizer.decode(sample_output, skip_special_tokens=True).split(':')[-1] for sample_output in sample_outputs])
|
282 |
+
for i, sample_output in enumerate(outputs):
|
283 |
+
st.write(f"<span style='font-size: 18px; font-family: Arial, sans-serif; font-weight: bold;'>{i+1}: {sample_output}</span>", unsafe_allow_html=True)
|
284 |
+
time.sleep(0.5)
|
285 |
+
|
286 |
+
|
287 |
+
def process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer):
|
288 |
+
emo = ""
|
289 |
+
button_label = "Show Helpful Tips"
|
290 |
+
recorded = False # Initialize the recording state as False
|
291 |
+
|
292 |
+
if 'stage' not in st.session_state:
|
293 |
+
st.session_state.stage = 0
|
294 |
+
|
295 |
+
def set_stage(stage):
|
296 |
+
st.session_state.stage = stage
|
297 |
+
|
298 |
+
# Add custom CSS styles
|
299 |
+
st.markdown("""
|
300 |
+
<style>
|
301 |
+
.stRecordButton {
|
302 |
+
width: 50px;
|
303 |
+
height: 50px;
|
304 |
+
border-radius: 50px;
|
305 |
+
background-color: red;
|
306 |
+
color: black; /* Text color */
|
307 |
+
font-size: 16px;
|
308 |
+
font-weight: bold;
|
309 |
+
border: 2px solid white; /* Solid border */
|
310 |
+
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
|
311 |
+
cursor: pointer;
|
312 |
+
transition: background-color 0.2s;
|
313 |
+
display: flex;
|
314 |
+
justify-content: center;
|
315 |
+
align-items: center;
|
316 |
+
}
|
317 |
+
|
318 |
+
.stRecordButton:hover {
|
319 |
+
background-color: darkred; /* Change background color on hover */
|
320 |
+
}
|
321 |
+
</style>
|
322 |
+
""", unsafe_allow_html=True)
|
323 |
+
|
324 |
+
if st.sidebar.button("Record a 4 sec audio!", key="record_button", help="Click to start recording", on_click=set_stage, args=(1,)):
|
325 |
+
# Your button click action here
|
326 |
+
|
327 |
+
# Apply bold styling to the button label
|
328 |
+
st.sidebar.markdown("<span style='font-weight: bolder;'>Record a 4 sec audio!</span>", unsafe_allow_html=True)
|
329 |
+
|
330 |
+
|
331 |
+
# recorded = True # Set the recording state to True after recording
|
332 |
+
|
333 |
+
# Add your audio recording code here
|
334 |
+
output_wav_file = "output.wav"
|
335 |
+
record_audio(output_wav_file, duration=4)
|
336 |
+
|
337 |
+
# # Use a div to encapsulate the audio element and apply the border
|
338 |
+
with st.sidebar.markdown('<div class="audio-container">', unsafe_allow_html=True):
|
339 |
+
# Play recorded sound
|
340 |
+
st.audio(output_wav_file, format="wav")
|
341 |
+
|
342 |
+
audio_array, sr = librosa.load(preprocessWavFile(output_wav_file), sr=None)
|
343 |
+
st.sidebar.markdown("<p style='font-size: 14px; font-weight: bold;'>Generating transcriptions! Please wait...</p>", unsafe_allow_html=True)
|
344 |
+
|
345 |
+
transcription = speechtoText(output_wav_file)
|
346 |
+
|
347 |
+
emo = predict(audio_array,ser_model,2,tokenizer,transcription)
|
348 |
+
|
349 |
+
# Display the transcription in a textbox
|
350 |
+
st.sidebar.text_area("Transcription", transcription, height=25)
|
351 |
+
|
352 |
+
txt = f"You seem to be <b>{(emo2promptMapping[emo]).capitalize()}!</b>\n Click on 'Show Helpful Tips' button to proceed further."
|
353 |
+
st.markdown(f"<div class='mobile-screen' style='font-size: 24px;'>{txt} </div>", unsafe_allow_html=True)
|
354 |
+
|
355 |
+
# Store the value of emo in the session state
|
356 |
+
st.session_state.emo = emo
|
357 |
+
|
358 |
+
if st.session_state.stage > 0:
|
359 |
+
if st.button(button_label,on_click=set_stage, args=(2,)):
|
360 |
+
# Retrieve prompt from the emotion
|
361 |
+
emo = st.session_state.emo
|
362 |
+
GenerateText(emo,gpt_tokenizer,gpt_model)
|
363 |
+
|
364 |
+
if __name__ == '__main__':
|
365 |
+
config()
|
366 |
+
ser_model,tokenizer,gpt_model,gpt_tokenizer = load_model()
|
367 |
+
process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer)
|