Leo Liu commited on
Commit
ce0d2f1
Β·
verified Β·
1 Parent(s): 6c203d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -69
app.py CHANGED
@@ -3,61 +3,85 @@ import torch
3
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
4
  import torchaudio
5
  import os
6
- import jieba
7
- import magic
 
8
 
9
- # Device setup: automatically selects CUDA or CPU
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- # Load Whisper model for Cantonese audio transcription
13
  MODEL_NAME = "alvanlii/whisper-small-cantonese"
14
  language = "zh"
15
- pipe = pipeline(task="automatic-speech-recognition", model=MODEL_NAME, chunk_length_s=60, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
16
  pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
17
 
18
- # Transcription function (supports long audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def transcribe_audio(audio_path):
20
  waveform, sample_rate = torchaudio.load(audio_path)
21
- duration = waveform.shape[1] / sample_rate
 
 
 
 
 
 
22
  if duration > 60:
 
 
23
  results = []
24
- for start in range(0, int(duration), 50):
25
- end = min(start + 60, int(duration))
26
- chunk = waveform[:, start * sample_rate:end * sample_rate]
27
- temp_filename = f"temp_chunk_{start}.wav"
28
- torchaudio.save(temp_filename, chunk, sample_rate)
29
- result = pipe(temp_filename)["text"]
30
- results.append(result)
31
- os.remove(temp_filename)
32
- return " ".join(results)
33
- return pipe(audio_path)["text"]
34
-
35
- # Load sentiment analysis model
 
36
  sentiment_pipe = pipeline("text-classification", model="Leo0129/CustomModel-multilingual-sentiment-analysis", device=device)
37
 
38
- # Text splitting function (using jieba for Chinese text)
39
- def split_text(text, max_length=512):
40
- words = list(jieba.cut(text))
41
- chunks, current_chunk = [], ""
42
- for word in words:
43
- if len(current_chunk) + len(word) < max_length:
44
- current_chunk += word
45
- else:
46
- chunks.append(current_chunk)
47
- current_chunk = word
48
- if current_chunk:
49
- chunks.append(current_chunk)
50
- return chunks
51
-
52
- # Function to rate sentiment quality based on most frequent result
53
  def rate_quality(text):
54
- chunks = split_text(text)
55
- results = []
56
- for chunk in chunks:
57
- result = sentiment_pipe(chunk)[0]
58
- label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
59
- results.append(label_map.get(result["label"], "Unknown"))
60
- return max(set(results), key=results.count)
61
 
62
  # Streamlit main interface
63
  def main():
@@ -66,7 +90,6 @@ def main():
66
  # Custom CSS styling
67
  st.markdown("""
68
  <style>
69
- @import url('https://fonts.googleapis.com/css2?family=Comic+Neue:wght@700&display=swap');
70
  .header {
71
  background: linear-gradient(45deg, #FF9A6C, #FF6B6B);
72
  border-radius: 15px;
@@ -75,38 +98,19 @@ def main():
75
  box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
76
  margin-bottom: 2rem;
77
  }
78
- .subtitle {
79
- font-family: 'Comic Neue', cursive;
80
- color: #4B4B4B;
81
- font-size: 1.2rem;
82
- margin: 1rem 0;
83
- padding: 1rem;
84
- background: rgba(255,255,255,0.9);
85
- border-radius: 10px;
86
- border-left: 5px solid #FF6B6B;
87
- }
88
  </style>
89
  """, unsafe_allow_html=True)
90
 
91
- # Header
92
  st.markdown("""
93
  <div class="header">
94
  <h1 style='margin:0;'>πŸŽ™οΈ Customer Service Quality Analyzer</h1>
95
- <p style='color: white; font-size: 1.2rem;'>Evaluate the service quality with simple uploading!</p>
96
  </div>
97
  """, unsafe_allow_html=True)
98
 
99
- # Audio file uploader
100
- uploaded_file = st.file_uploader("πŸ“€ Please upload your Cantonese customer service audio file", type=["wav", "mp3", "flac"])
101
 
102
  if uploaded_file is not None:
103
- file_type = magic.from_buffer(uploaded_file.read(), mime=True)
104
- uploaded_file.seek(0)
105
- if not file_type.startswith("audio/"):
106
- st.error("⚠️ Sorry, the uploaded file format is not supported. Please upload an audio file.")
107
- return
108
-
109
- st.audio(uploaded_file, format="audio/wav")
110
  temp_audio_path = "uploaded_audio.wav"
111
  with open(temp_audio_path, "wb") as f:
112
  f.write(uploaded_file.getbuffer())
@@ -114,25 +118,23 @@ def main():
114
  progress_bar = st.progress(0)
115
 
116
  # Step 1: Audio transcription
117
- with st.spinner('πŸ“ Step 1: Transcribing audio, please wait... '):
118
  transcript = transcribe_audio(temp_audio_path)
119
  progress_bar.progress(50)
120
  st.write("**Transcript:**", transcript)
121
 
122
  # Step 2: Sentiment Analysis
123
- with st.spinner('πŸ§‘β€βš–οΈ Step 2: Analyzing sentiment, please wait... '):
124
  quality_rating = rate_quality(transcript)
125
  progress_bar.progress(100)
126
  st.write("**Sentiment Analysis Result:**", quality_rating)
127
 
128
- # Download analysis results
129
  result_text = f"Transcript:\n{transcript}\n\nSentiment Analysis Result: {quality_rating}"
130
  st.download_button(label="πŸ“₯ Download Analysis Report", data=result_text, file_name="analysis_report.txt")
131
 
132
- # Customer support info
133
- st.markdown("❓If you encounter any issues, please contact customer support: πŸ“§ **support@hellotoby.com**")
134
 
135
  os.remove(temp_audio_path)
136
 
137
  if __name__ == "__main__":
138
- main()
 
3
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
4
  import torchaudio
5
  import os
6
+ import re
7
+ from difflib import SequenceMatcher
8
+ import numpy as np
9
 
10
+ # Device setup
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Load Whisper model with adjusted parameters for better memory handling
14
  MODEL_NAME = "alvanlii/whisper-small-cantonese"
15
  language = "zh"
16
+ pipe = pipeline(
17
+ task="automatic-speech-recognition",
18
+ model=MODEL_NAME,
19
+ chunk_length_s=30,
20
+ device=device,
21
+ generate_kwargs={
22
+ "no_repeat_ngram_size": 3,
23
+ "repetition_penalty": 1.3,
24
+ "temperature": 0.7,
25
+ "top_p": 0.9,
26
+ "top_k": 50
27
+ }
28
+ )
29
  pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
30
 
31
+ # Similarity check to remove repeated phrases
32
+ def is_similar(a, b, threshold=0.8):
33
+ return SequenceMatcher(None, a, b).ratio() > threshold
34
+
35
+ def remove_repeated_phrases(text):
36
+ sentences = re.split(r'(?<=[γ€‚οΌοΌŸ])', text)
37
+ cleaned_sentences = []
38
+ for sentence in sentences:
39
+ if not cleaned_sentences or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
40
+ cleaned_sentences.append(sentence.strip())
41
+ return " ".join(cleaned_sentences)
42
+
43
+ # Remove punctuation
44
+ def remove_punctuation(text):
45
+ return re.sub(r'[^\w\s]', '', text)
46
+
47
+ # Transcription function (adjusted for punctuation and repetition removal)
48
  def transcribe_audio(audio_path):
49
  waveform, sample_rate = torchaudio.load(audio_path)
50
+
51
+ if waveform.shape[0] > 1:
52
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
53
+
54
+ waveform = waveform.squeeze(0).numpy()
55
+
56
+ duration = waveform.shape[0] / sample_rate
57
  if duration > 60:
58
+ chunk_size = sample_rate * 55
59
+ step_size = sample_rate * 50
60
  results = []
61
+
62
+ for start in range(0, waveform.shape[0], step_size):
63
+ chunk = waveform[start:start + chunk_size]
64
+ if chunk.shape[0] == 0:
65
+ break
66
+ transcript = pipe({"sampling_rate": sample_rate, "raw": chunk})["text"]
67
+ results.append(remove_punctuation(transcript))
68
+
69
+ return remove_punctuation(remove_repeated_phrases(" ".join(results)))
70
+
71
+ return remove_punctuation(remove_repeated_phrases(pipe({"sampling_rate": sample_rate, "raw": waveform})["text"]))
72
+
73
+ # Sentiment analysis model
74
  sentiment_pipe = pipeline("text-classification", model="Leo0129/CustomModel-multilingual-sentiment-analysis", device=device)
75
 
76
+ # Rate sentiment with batch processing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def rate_quality(text):
78
+ chunks = [text[i:i+512] for i in range(0, len(text), 512)]
79
+ results = sentiment_pipe(chunks, batch_size=4)
80
+
81
+ label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
82
+ processed_results = [label_map.get(res["label"], "Unknown") for res in results]
83
+
84
+ return max(set(processed_results), key=processed_results.count)
85
 
86
  # Streamlit main interface
87
  def main():
 
90
  # Custom CSS styling
91
  st.markdown("""
92
  <style>
 
93
  .header {
94
  background: linear-gradient(45deg, #FF9A6C, #FF6B6B);
95
  border-radius: 15px;
 
98
  box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
99
  margin-bottom: 2rem;
100
  }
 
 
 
 
 
 
 
 
 
 
101
  </style>
102
  """, unsafe_allow_html=True)
103
 
 
104
  st.markdown("""
105
  <div class="header">
106
  <h1 style='margin:0;'>πŸŽ™οΈ Customer Service Quality Analyzer</h1>
107
+ <p style='color: white;'>Evaluate the service quality with simple uploading!</p>
108
  </div>
109
  """, unsafe_allow_html=True)
110
 
111
+ uploaded_file = st.file_uploader("πŸ“€ Please upload your Cantonese customer service audio file", type=["wav", "mp3", "flac"])
 
112
 
113
  if uploaded_file is not None:
 
 
 
 
 
 
 
114
  temp_audio_path = "uploaded_audio.wav"
115
  with open(temp_audio_path, "wb") as f:
116
  f.write(uploaded_file.getbuffer())
 
118
  progress_bar = st.progress(0)
119
 
120
  # Step 1: Audio transcription
121
+ with st.spinner('πŸ“ Step 1: Transcribing audio, please wait...'):
122
  transcript = transcribe_audio(temp_audio_path)
123
  progress_bar.progress(50)
124
  st.write("**Transcript:**", transcript)
125
 
126
  # Step 2: Sentiment Analysis
127
+ with st.spinner('πŸ§‘β€βš–οΈ Step 2: Analyzing sentiment, please wait...'):
128
  quality_rating = rate_quality(transcript)
129
  progress_bar.progress(100)
130
  st.write("**Sentiment Analysis Result:**", quality_rating)
131
 
 
132
  result_text = f"Transcript:\n{transcript}\n\nSentiment Analysis Result: {quality_rating}"
133
  st.download_button(label="πŸ“₯ Download Analysis Report", data=result_text, file_name="analysis_report.txt")
134
 
135
+ st.markdown("❓If you encounter any issues, please contact customer support: πŸ“§ **abc@hellotoby.com**")
 
136
 
137
  os.remove(temp_audio_path)
138
 
139
  if __name__ == "__main__":
140
+ main()