import streamlit as st import torch from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM import torchaudio import os def load_models(): st.session_state.transcription_pipe = pipeline( task="automatic-speech-recognition", model="alvanlii/whisper-small-cantonese", chunk_length_s=60, device="cuda" if torch.cuda.is_available() else "cpu" ) st.session_state.transcription_pipe.model.config.forced_decoder_ids = st.session_state.transcription_pipe.tokenizer.get_decoder_prompt_ids(language="zh", task="transcribe") st.session_state.translation_tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh") st.session_state.translation_model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh") st.session_state.rating_pipe = pipeline("text-classification", model="uer/roberta-base-finetuned-jd-binary-chinese") def transcribe_audio(audio_path): pipe = st.session_state.transcription_pipe return pipe(audio_path)["text"] def translate_text(text): tokenizer = st.session_state.translation_tokenizer model = st.session_state.translation_model inputs = tokenizer(text, return_tensors="pt") outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5) return tokenizer.decode(outputs[0], skip_special_tokens=True) def rate_quality(text): result = st.session_state.rating_pipe(text)[0] label = result["label"].split("(")[0].strip().lower() label_map = {"negative": "Poor", "neutral": "Average", "positive": "Good"} return label_map.get(label, "Unknown") def main(): st.title("Audio Processing & Conversation Quality Rating") if "transcription_pipe" not in st.session_state: with st.spinner("Loading models..."): load_models() uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "m4a"]) if uploaded_file is not None: with st.spinner("Processing audio..."): file_path = "temp_audio.wav" with open(file_path, "wb") as f: f.write(uploaded_file.read()) transcript = transcribe_audio(file_path) translation = translate_text(transcript) rating = rate_quality(translation) os.remove(file_path) st.subheader("Transcription") st.write(transcript) st.subheader("Translation (Chinese)") st.write(translation) st.subheader("Conversation Quality Rating") st.write(rating) if __name__ == "__main__": main()