sharmavaruncs commited on
Commit
608d2a8
·
1 Parent(s): 2f5687a

Added spinner for HelpfulTips

Browse files
Files changed (2) hide show
  1. .app.py.swp +0 -0
  2. app.py +184 -136
.app.py.swp DELETED
Binary file (16.4 kB)
 
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import pandas as pd
2
  import numpy as np
3
  import matplotlib.pyplot as plt
@@ -6,17 +7,13 @@ import librosa
6
  import time
7
  from matplotlib import cm
8
  import soundfile as sf
9
- import sounddevice as sd
10
  import torch
11
  import torch.nn as nn
12
- import torch.optim as optim
13
- from torch.utils.data import DataLoader, random_split
14
  from PIL import Image
15
  import torch.nn.functional as F
16
  import streamlit as st
17
  import tempfile
18
  import noisereduce as nr
19
- import altair as alt
20
  import pyaudio
21
  import wave
22
  import whisper
@@ -25,12 +22,12 @@ from transformers import (
25
  Wav2Vec2FeatureExtractor,
26
  AutoModel,
27
  AutoTokenizer,
28
- HubertForSequenceClassification
 
29
  )
30
- from transformers import AutoTokenizer, AutoModelForCausalLM
31
- import webbrowser
32
  from streamlit.components.v1 import html
33
 
 
34
  emo2promptMapping = {
35
  'Angry':'ANGRY',
36
  'Calm':'CALM',
@@ -54,15 +51,24 @@ speech_model = whisper.load_model("base")
54
  num_labels=7
55
  label_mapping = ['angry', 'calm', 'disgust', 'fearful', 'happy', 'sad', 'surprised']
56
 
57
- # Define your model name from the Hugging Face model hub
58
  model_weights_path = "https://huggingface.co/netgvarun2005/MultiModalBertHubert/resolve/main/MultiModal_model_state_dict.pth"
59
 
60
- # Emo Detector
61
  model_id = "facebook/hubert-base-ls960"
62
  bert_model_name = "bert-base-uncased"
63
 
64
 
65
  def open_page(url):
 
 
 
 
 
 
 
 
 
66
  open_script= """
67
  <script type="text/javascript">
68
  window.open('%s', '_blank').focus();
@@ -71,22 +77,24 @@ def open_page(url):
71
  html(open_script)
72
 
73
  def config():
 
 
 
 
 
 
 
 
 
 
 
 
74
  # Loading Image using PIL
75
  im = Image.open('./icon.png')
76
 
77
  # Set the page configuration with the title and icon
78
  st.set_page_config(page_title="Virtual Therapist", page_icon=im)
79
 
80
- # if st.sidebar.markdown("**Open External Audio Recorder**"):
81
- # # url = 'https://voice-recorder-online.com/'
82
- # # # webbrowser.open_new_tab(url)
83
- # # st.markdown(f'''
84
- # # <a href={url}><button style="background-color:GreenYellow;">Stackoverflow</button></a>
85
- # # ''', unsafe_allow_html=True)
86
- # st.markdown("<a href='https://voice-recorder-online.com/' target='_blank'>Redirecting to the external audio recorder</a>.", unsafe_allow_html=True)
87
-
88
- # st.sidebar.button('[**Open External Audio Recorder**]()')
89
-
90
  # Add custom CSS styles
91
  st.markdown("""
92
  <style>
@@ -102,8 +110,6 @@ def config():
102
  }
103
  </style>
104
  """, unsafe_allow_html=True)
105
- # Render mobile screen container and its content
106
- #st.sidebar.title("Sound Recorder")
107
 
108
  # Define a custom style for your title
109
  title_style = """
@@ -120,7 +126,6 @@ def config():
120
  st.markdown(title_style, unsafe_allow_html=True)
121
  st.markdown("# WELCOME! HOW ARE YOU FEELING? PLEASE RECORD AN AUDIO!", unsafe_allow_html=True)
122
  st.markdown("# BASED ON YOUR EMOTIONAL STATE, I WILL SUGGEST SOME TIPS!", unsafe_allow_html=True)
123
-
124
 
125
  return
126
 
@@ -150,9 +155,33 @@ class MultimodalModel(nn.Module):
150
 
151
  @st.cache_resource(show_spinner=False)
