xujinheng666's picture
Update app.py
5b9cbca verified
raw
history blame
3.7 kB
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torchaudio
import os
import re
import jieba
from difflib import SequenceMatcher
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Whisper model for transcription with improved parameters
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=60,
device=device,
generate_kwargs={
"no_repeat_ngram_size": 4,
"repetition_penalty": 1.15,
"temperature": 0.5,
"top_p": 0.97,
"top_k": 40,
"max_new_tokens": 300,
"do_sample": True
}
)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
def is_similar(a, b, threshold=0.8):
return SequenceMatcher(None, a, b).ratio() > threshold
def remove_repeated_phrases(text):
sentences = re.split(r'(?<=[γ€‚οΌοΌŸ])', text)
cleaned_sentences = []
for i, sentence in enumerate(sentences):
if i == 0 or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
cleaned_sentences.append(sentence.strip())
return " ".join(cleaned_sentences)
def transcribe_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
if duration > 60:
results = []
for start in range(0, int(duration), 55):
end = min(start + 60, int(duration))
chunk = waveform[:, start * sample_rate:end * sample_rate]
if chunk.shape[1] == 0:
continue
temp_filename = f"temp_chunk_{start}.wav"
torchaudio.save(temp_filename, chunk, sample_rate)
if os.path.exists(temp_filename):
try:
result = pipe(temp_filename)["text"]
results.append(result)
finally:
os.remove(temp_filename)
return remove_repeated_phrases(" ".join(results))
return remove_repeated_phrases(pipe(audio_path)["text"])
# Load translation model
tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh")
model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh").to(device)
def translate(text):
sentences = [s for s in re.split(r'(?<=[γ€‚οΌοΌŸ])', text) if s]
translations = []
for sentence in sentences:
inputs = tokenizer(sentence, return_tensors="pt").to(device)
outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5)
translations.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return " ".join(translations)
# Streamlit UI
st.set_page_config(page_title="Cantonese Speech Processing", layout="wide")
st.title("🎀 Cantonese Audio Transcription & Translation")
st.write("Upload an audio file to transcribe, translate, and analyze quality.")
uploaded_file = st.file_uploader("Upload your audio file (WAV format)", type=["wav"])
if uploaded_file is not None:
with st.spinner("Processing audio..."):
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(uploaded_file.read())
transcript = transcribe_audio(audio_path)
st.subheader("πŸ“ Transcription")
st.text_area("Transcript", transcript, height=150)
translated_text = translate(transcript)
st.subheader("🌍 Translation")
st.text_area("Translated Text", translated_text, height=150)
st.success("Processing complete!")