xujinheng666's picture
Create app.py
6e645b6 verified
raw
history blame
2.91 kB
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torchaudio
import os
def load_models():
st.session_state.transcription_pipe = pipeline(
task="automatic-speech-recognition",
model="alvanlii/whisper-small-cantonese",
chunk_length_s=60,
device="cuda" if torch.cuda.is_available() else "cpu"
)
st.session_state.transcription_pipe.model.config.forced_decoder_ids = st.session_state.transcription_pipe.tokenizer.get_decoder_prompt_ids(language="zh", task="transcribe")
st.session_state.translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
st.session_state.translation_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
st.session_state.summary_pipe = pipeline("text-summarization", model="facebook/bart-large-cnn")
st.session_state.rating_pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
def transcribe_audio(audio_path):
pipe = st.session_state.transcription_pipe
return pipe(audio_path)["text"]
def translate_text(text):
tokenizer = st.session_state.translation_tokenizer
model = st.session_state.translation_model
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def summarize_text(text):
return st.session_state.summary_pipe(text)[0]['summary_text']
def rate_quality(text):
result = st.session_state.rating_pipe(text)[0]
label_map = {"LABEL_0": "Poor", "LABEL_1": "Average", "LABEL_2": "Good"}
return label_map.get(result["label"], "Unknown")
def main():
st.title("Audio Processing & Conversation Quality Rating")
if "transcription_pipe" not in st.session_state:
with st.spinner("Loading models..."):
load_models()
uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "m4a"])
if uploaded_file is not None:
with st.spinner("Processing audio..."):
file_path = "temp_audio.wav"
with open(file_path, "wb") as f:
f.write(uploaded_file.read())
transcript = transcribe_audio(file_path)
translation = translate_text(transcript)
summary = summarize_text(translation)
rating = rate_quality(translation)
os.remove(file_path)
st.subheader("Transcription")
st.write(transcript)
st.subheader("Translation (English)")
st.write(translation)
st.subheader("Summary")
st.write(summary)
st.subheader("Conversation Quality Rating")
st.write(rating)
if __name__ == "__main__":
main()