Leo Liu
Update app.py
6c203d8 verified
raw
history blame
5.25 kB
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torchaudio
import os
import jieba
import magic
# Device setup: automatically selects CUDA or CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Whisper model for Cantonese audio transcription
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
pipe = pipeline(task="automatic-speech-recognition", model=MODEL_NAME, chunk_length_s=60, device=device)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
# Transcription function (supports long audio)
def transcribe_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
if duration > 60:
results = []
for start in range(0, int(duration), 50):
end = min(start + 60, int(duration))
chunk = waveform[:, start * sample_rate:end * sample_rate]
temp_filename = f"temp_chunk_{start}.wav"
torchaudio.save(temp_filename, chunk, sample_rate)
result = pipe(temp_filename)["text"]
results.append(result)
os.remove(temp_filename)
return " ".join(results)
return pipe(audio_path)["text"]
# Load sentiment analysis model
sentiment_pipe = pipeline("text-classification", model="Leo0129/CustomModel-multilingual-sentiment-analysis", device=device)
# Text splitting function (using jieba for Chinese text)
def split_text(text, max_length=512):
words = list(jieba.cut(text))
chunks, current_chunk = [], ""
for word in words:
if len(current_chunk) + len(word) < max_length:
current_chunk += word
else:
chunks.append(current_chunk)
current_chunk = word
if current_chunk:
chunks.append(current_chunk)
return chunks
# Function to rate sentiment quality based on most frequent result
def rate_quality(text):
chunks = split_text(text)
results = []
for chunk in chunks:
result = sentiment_pipe(chunk)[0]
label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
results.append(label_map.get(result["label"], "Unknown"))
return max(set(results), key=results.count)
# Streamlit main interface
def main():
st.set_page_config(page_title="Customer Service Quality Analyzer", page_icon="πŸŽ™οΈ")
# Custom CSS styling
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Comic+Neue:wght@700&display=swap');
.header {
background: linear-gradient(45deg, #FF9A6C, #FF6B6B);
border-radius: 15px;
padding: 2rem;
text-align: center;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
margin-bottom: 2rem;
}
.subtitle {
font-family: 'Comic Neue', cursive;
color: #4B4B4B;
font-size: 1.2rem;
margin: 1rem 0;
padding: 1rem;
background: rgba(255,255,255,0.9);
border-radius: 10px;
border-left: 5px solid #FF6B6B;
}
</style>
""", unsafe_allow_html=True)
# Header
st.markdown("""
<div class="header">
<h1 style='margin:0;'>πŸŽ™οΈ Customer Service Quality Analyzer</h1>
<p style='color: white; font-size: 1.2rem;'>Evaluate the service quality with simple uploading!</p>
</div>
""", unsafe_allow_html=True)
# Audio file uploader
uploaded_file = st.file_uploader("πŸ“€ Please upload your Cantonese customer service audio file", type=["wav", "mp3", "flac"])
if uploaded_file is not None:
file_type = magic.from_buffer(uploaded_file.read(), mime=True)
uploaded_file.seek(0)
if not file_type.startswith("audio/"):
st.error("⚠️ Sorry, the uploaded file format is not supported. Please upload an audio file.")
return
st.audio(uploaded_file, format="audio/wav")
temp_audio_path = "uploaded_audio.wav"
with open(temp_audio_path, "wb") as f:
f.write(uploaded_file.getbuffer())
progress_bar = st.progress(0)
# Step 1: Audio transcription
with st.spinner('πŸ“ Step 1: Transcribing audio, please wait... '):
transcript = transcribe_audio(temp_audio_path)
progress_bar.progress(50)
st.write("**Transcript:**", transcript)
# Step 2: Sentiment Analysis
with st.spinner('πŸ§‘β€βš–οΈ Step 2: Analyzing sentiment, please wait... '):
quality_rating = rate_quality(transcript)
progress_bar.progress(100)
st.write("**Sentiment Analysis Result:**", quality_rating)
# Download analysis results
result_text = f"Transcript:\n{transcript}\n\nSentiment Analysis Result: {quality_rating}"
st.download_button(label="πŸ“₯ Download Analysis Report", data=result_text, file_name="analysis_report.txt")
# Customer support info
st.markdown("❓If you encounter any issues, please contact customer support: πŸ“§ **support@hellotoby.com**")
os.remove(temp_audio_path)
if __name__ == "__main__":
main()