Leo Liu
Update app.py
c875150 verified
raw
history blame
6.04 kB
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torchaudio
import os
import re
import jieba
# Device setup: 自动选择使用 CUDA 或 CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载 Whisper 模型,用于音频转录(粤语版)
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
pipe = pipeline(task="automatic-speech-recognition", model=MODEL_NAME, chunk_length_s=60, device=device)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
def transcribe_audio(audio_path):
"""
对音频文件进行转录,支持大于60秒的音频分段处理
"""
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
if duration > 60:
results = []
for start in range(0, int(duration), 50):
end = min(start + 60, int(duration))
chunk = waveform[:, start * sample_rate:end * sample_rate]
temp_filename = f"temp_chunk_{start}.wav"
torchaudio.save(temp_filename, chunk, sample_rate)
result = pipe(temp_filename)["text"]
results.append(result)
os.remove(temp_filename)
return " ".join(results)
return pipe(audio_path)["text"]
# 加载翻译模型(粤语到中文)
tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh")
model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh").to(device)
def split_sentences(text):
"""根据中文标点分割句子"""
return [s for s in re.split(r'(?<=[。!?])', text) if s]
def translate(text):
"""
将转录文本翻译为中文,逐句翻译后拼接输出
"""
sentences = split_sentences(text)
translations = []
for sentence in sentences:
inputs = tokenizer(sentence, return_tensors="pt").to(device)
outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5)
translations.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return " ".join(translations)
# 加载质量评分模型,用于评价对话质量
rating_pipe = pipeline("text-classification", model="Leo0129/CustomModel_dianping-chinese")
def split_text(text, max_length=512):
"""
将文本按照最大长度拆分成多个片段,使用 jieba 分词
"""
words = list(jieba.cut(text))
chunks, current_chunk = [], ""
for word in words:
if len(current_chunk) + len(word) < max_length:
current_chunk += word
else:
chunks.append(current_chunk)
current_chunk = word
if current_chunk:
chunks.append(current_chunk)
return chunks
def rate_quality(text):
"""
对翻译后的文本进行质量评价,返回最频繁的评分结果
"""
chunks = split_text(text)
results = []
for chunk in chunks:
result = rating_pipe(chunk)[0]
label_map = {"LABEL_0": "Poor", "LABEL_1": "Neutral", "LABEL_2": "Good"}
results.append(label_map.get(result["label"], "Unknown"))
return max(set(results), key=results.count)
def main():
# 设置页面配置和图标,吸引用户注意
st.set_page_config(page_title="Customer Service Quality Analyzer", page_icon="🎙️")
# 自定义 CSS 样式(引用 Comic Neue 字体,并设置背景渐变、边框圆角等效果)
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Comic+Neue:wght@700&display=swap');
.header {
background: linear-gradient(45deg, #FF9A6C, #FF6B6B);
border-radius: 15px;
padding: 2rem;
text-align: center;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
margin-bottom: 2rem;
}
.subtitle {
font-family: 'Comic Neue', cursive;
color: #4B4B4B;
font-size: 1.2rem;
margin: 1rem 0;
padding: 1rem;
background: rgba(255,255,255,0.9);
border-radius: 10px;
border-left: 5px solid #FF6B6B;
}
</style>
""", unsafe_allow_html=True)
# 页面头部展示
st.markdown("""
<div class="header">
<h1 style='margin:0;'>🎙️ Customer Service Quality Analyzer</h1>
<p style='color: white; font-size: 1.2rem;'>Evaluate the service quality with simple-uploading!</p>
</div>
""", unsafe_allow_html=True)
# 上传音频文件(支持 wav、mp3、flac 格式)
uploaded_file = st.file_uploader("👉🏻 Upload your Cantonese audio file here...", type=["wav", "mp3", "flac"])
if uploaded_file is not None:
# 直接播放上传的音频
st.audio(uploaded_file, format="audio/wav")
# 将上传的文件保存为临时文件
temp_audio_path = "uploaded_audio.wav"
with open(temp_audio_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# 初始化进度条和状态提示区域
progress_bar = st.progress(0)
status_container = st.empty()
# Step 1: 音频转录
status_container.info("📝 **Step 1/3**: Transcribing audio...")
transcript = transcribe_audio(temp_audio_path)
progress_bar.progress(33)
st.write("**Transcript:**", transcript)
# Step 2: 翻译转录内容
status_container.info("📚 **Step 2/3**: Translating transcript...")
translated_text = translate(transcript)
progress_bar.progress(66)
st.write("**Translation:**", translated_text)
# Step 3: 音频质量评分
status_container.info("🧑‍⚖️ **Step 3/3**: Evaluating audio quality...")
quality_rating = rate_quality(translated_text)
progress_bar.progress(100)
st.write("**Quality Rating:**", quality_rating)
# 处理完成后删除临时文件
os.remove(temp_audio_path)
if __name__ == "__main__":
main()