152
  def speechtoText(wavfile):
 
 
 
 
 
 
 
 
 
 
 
 
153
  return speech_model.transcribe(wavfile)['text']
154
 
155
  def resampleaudio(wavfile):
 
 
 
 
 
 
 
 
 
 
 
 
156
  audio, sr = librosa.load(wavfile, sr=None)
157
 
158
  # Set the desired target sample rate
@@ -160,12 +189,25 @@ def resampleaudio(wavfile):
160
 
161
  # Resample the audio to the target sample rate
162
  resampled_audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate)
163
-
 
164
  sf.write(wavfile,resampled_audio, target_sample_rate)
165
  return wavfile
166
 
167
 
168
  def noiseReduction(wavfile):
 
 
 
 
 
 
 
 
 
 
 
 
169
  audio, sr = librosa.load(wavfile, sr=None)
170
 
171
  # Set parameters for noise reduction
@@ -181,6 +223,18 @@ def noiseReduction(wavfile):
181
 
182
 
183
  def removeSilence(wavfile):
 
 
 
 
 
 
 
 
 
 
 
 
184
  # Load the audio file
185
  audio_file = wavfile
186
 
@@ -194,18 +248,44 @@ def removeSilence(wavfile):
194
  for start, end in clips:
195
  non_silent_audio.extend(audio[start:end])
196
 
197
-
198
  # Save the audio without silence to a new WAV file
199
  sf.write(wavfile,non_silent_audio, sr)
200
  return wavfile
201
 
202
  def preprocessWavFile(wavfile):
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  resampledwavfile = resampleaudio(wavfile)
204
  denoised_file = noiseReduction(resampledwavfile)
205
  return removeSilence(denoised_file)
206
 
207
  @st.cache_resource()
208
  def load_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  # Load the model
210
  multiModel = MultimodalModel(bert_model_name, num_labels)
211
 
@@ -216,16 +296,34 @@ def load_model():
216
  tokenizer = AutoTokenizer.from_pretrained("netgvarun2005/MultiModalBertHubertTokenizer")
217
 
218
  # GenAI
219
- #tokenizer_gpt = AutoTokenizer.from_pretrained("netgvarun2005/GPTVirtualTherapistTokenizer", pad_token='<|pad|>',bos_token='<|startoftext|>',eos_token='<|endoftext|>')
220
  tokenizer_gpt = AutoTokenizer.from_pretrained("netgvarun2005/GPTTherapistDeepSpeedTokenizer", pad_token='<|pad|>',bos_token='<|startoftext|>',eos_token='<|endoftext|>')
221
- #model_gpt = AutoModelForCausalLM.from_pretrained("netgvarun2005/GPTVirtualTherapist")
222
  model_gpt = AutoModelForCausalLM.from_pretrained("netgvarun2005/GPTTherapistDeepSpeedModel")
223
 
224
  return multiModel,tokenizer,model_gpt,tokenizer_gpt
225
 
226
 
227
  def predict(audio_array,multiModal_model,key,tokenizer,text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  input_text = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
 
229
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_id)
230
 
231
  input_audio = feature_extractor(
@@ -234,8 +332,11 @@ def predict(audio_array,multiModal_model,key,tokenizer,text):
234
  padding=True,
235
  return_tensors="pt"
236
  )
 
 
237
  logits = multiModal_model(input_audio["input_values"], input_text["input_ids"])
238
 
 
239
  probabilities = F.softmax(logits, dim=1).to_dense()
240
  _, predicted = torch.max(probabilities, 1)
241
  class_prob = probabilities.tolist()
@@ -243,130 +344,68 @@ def predict(audio_array,multiModal_model,key,tokenizer,text):
243
  class_prob = [round(value, 2) for value in class_prob]
244
  maxVal = np.argmax(class_prob)
245
 
246
- # Display the final transcript
247
  if label_mapping[predicted] == "":
248
  st.write("Inference impossible, a problem occurred with your audio or your parameters, we apologize :(")
249
 
250
  return (label_mapping[maxVal]).capitalize()
251
 
