xujinheng666 commited on
Commit
3521f10
Β·
verified Β·
1 Parent(s): c1fa9f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -89
app.py CHANGED
@@ -1,117 +1,61 @@
1
- import streamlit as st
2
  import torch
3
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torchaudio
5
  import os
6
  import re
7
- import jieba
8
  from difflib import SequenceMatcher
 
9
 
10
  # Device setup
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # Load Whisper model for transcription with improved parameters
14
  MODEL_NAME = "alvanlii/whisper-small-cantonese"
15
  language = "zh"
16
  pipe = pipeline(
17
- task="automatic-speech-recognition",
18
- model=MODEL_NAME,
19
- chunk_length_s=60,
20
- device=device,
21
- generate_kwargs={
22
- "no_repeat_ngram_size": 4,
23
- "repetition_penalty": 1.15,
24
- "temperature": 0.5,
25
- "top_p": 0.97,
26
- "top_k": 40,
27
- "max_new_tokens": 300,
28
- "do_sample": True
29
- }
30
  )
31
  pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
32
 
33
- def is_similar(a, b, threshold=0.8):
34
- return SequenceMatcher(None, a, b).ratio() > threshold
35
 
36
- def remove_repeated_phrases(text):
37
- sentences = re.split(r'(?<=[γ€‚οΌοΌŸ])', text)
38
- cleaned_sentences = []
39
- for i, sentence in enumerate(sentences):
40
- if i == 0 or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
41
- cleaned_sentences.append(sentence.strip())
42
- return " ".join(cleaned_sentences)
43
 
44
  def remove_punctuation(text):
45
  return re.sub(r'[^\w\s]', '', text)
46
 
47
  def transcribe_audio(audio_path):
48
- waveform, sample_rate = torchaudio.load(audio_path)
49
- duration = waveform.shape[1] / sample_rate
50
- if duration > 60:
51
- results = []
52
- for start in range(0, int(duration), 55):
53
- end = min(start + 60, int(duration))
54
- chunk = waveform[:, start * sample_rate:end * sample_rate]
55
- if chunk.shape[1] == 0:
56
- continue
57
- temp_filename = f"temp_chunk_{start}.wav"
58
- torchaudio.save(temp_filename, chunk, sample_rate)
59
- if os.path.exists(temp_filename):
60
- try:
61
- result = pipe(temp_filename)["text"]
62
- results.append(remove_punctuation(result))
63
- finally:
64
- os.remove(temp_filename)
65
- return remove_punctuation(remove_repeated_phrases(" ".join(results)))
66
- return remove_punctuation(remove_repeated_phrases(pipe(audio_path)["text"]))
67
-
68
- # Load translation model
69
- tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh")
70
- model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh").to(device)
71
-
72
- def translate(text):
73
- sentences = [s for s in re.split(r'(?<=[γ€‚οΌοΌŸ])', text) if s]
74
- translations = []
75
- for sentence in sentences:
76
- inputs = tokenizer(sentence, return_tensors="pt").to(device)
77
- outputs = model.generate(inputs["input_ids"], max_length=2000, num_beams=5) # Increased max_length to 2000
78
- translations.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
79
- return " ".join(translations)
80
-
81
- # Load quality rating model
82
- rating_pipe = pipeline("text-classification", model="Leo0129/CustomModel_dianping-chinese")
83
 
84
  def rate_quality(text):
85
- chunks = [text[i:i+512] for i in range(0, len(text), 512)]
86
- results = []
87
- for chunk in chunks:
88
- result = rating_pipe(chunk)[0]
89
- label_map = {"LABEL_0": "Poor", "LABEL_1": "Neutral", "LABEL_2": "Good"}
90
- results.append(label_map.get(result["label"], "Unknown"))
91
- return max(set(results), key=results.count)
92
 
93
  # Streamlit UI
94
- st.set_page_config(page_title="Cantonese Speech Processing", layout="wide")
95
- st.title("🎀 Cantonese Audio Transcription, Translation & Quality Rating")
96
- st.write("Upload an audio file to transcribe, translate, and analyze quality.")
97
-
98
- uploaded_file = st.file_uploader("Upload your audio file (WAV format)", type=["wav"])
99
 
 
100
  if uploaded_file is not None:
101
  with st.spinner("Processing audio..."):
102
- audio_path = "temp_audio.wav"
103
- with open(audio_path, "wb") as f:
104
- f.write(uploaded_file.read())
105
- transcript = transcribe_audio(audio_path)
106
- st.subheader("πŸ“ Transcription")
107
- st.text_area("Transcript", transcript, height=150)
108
-
109
- translated_text = translate(transcript)
110
- st.subheader("🌍 Translation")
111
- st.text_area("Translated Text", translated_text, height=150)
112
-
113
- quality_rating = rate_quality(translated_text)
114
- st.subheader("⭐ Quality Rating")
115
- st.write(f"**Rating:** {quality_rating}")
116
-
117
- st.success("Processing complete!")
 
 
1
  import torch
 
2
  import torchaudio
3
  import os
4
  import re
5
+ import streamlit as st
6
  from difflib import SequenceMatcher
7
+ from transformers import pipeline
8
 
9
  # Device setup
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ # Load Whisper model for transcription
13
  MODEL_NAME = "alvanlii/whisper-small-cantonese"
14
  language = "zh"
15
  pipe = pipeline(
16
+ task="automatic-speech-recognition",
17
+ model=MODEL_NAME,
18
+ chunk_length_s=60,
19
+ device=device
 
 
 
 
 
 
 
 
 
20
  )
21
  pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
22
 
23
+ # Load quality rating model
24
+ rating_pipe = pipeline("text-classification", model="tabularisai/multilingual-sentiment-analysis")
25
 
26
+ # Sentiment label mapping
27
+ label_map = {"Negative": "Very Poor", "Neutral": "Neutral", "Positive": "Very Good"}
 
 
 
 
 
28
 
29
  def remove_punctuation(text):
30
  return re.sub(r'[^\w\s]', '', text)
31
 
32
  def transcribe_audio(audio_path):
33
+ transcript = pipe(audio_path)["text"]
34
+ return remove_punctuation(transcript)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def rate_quality(text):
37
+ result = rating_pipe(text)[0]
38
+ return label_map.get(result["label"], "Unknown")
 
 
 
 
 
39
 
40
  # Streamlit UI
41
+ st.set_page_config(page_title="Cantonese Audio Transcription & Analysis", layout="centered")
42
+ st.title("πŸ—£οΈ Cantonese Audio Transcriber & Sentiment Analyzer")
43
+ st.markdown("Upload your Cantonese audio file, and we will transcribe and analyze its sentiment.")
 
 
44
 
45
+ uploaded_file = st.file_uploader("Upload an audio file (WAV, MP3, etc.)", type=["wav", "mp3", "m4a"])
46
  if uploaded_file is not None:
47
  with st.spinner("Processing audio..."):
48
+ temp_audio_path = "temp_audio.wav"
49
+ with open(temp_audio_path, "wb") as f:
50
+ f.write(uploaded_file.getbuffer())
51
+ transcript = transcribe_audio(temp_audio_path)
52
+ sentiment = rate_quality(transcript)
53
+ os.remove(temp_audio_path)
54
+
55
+ st.subheader("Transcription")
56
+ st.text_area("", transcript, height=150)
57
+
58
+ st.subheader("Sentiment Analysis")
59
+ st.markdown(f"### 🎭 Sentiment: **{sentiment}**")
60
+
61
+ st.success("Processing complete! πŸŽ‰")