File size: 4,490 Bytes
6e645b6
 
 
 
 
0eb093e
2ffc60e
5b9cbca
6e645b6
0eb093e
 
c74678d
5b9cbca
0eb093e
 
5b9cbca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0eb093e
6e645b6
5b9cbca
 
 
 
 
 
 
 
 
 
 
2ba44e2
18d5ab3
2ba44e2
6e645b6
0eb093e
 
 
 
5b9cbca
2ffc60e
0eb093e
5b9cbca
26db416
0eb093e
5b9cbca
 
 
 
2ba44e2
5b9cbca
 
2ba44e2
 
6e645b6
0eb093e
 
 
 
 
5b9cbca
0eb093e
 
 
 
 
 
 
18d5ab3
 
 
 
 
 
 
 
 
 
 
 
0eb093e
5b9cbca
18d5ab3
5b9cbca
0eb093e
5b9cbca
0eb093e
 
5b9cbca
 
 
 
 
 
 
 
 
 
 
 
18d5ab3
 
 
 
5b9cbca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torchaudio
import os
import re
import jieba
from difflib import SequenceMatcher

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Whisper model for transcription with improved parameters
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
pipe = pipeline(
    task="automatic-speech-recognition", 
    model=MODEL_NAME, 
    chunk_length_s=60, 
    device=device,
    generate_kwargs={
        "no_repeat_ngram_size": 4,  
        "repetition_penalty": 1.15,  
        "temperature": 0.5,  
        "top_p": 0.97,  
        "top_k": 40,  
        "max_new_tokens": 300,  
        "do_sample": True  
    }
)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")

def is_similar(a, b, threshold=0.8):
    return SequenceMatcher(None, a, b).ratio() > threshold

def remove_repeated_phrases(text):
    sentences = re.split(r'(?<=[。!?])', text)
    cleaned_sentences = []
    for i, sentence in enumerate(sentences):
        if i == 0 or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
            cleaned_sentences.append(sentence.strip())
    return " ".join(cleaned_sentences)

def remove_punctuation(text):
    return re.sub(r'[^\w\s]', '', text)

def transcribe_audio(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)
    duration = waveform.shape[1] / sample_rate
    if duration > 60:
        results = []
        for start in range(0, int(duration), 55):
            end = min(start + 60, int(duration))
            chunk = waveform[:, start * sample_rate:end * sample_rate]
            if chunk.shape[1] == 0:
                continue
            temp_filename = f"temp_chunk_{start}.wav"
            torchaudio.save(temp_filename, chunk, sample_rate)
            if os.path.exists(temp_filename):
                try:
                    result = pipe(temp_filename)["text"]
                    results.append(remove_punctuation(result))
                finally:
                    os.remove(temp_filename)
        return remove_punctuation(remove_repeated_phrases(" ".join(results)))
    return remove_punctuation(remove_repeated_phrases(pipe(audio_path)["text"]))

# Load translation model
tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh")
model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh").to(device)

def translate(text):
    sentences = [s for s in re.split(r'(?<=[。!?])', text) if s]
    translations = []
    for sentence in sentences:
        inputs = tokenizer(sentence, return_tensors="pt").to(device)
        outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5)
        translations.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
    return " ".join(translations)

# Load quality rating model
rating_pipe = pipeline("text-classification", model="Leo0129/CustomModel_dianping-chinese")

def rate_quality(text):
    chunks = [text[i:i+512] for i in range(0, len(text), 512)]
    results = []
    for chunk in chunks:
        result = rating_pipe(chunk)[0]
        label_map = {"LABEL_0": "Poor", "LABEL_1": "Neutral", "LABEL_2": "Good"}
        results.append(label_map.get(result["label"], "Unknown"))
    return max(set(results), key=results.count)

# Streamlit UI
st.set_page_config(page_title="Cantonese Speech Processing", layout="wide")
st.title("🎤 Cantonese Audio Transcription, Translation & Quality Rating")
st.write("Upload an audio file to transcribe, translate, and analyze quality.")

uploaded_file = st.file_uploader("Upload your audio file (WAV format)", type=["wav"])

if uploaded_file is not None:
    with st.spinner("Processing audio..."):
        audio_path = "temp_audio.wav"
        with open(audio_path, "wb") as f:
            f.write(uploaded_file.read())
        transcript = transcribe_audio(audio_path)
        st.subheader("📝 Transcription")
        st.text_area("Transcript", transcript, height=150)
        
        translated_text = translate(transcript)
        st.subheader("🌍 Translation")
        st.text_area("Translated Text", translated_text, height=150)
        
        quality_rating = rate_quality(translated_text)
        st.subheader("⭐ Quality Rating")
        st.write(f"**Rating:** {quality_rating}")
        
        st.success("Processing complete!")