252
- def record_audio(output_file, duration=5):
253
- # st.sidebar.markdown("Recording...")
254
- sd.wait() # Wait for microphone to start
255
- sd.wait() # Wait for microphone to start
256
- time.sleep(0.4)
257
-
258
- st.sidebar.markdown("<p style='font-size: 14px; font-weight: bold;'>Recording...</p>", unsafe_allow_html=True)
259
-
260
- chunk = 1024
261
- sample_format = pyaudio.paInt16
262
- channels = 2
263
- fs = 44100
264
-
265
- p = pyaudio.PyAudio()
266
-
267
- stream = p.open(format=sample_format,
268
- channels=channels,
269
- rate=fs,
270
- frames_per_buffer=chunk,
271
- input=True)
272
-
273
- frames = []
274
-
275
- for _ in range(int(fs / chunk * duration)):
276
- data = stream.read(chunk)
277
- frames.append(data)
278
-
279
- stream.stop_stream()
280
- stream.close()
281
- p.terminate()
282
-
283
- wf = wave.open(output_file, 'wb')
284
- wf.setnchannels(channels)
285
- wf.setsampwidth(p.get_sample_size(sample_format))
286
- wf.setframerate(fs)
287
- wf.writeframes(b''.join(frames))
288
- wf.close()
289
- time.sleep(0.5)
290
- # st.sidebar.markdown("Recording finished!")
291
- st.sidebar.markdown("<p style='font-size: 14px; font-weight: bold;'>Recording finished!</p>", unsafe_allow_html=True)
292
-
293
- time.sleep(0.5)
294
-
295
  def GenerateText(emo,gpt_tokenizer,gpt_model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  prompt = f'<startoftext>{emo2promptMapping[emo]}:'
297
 
 
298
  generated = gpt_tokenizer(prompt, return_tensors="pt").input_ids
299
 
 
300
  generated = generated.to(device)
301
  gpt_model.to(device)
302
 
 
303
  sample_outputs = gpt_model.generate(generated, do_sample=True, top_k=50,
304
  max_length=30, top_p=0.95, temperature=1.1, num_return_sequences=10)#,no_repeat_ngram_size=1)
305
 
306
  # Extract and split the generated text into words
307
  outputs = set([gpt_tokenizer.decode(sample_output, skip_special_tokens=True).split(':')[-1] for sample_output in sample_outputs])
 
 
308
  for i, sample_output in enumerate(outputs):
309
  st.write(f"<span style='font-size: 18px; font-family: Arial, sans-serif; font-weight: bold;'>{i+1}: {sample_output}</span>", unsafe_allow_html=True)
310
  time.sleep(0.5)
311
 
312
 
313
  def process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  emo = ""
315
  button_label = "Show Helpful Tips"
316
- # recorded = False # Initialize the recording state as False
317
-
318
- # if 'stage' not in st.session_state:
319
- # st.session_state.stage = 0
320
-
321
- # def set_stage(stage):
322
- # st.session_state.stage = stage
323
-
324
- # # Add custom CSS styles
325
- # st.markdown("""
326
- # <style>
327
- # .stRecordButton {
328
- # width: 50px;
329
- # height: 50px;
330
- # border-radius: 50px;
331
- # background-color: red;
332
- # color: black; /* Text color */
333
- # font-size: 16px;
334
- # font-weight: bold;
335
- # border: 2px solid white; /* Solid border */
336
- # box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
337
- # cursor: pointer;
338
- # transition: background-color 0.2s;
339
- # display: flex;
340
- # justify-content: center;
341
- # align-items: center;
342
- # }
343
-
344
- # .stRecordButton:hover {
345
- # background-color: darkred; /* Change background color on hover */
346
- # }
347
- # </style>
348
- # """, unsafe_allow_html=True)
349
- # Redirect the user to the external website
350
- #st.markdown("<a href='https://voice-recorder-online.com/' target='_blank'>Redirecting to the external audio recorder</a>.", unsafe_allow_html=True)
351
-
352
- # if st.sidebar.button("Record a 4 sec audio!", key="record_button", help="Click to start recording", on_click=set_stage, args=(1,)):
353
- # # Your button click action here
354
-
355
- # # Apply bold styling to the button label
356
- # st.sidebar.markdown("<span style='font-weight: bolder;'>Record a 4 sec audio!</span>", unsafe_allow_html=True)
357
-
358
- # # recorded = True # Set the recording state to True after recording
359
-
360
- # # Add your audio recording code here
361
- # output_wav_file = "output.wav"
362
-
363
- # try:
364
- # record_audio(output_wav_file, duration=4)
365
-
366
- # # # Use a div to encapsulate the audio element and apply the border
367
- # with st.sidebar.markdown('<div class="audio-container">', unsafe_allow_html=True):
368
- # # Play recorded sound
369
- # st.audio(output_wav_file, format="wav")
370
  uploaded_file = st.file_uploader("Upload your file! It should be .wav", type=["wav"])
371
 
372
  if uploaded_file is not None:
@@ -382,9 +421,7 @@ def process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer):
382
  temp_file.write(audio_content)
