File size: 3,441 Bytes
65628c8
6e645b6
 
65628c8
5a44a9a
 
 
3521f10
6e645b6
0eb093e
 
c74678d
5a44a9a
0eb093e
 
5a44a9a
3521f10
 
65628c8
 
 
92ed395
 
 
 
 
65628c8
5b9cbca
5a44a9a
5b9cbca
9afd3be
3d5104e
65628c8
 
 
9afd3be
65628c8
 
 
 
 
 
 
5b9cbca
5a44a9a
 
2ba44e2
6e645b6
65628c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d5ab3
 
65628c8
 
 
 
 
 
 
9afd3be
5a44a9a
65628c8
 
 
5a44a9a
 
65628c8
 
 
 
 
b20dd94
65628c8
 
 
 
3521f10
65628c8
 
 
3521f10
65628c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import torch
import torchaudio
import numpy as np
import re
import streamlit as st
from difflib import SequenceMatcher
from transformers import pipeline

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Whisper model for transcription
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
pipe = pipeline(
    task="automatic-speech-recognition",
    model=MODEL_NAME,
    chunk_length_s=30,
    device=device,
    generate_kwargs={
        "no_repeat_ngram_size": 3,
        "repetition_penalty": 1.3,
        "temperature": 0.7,
        "top_p": 0.9,
        "top_k": 50
    }
)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")

# Load quality rating model
rating_pipe = pipeline("text-classification", model="Leo0129/CustomModel-multilingual-sentiment-analysis", device=device)

def is_similar(a, b, threshold=0.8):
    return SequenceMatcher(None, a, b).ratio() > threshold

def remove_repeated_phrases(text):
    sentences = re.split(r'(?<=[。!?])', text)
    cleaned_sentences = []
    for sentence in sentences:
        if not cleaned_sentences or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
            cleaned_sentences.append(sentence.strip())
    return " ".join(cleaned_sentences)

def remove_punctuation(text):
    return re.sub(r'[^\w\s]', '', text)

def transcribe_audio(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)
    
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    waveform = waveform.squeeze(0).numpy()
    duration = waveform.shape[0] / sample_rate
    
    if duration > 60:
        chunk_size = sample_rate * 55
        step_size = sample_rate * 50
        results = []
        
        for start in range(0, waveform.shape[0], step_size):
            chunk = waveform[start:start + chunk_size]
            if chunk.shape[0] == 0:
                break
            transcript = pipe({"sampling_rate": sample_rate, "raw": chunk})["text"]
            results.append(remove_punctuation(transcript))
        
        return remove_punctuation(remove_repeated_phrases(" ".join(results)))
    
    return remove_punctuation(remove_repeated_phrases(pipe({"sampling_rate": sample_rate, "raw": waveform})["text"]))

def rate_quality(text):
    chunks = [text[i:i+512] for i in range(0, len(text), 512)]
    results = rating_pipe(chunks, batch_size=4)
    
    label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
    processed_results = [label_map.get(res["label"], "Unknown") for res in results]
    
    return max(set(processed_results), key=processed_results.count)

# Streamlit UI
st.title("Audio Transcription and Quality Rating")

uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])

if uploaded_file is not None:
    st.audio(uploaded_file, format="audio/wav")
    
    temp_audio_path = "temp_audio.wav"
    with open(temp_audio_path, "wb") as f:
        f.write(uploaded_file.read())
    
    st.write("Processing audio...")
    transcript = transcribe_audio(temp_audio_path)
    st.subheader("Transcript")
    st.write(transcript)
    
    quality_rating = rate_quality(transcript)
    st.subheader("Quality Rating")
    st.write(quality_rating)
    
    os.remove(temp_audio_path)