Leo Liu
Update app.py
795e7cf verified
raw
history blame
4.6 kB
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torchaudio
import os
import jieba
# Device setup: automatically selects CUDA or CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Whisper model for Cantonese audio transcription
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
pipe = pipeline(task="automatic-speech-recognition", model=MODEL_NAME, chunk_length_s=60, device=device)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
# Transcription function (supports long audio)
def transcribe_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
if duration > 60:
results = []
for start in range(0, int(duration), 50):
end = min(start + 60, int(duration))
chunk = waveform[:, start * sample_rate:end * sample_rate]
temp_filename = f"temp_chunk_{start}.wav"
torchaudio.save(temp_filename, chunk, sample_rate)
result = pipe(temp_filename)["text"]
results.append(result)
os.remove(temp_filename)
return " ".join(results)
return pipe(audio_path)["text"]
# Load sentiment analysis model (Custom multilingual sentiment analysis)
sentiment_pipe = pipeline("text-classification", model="CustomModel-multilingual-sentiment-analysis", device=device)
# Text splitting function (using jieba for Chinese text)
def split_text(text, max_length=512):
words = list(jieba.cut(text))
chunks, current_chunk = [], ""
for word in words:
if len(current_chunk) + len(word) < max_length:
current_chunk += word
else:
chunks.append(current_chunk)
current_chunk = word
if current_chunk:
chunks.append(current_chunk)
return chunks
# Function to rate sentiment quality based on most frequent result
def rate_quality(text):
chunks = split_text(text)
results = []
for chunk in chunks:
result = sentiment_pipe(chunk)[0]
label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
results.append(label_map.get(result["label"], "Unknown"))
return max(set(results), key=results.count)
# Streamlit main interface
def main():
st.set_page_config(page_title="Customer Service Quality Analyzer", page_icon="πŸŽ™οΈ")
# Custom CSS styling
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Comic+Neue:wght@700&display=swap');
.header {
background: linear-gradient(45deg, #FF9A6C, #FF6B6B);
border-radius: 15px;
padding: 2rem;
text-align: center;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
margin-bottom: 2rem;
}
.subtitle {
font-family: 'Comic Neue', cursive;
color: #4B4B4B;
font-size: 1.2rem;
margin: 1rem 0;
padding: 1rem;
background: rgba(255,255,255,0.9);
border-radius: 10px;
border-left: 5px solid #FF6B6B;
}
</style>
""", unsafe_allow_html=True)
# Header
st.markdown("""
<div class="header">
<h1 style='margin:0;'>πŸŽ™οΈ Customer Service Quality Analyzer</h1>
<p style='color: white; font-size: 1.2rem;'>Evaluate the service quality with simple-uploading!</p>
</div>
""", unsafe_allow_html=True)
# Audio file uploader
uploaded_file = st.file_uploader("πŸ‘‰πŸ» Upload your Cantonese audio file here...", type=["wav", "mp3", "flac"])
if uploaded_file is not None:
st.audio(uploaded_file, format="audio/wav")
temp_audio_path = "uploaded_audio.wav"
with open(temp_audio_path, "wb") as f:
f.write(uploaded_file.getbuffer())
progress_bar = st.progress(0)
status_container = st.empty()
# Step 1: Audio transcription
status_container.info("πŸ“ **Step 1/2**: Transcribing audio...")
transcript = transcribe_audio(temp_audio_path)
progress_bar.progress(50)
st.write("**Transcript:**", transcript)
# Step 2: Sentiment Analysis
status_container.info("πŸ§‘β€βš–οΈ **Step 2/2**: Evaluating sentiment quality...")
quality_rating = rate_quality(transcript)
progress_bar.progress(100)
st.write("**Sentiment Rating:**", quality_rating)
os.remove(temp_audio_path)
if __name__ == "__main__":
main()