383
 
384
  try:
385
-
386
  audio_array, sr = librosa.load(preprocessWavFile(temp_filename), sr=None)
387
- #st.sidebar.markdown("<p style='font-size: 14px; font-weight: bold;'>Generating transcriptions! Please wait...</p>", unsafe_allow_html=True)
388
  with st.spinner(st.markdown("<p style='font-size: 14px; font-weight: bold;'>Generating transcriptions in the side pane! Please wait...</p>", unsafe_allow_html=True)):
389
  transcription = speechtoText(temp_filename)
390
  emo = predict(audio_array,ser_model,2,tokenizer,transcription)
@@ -400,21 +437,32 @@ def process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer):
400
  # Store the value of emo in the session state
401
  st.session_state.emo = emo
402
  if st.button(button_label):
403
- # Retrieve prompt from the emotion
404
- emo = st.session_state.emo
405
- GenerateText(emo,gpt_tokenizer,gpt_model)
406
- # except OSError as e:
407
- # if "[Errno -9996]" in str(e) and "Invalid input device (no default output device)" in str(e):
408
- # st.error("Recording not possible as no input device on cloud platforms. Please upload instead.")
409
- # else:
410
- # st.error(f"An error occurred while recording: {str(e)}")
411
 
412
- # if st.session_state.stage > 0:
 
 
413
 
414
- if __name__ == '__main__':
 
 
 
 
 
415
  config()
416
  if st.sidebar.button("**Open External Audio Recorder!**"):
417
  open_page("https://voice-recorder-online.com/")
418
 
 
419
  ser_model,tokenizer,gpt_model,gpt_tokenizer = load_model()
 
 
420
  process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer)
 
 
 
 
 
1
+ # Importing necessary libraries
2
  import pandas as pd
3
  import numpy as np
4
  import matplotlib.pyplot as plt
 
7
  import time
8
  from matplotlib import cm
9
  import soundfile as sf
 
10
  import torch
11
  import torch.nn as nn
 
 
12
  from PIL import Image
13
  import torch.nn.functional as F
14
  import streamlit as st
15
  import tempfile
16
  import noisereduce as nr
 
17
  import pyaudio
18
  import wave
19
  import whisper
 
22
  Wav2Vec2FeatureExtractor,
23
  AutoModel,
24
  AutoTokenizer,
25
+ HubertForSequenceClassification,
26
+ AutoModelForCausalLM
27
  )
 
 
28
  from streamlit.components.v1 import html
29
 
30
+ # Mapping Hubert model's output to GPT input
31
  emo2promptMapping = {
32
  'Angry':'ANGRY',
33
  'Calm':'CALM',
 
51
  num_labels=7
52
  label_mapping = ['angry', 'calm', 'disgust', 'fearful', 'happy', 'sad', 'surprised']
53
 
54
+ # Define the model's name from the Hugging Face model hub
55
  model_weights_path = "https://huggingface.co/netgvarun2005/MultiModalBertHubert/resolve/main/MultiModal_model_state_dict.pth"
56
 
57
+ # Model name initialization
58
  model_id = "facebook/hubert-base-ls960"
59
  bert_model_name = "bert-base-uncased"
60
 
61
 
62
  def open_page(url):
63
+ """
64
+ Function to invoke javascript code to redirect to an external URL.
65
+
66
+ Parameters:
67
+ External URL to redirect to.
68
+
69
+ Returns:
70
+ None
71
+ """
72
  open_script= """
73
  <script type="text/javascript">
74
  window.open('%s', '_blank').focus();
 
77
  html(open_script)
78
 
79
  def config():
80
+ """
81
+ Configure the Streamlit application settings and styles.
82
+
83
+ This function sets the page configuration, including the title and icon, adds custom CSS styles
84
+ for specific elements, and defines a custom style for the application title.
85
+
86
+ Parameters:
87
+ None
88
+
89
+ Returns:
90
+ None
91
+ """
92
  # Loading Image using PIL
93
  im = Image.open('./icon.png')
94
 
95
  # Set the page configuration with the title and icon
96
  st.set_page_config(page_title="Virtual Therapist", page_icon=im)
97
 
 
 
 
 
 
 
 
 
 
 
98
  # Add custom CSS styles
99
  st.markdown("""
100
  <style>
 
110
  }
