xujinheng666's picture
Update app.py
7ea48f6 verified
raw
history blame
3.32 kB
import streamlit as st
import torch
import torchaudio
import numpy as np
import re
from difflib import SequenceMatcher
from transformers import pipeline
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Whisper model for transcription
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=10,
device=device,
generate_kwargs={
"no_repeat_ngram_size": 3,
"repetition_penalty": 1.15,
"temperature": 0.3,
# "top_p": 0.97,
"top_k": 20,
"max_new_tokens": 200,
"do_sample": False
}
)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
rating_pipe = pipeline("text-classification", model="MonkeyDLLLLLLuffy/CustomModel-multilingual-sentiment-analysis", device=device)
def is_similar(a, b, threshold=0.8):
return SequenceMatcher(None, a, b).ratio() > threshold
def remove_repeated_phrases(text):
sentences = re.split(r'(?<=[。!?])', text)
cleaned_sentences = []
for sentence in sentences:
if not cleaned_sentences or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
cleaned_sentences.append(sentence.strip())
return " ".join(cleaned_sentences)
def remove_punctuation(text):
return re.sub(r'[^\w\s]', '', text)
def transcribe_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
waveform = waveform.squeeze(0).numpy()
duration = waveform.shape[0] / sample_rate
if duration > 60:
chunk_size = sample_rate * 55
step_size = sample_rate * 50
results = []
for start in range(0, waveform.shape[0], step_size):
chunk = waveform[start:start + chunk_size]
if chunk.shape[0] == 0:
break
transcript = pipe({"sampling_rate": sample_rate, "raw": chunk})["text"]
results.append(remove_punctuation(transcript))
return remove_punctuation(remove_repeated_phrases(" ".join(results)))
return remove_punctuation(remove_repeated_phrases(pipe({"sampling_rate": sample_rate, "raw": waveform})["text"]))
def rate_quality(text):
chunks = [text[i:i+512] for i in range(0, len(text), 512)]
results = rating_pipe(chunks, batch_size=4)
label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
processed_results = [label_map.get(res["label"], "Unknown") for res in results]
return max(set(processed_results), key=processed_results.count)
# Streamlit UI
st.title("Audio Transcription & Quality Rating")
uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if uploaded_file:
st.audio(uploaded_file, format='audio/wav')
with open("temp_audio.wav", "wb") as f:
f.write(uploaded_file.read())
st.write("Processing audio...")
transcript = transcribe_audio("temp_audio.wav")
st.subheader("Transcript")
st.write(transcript)
quality_rating = rate_quality(transcript)
st.subheader("Quality Rating")
st.write(quality_rating)