xujinheng666's picture
Update app.py
2ffc60e verified
raw
history blame
3.6 kB
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torchaudio
import os
import re
import jieba
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Whisper model for transcription
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
pipe = pipeline(task="automatic-speech-recognition", model=MODEL_NAME, chunk_length_s=60, device=device)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
def transcribe_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
if duration > 60:
results = []
for start in range(0, int(duration), 50):
end = min(start + 60, int(duration))
chunk = waveform[:, start * sample_rate:end * sample_rate]
temp_filename = f"temp_chunk_{start}.wav"
torchaudio.save(temp_filename, chunk, sample_rate)
result = pipe(temp_filename)["text"]
results.append(result)
os.remove(temp_filename)
return " ".join(results)
return pipe(audio_path)["text"]
# Load translation model
tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh")
model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh").to(device)
def split_sentences(text):
return [s for s in re.split(r'(?<=[。!?])', text) if s]
def translate(text):
sentences = split_sentences(text)
translations = []
for sentence in sentences:
inputs = tokenizer(sentence, return_tensors="pt").to(device)
outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5)
translations.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return " ".join(translations)
# Load quality rating model
rating_pipe = pipeline("text-classification", model="Leo0129/CustomModel_dianping-chinese")
def split_text(text, max_length=512):
words = list(jieba.cut(text))
chunks, current_chunk = [], ""
for word in words:
if len(current_chunk) + len(word) < max_length:
current_chunk += word
else:
chunks.append(current_chunk)
current_chunk = word
if current_chunk:
chunks.append(current_chunk)
return chunks
def rate_quality(text):
chunks = split_text(text)
results = []
for chunk in chunks:
result = rating_pipe(chunk)[0]
label_map = {"LABEL_0": "Poor", "LABEL_1": "Neutral", "LABEL_2": "Good"}
results.append(label_map.get(result["label"], "Unknown"))
return max(set(results), key=results.count) # Return most frequent rating
# Streamlit UI
st.title("Cantonese Audio Analysis")
st.write("Upload a Cantonese audio file to transcribe, translate, and rate the conversation quality.")
uploaded_file = st.file_uploader("Upload Audio File", type=["wav", "mp3", "flac"])
if uploaded_file is not None:
st.audio(uploaded_file, format="audio/wav")
temp_audio_path = "uploaded_audio.wav"
with open(temp_audio_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.write("### Processing...")
transcript = transcribe_audio(temp_audio_path)
st.write("**Transcript:**", transcript)
translated_text = translate(transcript)
st.write("**Translation:**", translated_text)
quality_rating = rate_quality(translated_text)
st.write("**Quality Rating:**", quality_rating)
os.remove(temp_audio_path)