111
  </style>
112
  """, unsafe_allow_html=True)
 
 
113
 
114
  # Define a custom style for your title
115
  title_style = """
 
126
  st.markdown(title_style, unsafe_allow_html=True)
127
  st.markdown("# WELCOME! HOW ARE YOU FEELING? PLEASE RECORD AN AUDIO!", unsafe_allow_html=True)
128
  st.markdown("# BASED ON YOUR EMOTIONAL STATE, I WILL SUGGEST SOME TIPS!", unsafe_allow_html=True)
 
129
 
130
  return
131
 
 
155
 
156
  @st.cache_resource(show_spinner=False)
157
  def speechtoText(wavfile):
158
+ """
159
+ Convert speech from a WAV audio file to text using a pre-trained Whisper ASR model.
160
+
161
+ This function takes a WAV audio file as input and utilizes a pre-trained Whisper ASR model
162
+ to transcribe the speech into text.
163
+
164
+ Parameters:
165
+ wavfile (str): The file path to the input WAV audio file.
166
+
167
+ Returns:
168
+ str: The transcribed text from the speech in the audio file.
169
+ """
170
  return speech_model.transcribe(wavfile)['text']
171
 
172
  def resampleaudio(wavfile):
173
+ """
174
+ Resample an audio file to a target sample rate and save it back to the same file.
175
+
176
+ This function loads an audio file in WAV format, resamples it to the specified target sample rate,
177
+ and then saves the resampled audio back to the same file, overwriting the original content.
178
+
179
+ Parameters:
180
+ wavfile (str): The file path to the input WAV audio file.
181
+
182
+ Returns:
183
+ str: The file path to the resampled WAV audio file.
184
+ """
185
  audio, sr = librosa.load(wavfile, sr=None)
186
 
187
  # Set the desired target sample rate
 
189
 
190
  # Resample the audio to the target sample rate
191
  resampled_audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate)
192
+
193
+ # Write to the original file
194
  sf.write(wavfile,resampled_audio, target_sample_rate)
195
  return wavfile
196
 
197
 
198
  def noiseReduction(wavfile):
199
+ """
200
+ Apply noise reduction to an audio file and save the denoised audio back to the same file.
201
+
202
+ This function loads an audio file in WAV format, performs noise reduction using the specified parameters,
203
+ and then saves the denoised audio back to the same file, overwriting the original content.
204
+
205
+ Parameters:
206
+ wavfile (str): The file path to the input WAV audio file.
207
+
208
+ Returns:
209
+ str: The file path to the denoised WAV audio file.
210
+ """
211
  audio, sr = librosa.load(wavfile, sr=None)
212
 
213
  # Set parameters for noise reduction
 
223
 
224
 
225
  def removeSilence(wavfile):
226
+ """
227
+ Remove silence from an audio file and save the trimmed audio back to the same file.
228
+
229
+ This function loads an audio file in WAV format, identifies and removes silence based on a specified threshold,
230
+ and then saves the trimmed audio back to the same file, overwriting the original content.
231
+
232
+ Parameters:
233
+ wavfile (str): The file path to the input WAV audio file.
234
+
235
+ Returns:
236
+ str: The file path to the audio file with silence removed.
237
+ """
238
  # Load the audio file
239
  audio_file = wavfile
240
 
 
248
  for start, end in clips:
249
  non_silent_audio.extend(audio[start:end])
250
 
 
251
  # Save the audio without silence to a new WAV file
252
  sf.write(wavfile,non_silent_audio, sr)
253
  return wavfile
254
 
255
  def preprocessWavFile(wavfile):
