xujinheng666 commited on
Commit
5b9cbca
·
verified ·
1 Parent(s): 80fad5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -83
app.py CHANGED
@@ -5,63 +5,69 @@ import torchaudio
5
  import os
6
  import re
7
  import jieba
 
8
 
9
  # Device setup
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- # Load Whisper model for transcription
13
  MODEL_NAME = "alvanlii/whisper-small-cantonese"
14
  language = "zh"
15
- pipe = pipeline(task="automatic-speech-recognition", model=MODEL_NAME, chunk_length_s=60, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
17
 
 
 
 
 
 
 
 
 
 
 
 
18
  def transcribe_audio(audio_path):
19
  waveform, sample_rate = torchaudio.load(audio_path)
20
  duration = waveform.shape[1] / sample_rate
21
-
22
  if duration > 60:
23
  results = []
24
- for start in range(0, int(duration), 50):
25
  end = min(start + 60, int(duration))
26
  chunk = waveform[:, start * sample_rate:end * sample_rate]
27
-
28
- if chunk.shape[1] == 0: # Skip empty chunks
29
  continue
30
-
31
  temp_filename = f"temp_chunk_{start}.wav"
32
-
33
- # Save the chunk only if it contains valid data
34
- try:
35
- torchaudio.save(temp_filename, chunk, sample_rate)
36
- if not os.path.exists(temp_filename):
37
- print(f"Error: {temp_filename} was not created!")
38
- continue
39
-
40
- result = pipe(temp_filename)["text"]
41
- results.append(result)
42
-
43
- except Exception as e:
44
- print(f"Error processing {temp_filename}: {e}")
45
-
46
- finally:
47
- if os.path.exists(temp_filename):
48
- os.remove(temp_filename) # Ensure file is deleted safely
49
- else:
50
- print(f"Warning: File {temp_filename} was not found for deletion.")
51
-
52
- return " ".join(results)
53
-
54
- return pipe(audio_path)["text"]
55
 
56
  # Load translation model
57
  tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh")
58
  model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh").to(device)
59
 
60
- def split_sentences(text):
61
- return [s for s in re.split(r'(?<=[。!?])', text) if s]
62
-
63
  def translate(text):
64
- sentences = split_sentences(text)
65
  translations = []
66
  for sentence in sentences:
67
  inputs = tokenizer(sentence, return_tensors="pt").to(device)
@@ -69,56 +75,24 @@ def translate(text):
69
  translations.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
70
  return " ".join(translations)
71
 
72
- # Load quality rating model
73
- rating_pipe = pipeline("text-classification", model="Leo0129/CustomModel_dianping-chinese")
74
-
75
- def split_text(text, max_length=512):
76
- words = list(jieba.cut(text))
77
- chunks, current_chunk = [], ""
78
-
79
- for word in words:
80
- if len(current_chunk) + len(word) < max_length:
81
- current_chunk += word
82
- else:
83
- chunks.append(current_chunk)
84
- current_chunk = word
85
-
86
- if current_chunk:
87
- chunks.append(current_chunk)
88
-
89
- return chunks
90
-
91
- def rate_quality(text):
92
- chunks = split_text(text)
93
- results = []
94
-
95
- for chunk in chunks:
96
- result = rating_pipe(chunk)[0]
97
- label_map = {"LABEL_0": "Poor", "LABEL_1": "Neutral", "LABEL_2": "Good"}
98
- results.append(label_map.get(result["label"], "Unknown"))
99
-
100
- return max(set(results), key=results.count) # Return most frequent rating
101
-
102
  # Streamlit UI
103
- st.title("Cantonese Audio Analysis")
104
- st.write("Upload a Cantonese audio file to transcribe, translate, and rate the conversation quality.")
 
105
 
106
- uploaded_file = st.file_uploader("Upload Audio File", type=["wav", "mp3", "flac"])
107
 
108
  if uploaded_file is not None:
109
- st.audio(uploaded_file, format="audio/wav")
110
- temp_audio_path = "uploaded_audio.wav"
111
- with open(temp_audio_path, "wb") as f:
112
- f.write(uploaded_file.getbuffer())
113
-
114
- st.write("### Processing...")
115
- transcript = transcribe_audio(temp_audio_path)
116
- st.write("**Transcript:**", transcript)
117
-
118
- translated_text = translate(transcript)
119
- st.write("**Translation:**", translated_text)
120
-
121
- quality_rating = rate_quality(translated_text)
122
- st.write("**Quality Rating:**", quality_rating)
123
-
124
- os.remove(temp_audio_path)
 
5
  import os
6
  import re
7
  import jieba
8
+ from difflib import SequenceMatcher
9
 
10
  # Device setup
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Load Whisper model for transcription with improved parameters
14
  MODEL_NAME = "alvanlii/whisper-small-cantonese"
15
  language = "zh"
16
+ pipe = pipeline(
17
+ task="automatic-speech-recognition",
18
+ model=MODEL_NAME,
19
+ chunk_length_s=60,
20
+ device=device,
21
+ generate_kwargs={
22
+ "no_repeat_ngram_size": 4,
23
+ "repetition_penalty": 1.15,
24
+ "temperature": 0.5,
25
+ "top_p": 0.97,
26
+ "top_k": 40,
27
+ "max_new_tokens": 300,
28
+ "do_sample": True
29
+ }
30
+ )
31
  pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
32
 
33
+ def is_similar(a, b, threshold=0.8):
34
+ return SequenceMatcher(None, a, b).ratio() > threshold
35
+
36
+ def remove_repeated_phrases(text):
37
+ sentences = re.split(r'(?<=[。!?])', text)
38
+ cleaned_sentences = []
39
+ for i, sentence in enumerate(sentences):
40
+ if i == 0 or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
41
+ cleaned_sentences.append(sentence.strip())
42
+ return " ".join(cleaned_sentences)
43
+
44
  def transcribe_audio(audio_path):
45
  waveform, sample_rate = torchaudio.load(audio_path)
46
  duration = waveform.shape[1] / sample_rate
 
47
  if duration > 60:
48
  results = []
49
+ for start in range(0, int(duration), 55):
50
  end = min(start + 60, int(duration))
51
  chunk = waveform[:, start * sample_rate:end * sample_rate]
52
+ if chunk.shape[1] == 0:
 
53
  continue
 
54
  temp_filename = f"temp_chunk_{start}.wav"
55
+ torchaudio.save(temp_filename, chunk, sample_rate)
56
+ if os.path.exists(temp_filename):
57
+ try:
58
+ result = pipe(temp_filename)["text"]
59
+ results.append(result)
60
+ finally:
61
+ os.remove(temp_filename)
62
+ return remove_repeated_phrases(" ".join(results))
63
+ return remove_repeated_phrases(pipe(audio_path)["text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # Load translation model
66
  tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh")
67
  model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh").to(device)
68
 
 
 
 
69
  def translate(text):
70
+ sentences = [s for s in re.split(r'(?<=[。!?])', text) if s]
71
  translations = []
72
  for sentence in sentences:
73
  inputs = tokenizer(sentence, return_tensors="pt").to(device)
 
75
  translations.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
76
  return " ".join(translations)
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # Streamlit UI
79
+ st.set_page_config(page_title="Cantonese Speech Processing", layout="wide")
80
+ st.title("🎤 Cantonese Audio Transcription & Translation")
81
+ st.write("Upload an audio file to transcribe, translate, and analyze quality.")
82
 
83
+ uploaded_file = st.file_uploader("Upload your audio file (WAV format)", type=["wav"])
84
 
85
  if uploaded_file is not None:
86
+ with st.spinner("Processing audio..."):
87
+ audio_path = "temp_audio.wav"
88
+ with open(audio_path, "wb") as f:
89
+ f.write(uploaded_file.read())
90
+ transcript = transcribe_audio(audio_path)
91
+ st.subheader("📝 Transcription")
92
+ st.text_area("Transcript", transcript, height=150)
93
+
94
+ translated_text = translate(transcript)
95
+ st.subheader("🌍 Translation")
96
+ st.text_area("Translated Text", translated_text, height=150)
97
+
98
+ st.success("Processing complete!")