|
import streamlit as st |
|
import torch |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torchaudio |
|
import os |
|
|
|
def load_models(): |
|
st.session_state.transcription_pipe = pipeline( |
|
task="automatic-speech-recognition", |
|
model="alvanlii/whisper-small-cantonese", |
|
chunk_length_s=60, |
|
device="cuda" if torch.cuda.is_available() else "cpu" |
|
) |
|
st.session_state.transcription_pipe.model.config.forced_decoder_ids = st.session_state.transcription_pipe.tokenizer.get_decoder_prompt_ids(language="zh", task="transcribe") |
|
|
|
st.session_state.translation_tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh") |
|
st.session_state.translation_model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh") |
|
|
|
st.session_state.rating_pipe = pipeline("text-classification", model="jackietung/bert-base-chinese-finetuned-sentiment") |
|
|
|
def transcribe_audio(audio_path): |
|
pipe = st.session_state.transcription_pipe |
|
return pipe(audio_path)["text"] |
|
|
|
def translate_text(text): |
|
tokenizer = st.session_state.translation_tokenizer |
|
model = st.session_state.translation_model |
|
inputs = tokenizer(text, return_tensors="pt") |
|
outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
def rate_quality(text): |
|
result = st.session_state.rating_pipe(text)[0] |
|
label = result["label"] |
|
label_map = {"負面": "Poor", "中性": "Average", "正面": "Good"} |
|
return label_map.get(label, "Unknown") |
|
|
|
def main(): |
|
st.title("Audio Processing & Conversation Quality Rating") |
|
|
|
if "transcription_pipe" not in st.session_state: |
|
with st.spinner("Loading models..."): |
|
load_models() |
|
|
|
uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "m4a"]) |
|
|
|
if uploaded_file is not None: |
|
with st.spinner("Processing audio..."): |
|
file_path = "temp_audio.wav" |
|
with open(file_path, "wb") as f: |
|
f.write(uploaded_file.read()) |
|
|
|
transcript = transcribe_audio(file_path) |
|
translation = translate_text(transcript) |
|
rating = rate_quality(translation) |
|
|
|
os.remove(file_path) |
|
|
|
st.subheader("Transcription") |
|
st.write(transcript) |
|
|
|
st.subheader("Translation (Chinese)") |
|
st.write(translation) |
|
|
|
st.subheader("Conversation Quality Rating") |
|
st.write(rating) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|