256
+ """
257
+ Perform a series of audio preprocessing steps on a WAV file.
258
+
259
+ This function takes an input WAV audio file, applies a series of preprocessing steps,
260
+ including resampling, noise reduction, and silence removal, and returns the path to the
261
+ preprocessed audio file.
262
+
263
+ Parameters:
264
+ wavfile (str): The file path to the input WAV audio file.
265
+
266
+ Returns:
267
+ str: The file path to the preprocessed WAV audio file.
268
+ """
269
  resampledwavfile = resampleaudio(wavfile)
270
  denoised_file = noiseReduction(resampledwavfile)
271
  return removeSilence(denoised_file)
272
 
273
  @st.cache_resource()
274
  def load_model():
275
+ """
276
+ Load and configure various models and tokenizers for a multi-modal application.
277
+
278
+ This function loads a multi-modal model and its weights from a specified source,
279
+ initializes tokenizers for the model and an additional language model, and returns
280
+ these components for use in a multi-modal application.
281
+
282
+ Returns:
283
+ tuple: A tuple containing the following components:
284
+ - multiModel (MultimodalModel): The multi-modal model.
285
+ - tokenizer (AutoTokenizer): Tokenizer for the multi-modal model.
286
+ - model_gpt (AutoModelForCausalLM): Language model for text generation.
287
+ - tokenizer_gpt (AutoTokenizer): Tokenizer for the language model.
288
+ """
289
  # Load the model
290
  multiModel = MultimodalModel(bert_model_name, num_labels)
291
 
 
296
  tokenizer = AutoTokenizer.from_pretrained("netgvarun2005/MultiModalBertHubertTokenizer")
297
 
298
  # GenAI
 
299
  tokenizer_gpt = AutoTokenizer.from_pretrained("netgvarun2005/GPTTherapistDeepSpeedTokenizer", pad_token='<|pad|>',bos_token='<|startoftext|>',eos_token='<|endoftext|>')
 
300
  model_gpt = AutoModelForCausalLM.from_pretrained("netgvarun2005/GPTTherapistDeepSpeedModel")
301
 
302
  return multiModel,tokenizer,model_gpt,tokenizer_gpt
303
 
304
 
305
  def predict(audio_array,multiModal_model,key,tokenizer,text):
306
+ """
307
+ Perform multimodal prediction using an audio feature array and text input.
308
+
309
+ This function takes an audio feature array and text as input, tokenizes the text,
310
+ extracts audio features, and uses a multi-modal model to predict a class label based on
311
+ the combined audio and text inputs.
312
+
313
+ Parameters:
314
+ audio_array (numpy.ndarray): A numpy array containing audio features.
315
+ multiModal_model: The multi-modal model for prediction.
316
+ key: A key for identifying the model (e.g., model_id).
317
+ tokenizer: Tokenizer for processing the text input.
318
+ text (str): The input text for prediction.
319
+
320
+ Returns:
321
+ str: The predicted class label.
322
+ """
323
+ # Tokenize the input text
324
  input_text = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
325
+
326
+ # Extract audio features using a feature extractor
327
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_id)
328
 
329
  input_audio = feature_extractor(
 
332
  padding=True,
333
  return_tensors="pt"
334
  )
335
+
336
+ # Make predictions with the multi-modal model
337
  logits = multiModal_model(input_audio["input_values"], input_text["input_ids"])
338
 
339
+ # Calculate class probabilities
340
  probabilities = F.softmax(logits, dim=1).to_dense()
341
  _, predicted = torch.max(probabilities, 1)
342
  class_prob = probabilities.tolist()
 
344
  class_prob = [round(value, 2) for value in class_prob]
345
  maxVal = np.argmax(class_prob)
346
 
347
+ # Display the final transcript and handle inference issues
348
  if label_mapping[predicted] == "":
349
  st.write("Inference impossible, a problem occurred with your audio or your parameters, we apologize :(")
350
 
351
  return (label_mapping[maxVal]).capitalize()
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  def GenerateText(emo,gpt_tokenizer,gpt_model):
354
+ """
355
+ Generate text based on a given emotion using a GPT-2 model.
356
+
357
+ This function takes an emotion as input, generates text based on the emotion prompt,
358
+ and displays multiple generated text samples.
359
+
360
+ Parameters:
361
+ emo (str): The emotion for which text should be generated.
362
+ gpt_tokenizer: Tokenizer for processing the GPT-2 model input.
363
+ gpt_model: The GPT-2 model for text generation.
364
+
365
+ Returns:
366
+ None
367
+ """
368
+ # Create a prompt based on the input emotion
369
  prompt = f'<startoftext>{emo2promptMapping[emo]}:'
