xujinheng666's picture
Update app.py
18d5ab3 verified
raw
history blame
4.49 kB
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torchaudio
import os
import re
import jieba
from difflib import SequenceMatcher
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Whisper model for transcription with improved parameters
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=60,
device=device,
generate_kwargs={
"no_repeat_ngram_size": 4,
"repetition_penalty": 1.15,
"temperature": 0.5,
"top_p": 0.97,
"top_k": 40,
"max_new_tokens": 300,
"do_sample": True
}
)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
def is_similar(a, b, threshold=0.8):
return SequenceMatcher(None, a, b).ratio() > threshold
def remove_repeated_phrases(text):
sentences = re.split(r'(?<=[γ€‚οΌοΌŸ])', text)
cleaned_sentences = []
for i, sentence in enumerate(sentences):
if i == 0 or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
cleaned_sentences.append(sentence.strip())
return " ".join(cleaned_sentences)
def remove_punctuation(text):
return re.sub(r'[^\w\s]', '', text)
def transcribe_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
if duration > 60:
results = []
for start in range(0, int(duration), 55):
end = min(start + 60, int(duration))
chunk = waveform[:, start * sample_rate:end * sample_rate]
if chunk.shape[1] == 0:
continue
temp_filename = f"temp_chunk_{start}.wav"
torchaudio.save(temp_filename, chunk, sample_rate)
if os.path.exists(temp_filename):
try:
result = pipe(temp_filename)["text"]
results.append(remove_punctuation(result))
finally:
os.remove(temp_filename)
return remove_punctuation(remove_repeated_phrases(" ".join(results)))
return remove_punctuation(remove_repeated_phrases(pipe(audio_path)["text"]))
# Load translation model
tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh")
model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh").to(device)
def translate(text):
sentences = [s for s in re.split(r'(?<=[γ€‚οΌοΌŸ])', text) if s]
translations = []
for sentence in sentences:
inputs = tokenizer(sentence, return_tensors="pt").to(device)
outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5)
translations.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return " ".join(translations)
# Load quality rating model
rating_pipe = pipeline("text-classification", model="Leo0129/CustomModel_dianping-chinese")
def rate_quality(text):
chunks = [text[i:i+512] for i in range(0, len(text), 512)]
results = []
for chunk in chunks:
result = rating_pipe(chunk)[0]
label_map = {"LABEL_0": "Poor", "LABEL_1": "Neutral", "LABEL_2": "Good"}
results.append(label_map.get(result["label"], "Unknown"))
return max(set(results), key=results.count)
# Streamlit UI
st.set_page_config(page_title="Cantonese Speech Processing", layout="wide")
st.title("🎀 Cantonese Audio Transcription, Translation & Quality Rating")
st.write("Upload an audio file to transcribe, translate, and analyze quality.")
uploaded_file = st.file_uploader("Upload your audio file (WAV format)", type=["wav"])
if uploaded_file is not None:
with st.spinner("Processing audio..."):
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(uploaded_file.read())
transcript = transcribe_audio(audio_path)
st.subheader("πŸ“ Transcription")
st.text_area("Transcript", transcript, height=150)
translated_text = translate(transcript)
st.subheader("🌍 Translation")
st.text_area("Translated Text", translated_text, height=150)
quality_rating = rate_quality(translated_text)
st.subheader("⭐ Quality Rating")
st.write(f"**Rating:** {quality_rating}")
st.success("Processing complete!")