File size: 2,615 Bytes
6e645b6 22b7073 6e645b6 c42d4ba 6e645b6 c42d4ba 3f91072 6e645b6 22b7073 6e645b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torchaudio
import os
def load_models():
st.session_state.transcription_pipe = pipeline(
task="automatic-speech-recognition",
model="alvanlii/whisper-small-cantonese",
chunk_length_s=60,
device="cuda" if torch.cuda.is_available() else "cpu"
)
st.session_state.transcription_pipe.model.config.forced_decoder_ids = st.session_state.transcription_pipe.tokenizer.get_decoder_prompt_ids(language="zh", task="transcribe")
st.session_state.translation_tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh")
st.session_state.translation_model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh")
st.session_state.rating_pipe = pipeline("text-classification", model="jackietung/bert-base-chinese-finetuned-sentiment")
def transcribe_audio(audio_path):
pipe = st.session_state.transcription_pipe
return pipe(audio_path)["text"]
def translate_text(text):
tokenizer = st.session_state.translation_tokenizer
model = st.session_state.translation_model
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def rate_quality(text):
result = st.session_state.rating_pipe(text)[0]
label = result["label"]
label_map = {"負面": "Poor", "中性": "Average", "正面": "Good"}
return label_map.get(label, "Unknown")
def main():
st.title("Audio Processing & Conversation Quality Rating")
if "transcription_pipe" not in st.session_state:
with st.spinner("Loading models..."):
load_models()
uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "m4a"])
if uploaded_file is not None:
with st.spinner("Processing audio..."):
file_path = "temp_audio.wav"
with open(file_path, "wb") as f:
f.write(uploaded_file.read())
transcript = transcribe_audio(file_path)
translation = translate_text(transcript)
rating = rate_quality(translation)
os.remove(file_path)
st.subheader("Transcription")
st.write(transcript)
st.subheader("Translation (Chinese)")
st.write(translation)
st.subheader("Conversation Quality Rating")
st.write(rating)
if __name__ == "__main__":
main()
|