370
 
371
+ # Tokenize the prompt and convert it to input tensors
372
  generated = gpt_tokenizer(prompt, return_tensors="pt").input_ids
373
 
374
+ # Move the generated tensor and GPT model to the specified device (e.g., GPU)
375
  generated = generated.to(device)
376
  gpt_model.to(device)
377
 
378
+ # Generate multiple text samples based on the prompt
379
  sample_outputs = gpt_model.generate(generated, do_sample=True, top_k=50,
380
  max_length=30, top_p=0.95, temperature=1.1, num_return_sequences=10)#,no_repeat_ngram_size=1)
381
 
382
  # Extract and split the generated text into words
383
  outputs = set([gpt_tokenizer.decode(sample_output, skip_special_tokens=True).split(':')[-1] for sample_output in sample_outputs])
384
+
385
+ # Display the generated text samples with a delay for readability
386
  for i, sample_output in enumerate(outputs):
387
  st.write(f"<span style='font-size: 18px; font-family: Arial, sans-serif; font-weight: bold;'>{i+1}: {sample_output}</span>", unsafe_allow_html=True)
388
  time.sleep(0.5)
389
 
390
 
391
  def process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer):
392
+ """
393
+ Process and analyze an uploaded WAV file, generating transcriptions and helpful tips.
394
+
395
+ This function allows users to upload a WAV audio file, processes the file to obtain transcriptions,
396
+ predicts the user's emotional state, and displays helpful tips based on the predicted emotion.
397
+
398
+ Parameters:
399
+ ser_model: The emotion analysis model for predicting emotions.
400
+ tokenizer: Tokenizer for processing text inputs.
401
+ gpt_model: The GPT-3 model for generating text.
402
+ gpt_tokenizer: Tokenizer for processing GPT-3 model inputs.
403
+
404
+ Returns:
405
+ None
406
+ """
407
  emo = ""
408
  button_label = "Show Helpful Tips"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  uploaded_file = st.file_uploader("Upload your file! It should be .wav", type=["wav"])
410
 
411
  if uploaded_file is not None:
 
421
  temp_file.write(audio_content)
422
 
423
  try:
 
424
  audio_array, sr = librosa.load(preprocessWavFile(temp_filename), sr=None)
 
425
  with st.spinner(st.markdown("<p style='font-size: 14px; font-weight: bold;'>Generating transcriptions in the side pane! Please wait...</p>", unsafe_allow_html=True)):
426
  transcription = speechtoText(temp_filename)
427
  emo = predict(audio_array,ser_model,2,tokenizer,transcription)
 
437
  # Store the value of emo in the session state
438
  st.session_state.emo = emo
439
  if st.button(button_label):
440
+ with st.spinner(st.markdown("<p style='font-size: 14px; font-weight: bold;'>Generating tips (it may take upto 3-4 mins depending upon network speed! Please wait...</p>", unsafe_allow_html=True)):
441
+ # Retrieve prompt from the emotion
442
+ emo = st.session_state.emo
443
+ # Call the function for GENAI
444
+ GenerateText(emo,gpt_tokenizer,gpt_model)
 
 
 
445
 
446
+ def main():
447
+ """
448
+ Main function for running a Streamlit-based multi-modal text generation application.
449
 
450
+ This function configures the Streamlit application, loads necessary models and tokenizers,
451
+ and allows users to process audio files to generate transcriptions and helpful tips.
452
+
453
+ Returns:
454
+ None
455
+ """
456
  config()
457
  if st.sidebar.button("**Open External Audio Recorder!**"):
458
  open_page("https://voice-recorder-online.com/")
459
 
460
+ # Load the models, and tokenizers
461
  ser_model,tokenizer,gpt_model,gpt_tokenizer = load_model()
462
+
463
+ # Process and analyze uploaded audio files
464
  process_file(ser_model,tokenizer,gpt_model,gpt_tokenizer)
465
+
466
+
467
+ if __name__ == '__main__':
468
+ main()