Luigi commited on
Commit
9e56b98
·
1 Parent(s): 4cb4a31

Revert "allow cpu spk diarazation"

Browse files

This reverts commit 4cb4a31ed709eb1408bfa0bf05b28f493f0a24c6.

Files changed (1) hide show
  1. app.py +96 -40
app.py CHANGED
@@ -51,13 +51,14 @@ WHISPER_LANGUAGES = [
51
 
52
  SENSEVOICE_LANGUAGES = ["auto", "zh", "yue", "en", "ja", "ko", "nospeech"]
53
 
54
- # —————— Caches & Converter ——————
55
  whisper_pipes = {}
56
  sense_models = {}
57
  dar_pipe = None
 
58
  converter = opencc.OpenCC('s2t')
59
 
60
- # —————— Helper Functions ——————
61
  def get_whisper_pipe(model_id: str, device: int):
62
  key = (model_id, device)
63
  if key not in whisper_pipes:
@@ -102,14 +103,13 @@ def get_diarization_pipe():
102
  )
103
  return dar_pipe
104
 
 
105
  # —————— Whisper Transcription ——————
106
- def _transcribe_whisper_cpu(model_id, language, audio_path, enable_diar, diar_device):
107
  pipe = get_whisper_pipe(model_id, -1)
 
108
  if enable_diar:
109
  diarizer = get_diarization_pipe()
110
- # Move diarization pipeline to correct device
111
- dev = torch.device('cuda') if diar_device == 'GPU' and torch.cuda.is_available() else torch.device('cpu')
112
- diarizer.to(dev)
113
  diary = diarizer(audio_path)
114
  snippets = []
115
  for turn, _, speaker in diary.itertracks(yield_label=True):
@@ -123,17 +123,18 @@ def _transcribe_whisper_cpu(model_id, language, audio_path, enable_diar, diar_de
123
  text = converter.convert(out.get("text", "").strip())
124
  snippets.append(f"[{speaker}] {text}")
125
  return "", "\n".join(snippets)
 
126
  result = pipe(audio_path) if language == "auto" else pipe(audio_path, generate_kwargs={"language": language})
127
  transcript = converter.convert(result.get("text", "").strip())
128
  return transcript, ""
129
 
130
 
131
- def _transcribe_whisper_gpu(model_id, language, audio_path, enable_diar, diar_device):
 
132
  pipe = get_whisper_pipe(model_id, 0)
 
133
  if enable_diar:
134
  diarizer = get_diarization_pipe()
135
- dev = torch.device('cuda') if diar_device == 'GPU' and torch.cuda.is_available() else torch.device('cpu')
136
- diarizer.to(dev)
137
  diary = diarizer(audio_path)
138
  snippets = []
139
  for turn, _, speaker in diary.itertracks(yield_label=True):
@@ -147,23 +148,28 @@ def _transcribe_whisper_gpu(model_id, language, audio_path, enable_diar, diar_de
147
  text = converter.convert(out.get("text", "").strip())
148
  snippets.append(f"[{speaker}] {text}")
149
  return "", "\n".join(snippets)
 
150
  result = pipe(audio_path) if language == "auto" else pipe(audio_path, generate_kwargs={"language": language})
151
  transcript = converter.convert(result.get("text", "").strip())
152
  return transcript, ""
153
 
154
 
155
- def transcribe_whisper(model_id, language, audio_path, device_sel, diar_device, enable_diar):
156
  if device_sel == "GPU" and torch.cuda.is_available():
157
- return _transcribe_whisper_gpu(model_id, language, audio_path, enable_diar, diar_device)
158
- return _transcribe_whisper_cpu(model_id, language, audio_path, enable_diar, diar_device)
 
159
 
160
  # —————— SenseVoice Transcription ——————
161
- def _transcribe_sense_cpu(model_id: str, language: str, audio_path: str, enable_punct: bool, enable_diar: bool, diar_device: str):
 
 
 
 
162
  model = get_sense_model(model_id, "cpu")
 
163
  if enable_diar:
164
  diarizer = get_diarization_pipe()
165
- dev = torch.device('cuda') if diar_device == 'GPU' and torch.cuda.is_available() else torch.device('cpu')
166
- diarizer.to(dev)
167
  diary = diarizer(audio_path)
168
  snippets = []
169
  for turn, _, speaker in diary.itertracks(yield_label=True):
@@ -172,8 +178,15 @@ def _transcribe_sense_cpu(model_id: str, language: str, audio_path: str, enable_
172
  segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
173
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
174
  segment.export(tmp.name, format="wav")
175
- segs = model.generate(input=tmp.name, cache={}, language=language, use_itn=True,
176
- batch_size_s=300, merge_vad=False, merge_length_s=0)
 
 
 
 
 
 
 
177
  os.unlink(tmp.name)
178
  txt = rich_transcription_postprocess(segs[0]['text'])
179
  if not enable_punct:
@@ -181,8 +194,16 @@ def _transcribe_sense_cpu(model_id: str, language: str, audio_path: str, enable_
181
  txt = converter.convert(txt)
182
  snippets.append(f"[{speaker}] {txt}")
183
  return "", "\n".join(snippets)
184
- segs = model.generate(input=audio_path, cache={}, language=language, use_itn=True,
185
- batch_size_s=300, merge_vad=True, merge_length_s=15)
 
 
 
 
 
 
 
 
186
  text = rich_transcription_postprocess(segs[0]['text'])
187
  if not enable_punct:
188
  text = re.sub(r"[^\w\s]", "", text)
@@ -190,12 +211,16 @@ def _transcribe_sense_cpu(model_id: str, language: str, audio_path: str, enable_
190
  return text, ""
191
 
192
 
193
- def _transcribe_sense_gpu(model_id: str, language: str, audio_path: str, enable_punct: bool, enable_diar: bool, diar_device: str):
 
 
 
 
 
194
  model = get_sense_model(model_id, "cuda:0")
 
195
  if enable_diar:
196
  diarizer = get_diarization_pipe()
197
- dev = torch.device('cuda') if diar_device == 'GPU' and torch.cuda.is_available() else torch.device('cpu')
198
- diarizer.to(dev)
199
  diary = diarizer(audio_path)
200
  snippets = []
201
  for turn, _, speaker in diary.itertracks(yield_label=True):
@@ -204,8 +229,15 @@ def _transcribe_sense_gpu(model_id: str, language: str, audio_path: str, enable_
204
  segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
205
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
206
  segment.export(tmp.name, format="wav")
207
- segs = model.generate(input=tmp.name, cache={}, language=language, use_itn=True,
208
- batch_size_s=300, merge_vad=False, merge_length_s=0)
 
 
 
 
 
 
 
209
  os.unlink(tmp.name)
210
  txt = rich_transcription_postprocess(segs[0]['text'])
211
  if not enable_punct:
@@ -213,8 +245,16 @@ def _transcribe_sense_gpu(model_id: str, language: str, audio_path: str, enable_
213
  txt = converter.convert(txt)
214
  snippets.append(f"[{speaker}] {txt}")
215
  return "", "\n".join(snippets)
216
- segs = model.generate(input=audio_path, cache={}, language=language, use_itn=True,
217
- batch_size_s=300, merge_vad=True, merge_length_s=15)
 
 
 
 
 
 
 
 
218
  text = rich_transcription_postprocess(segs[0]['text'])
219
  if not enable_punct:
220
  text = re.sub(r"[^\w\s]", "", text)
@@ -222,51 +262,67 @@ def _transcribe_sense_gpu(model_id: str, language: str, audio_path: str, enable_
222
  return text, ""
223
 
224
 
225
- def transcribe_sense(model_id: str, language: str, audio_path: str, enable_punct: bool, enable_diar: bool, device_sel: str, diar_device: str):
 
 
 
 
 
226
  if device_sel == "GPU" and torch.cuda.is_available():
227
- return _transcribe_sense_gpu(model_id, language, audio_path, enable_punct, enable_diar, diar_device)
228
- return _transcribe_sense_cpu(model_id, language, audio_path, enable_punct, enable_diar, diar_device)
 
229
 
230
  # —————— Gradio UI ——————
231
  Demo = gr.Blocks()
232
  with Demo:
233
- gr.Markdown("## Whisper vs. SenseVoice (Language, ASR & Diarization Devices)")
 
234
  audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
 
 
235
  examples = gr.Examples(
236
- examples=[["interview.mp3"],["news.mp3"]],
237
- inputs=[audio_input], label="Example Audio Files"
 
 
 
 
238
  )
 
239
  with gr.Row():
240
  with gr.Column():
241
  gr.Markdown("### Whisper ASR")
242
  whisper_dd = gr.Dropdown(choices=WHISPER_MODELS, value=WHISPER_MODELS[0], label="Whisper Model")
243
  whisper_lang = gr.Dropdown(choices=WHISPER_LANGUAGES, value="auto", label="Whisper Language")
244
- asr_device = gr.Radio(choices=["GPU","CPU"], value="GPU", label="ASR Device")
245
- diar_device = gr.Radio(choices=["GPU","CPU"], value="CPU", label="Diarization Device")
246
  diar_check = gr.Checkbox(label="Enable Diarization", value=True)
247
  out_w = gr.Textbox(label="Transcript", visible=False)
248
  out_w_d = gr.Textbox(label="Diarized Transcript", visible=True)
 
249
  diar_check.change(lambda e: gr.update(visible=not e), inputs=diar_check, outputs=out_w)
250
  diar_check.change(lambda e: gr.update(visible=e), inputs=diar_check, outputs=out_w_d)
251
  btn_w = gr.Button("Transcribe with Whisper")
252
  btn_w.click(fn=transcribe_whisper,
253
- inputs=[whisper_dd,whisper_lang,audio_input,asr_device,diar_device,diar_check],
254
- outputs=[out_w,out_w_d])
 
255
  with gr.Column():
256
  gr.Markdown("### FunASR SenseVoice ASR")
257
  sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model")
258
  sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language")
259
- asr_device_s = gr.Radio(choices=["GPU","CPU"], value="GPU", label="ASR Device")
260
- diar_device_s = gr.Radio(choices=["GPU","CPU"], value="CPU", label="Diarization Device")
261
  punct_chk = gr.Checkbox(label="Enable Punctuation", value=True)
262
  diar_s_chk = gr.Checkbox(label="Enable Diarization", value=True)
263
  out_s = gr.Textbox(label="Transcript", visible=False)
264
  out_s_d = gr.Textbox(label="Diarized Transcript", visible=True)
 
265
  diar_s_chk.change(lambda e: gr.update(visible=not e), inputs=diar_s_chk, outputs=out_s)
266
  diar_s_chk.change(lambda e: gr.update(visible=e), inputs=diar_s_chk, outputs=out_s_d)
267
  btn_s = gr.Button("Transcribe with SenseVoice")
268
  btn_s.click(fn=transcribe_sense,
269
- inputs=[sense_dd,sense_lang,audio_input,punct_chk,diar_s_chk,asr_device_s,diar_device_s],
270
- outputs=[out_s,out_s_d])
 
271
  if __name__ == "__main__":
272
  Demo.launch()
 
51
 
52
  SENSEVOICE_LANGUAGES = ["auto", "zh", "yue", "en", "ja", "ko", "nospeech"]
53
 
54
+ # —————— Caches ——————
55
  whisper_pipes = {}
56
  sense_models = {}
57
  dar_pipe = None
58
+
59
  converter = opencc.OpenCC('s2t')
60
 
61
+ # —————— Helpers ——————
62
  def get_whisper_pipe(model_id: str, device: int):
63
  key = (model_id, device)
64
  if key not in whisper_pipes:
 
103
  )
104
  return dar_pipe
105
 
106
+
107
  # —————— Whisper Transcription ——————
108
+ def _transcribe_whisper_cpu(model_id, language, audio_path, enable_diar):
109
  pipe = get_whisper_pipe(model_id, -1)
110
+ # Diarization-only branch
111
  if enable_diar:
112
  diarizer = get_diarization_pipe()
 
 
 
113
  diary = diarizer(audio_path)
114
  snippets = []
115
  for turn, _, speaker in diary.itertracks(yield_label=True):
 
123
  text = converter.convert(out.get("text", "").strip())
124
  snippets.append(f"[{speaker}] {text}")
125
  return "", "\n".join(snippets)
126
+ # Raw-only branch
127
  result = pipe(audio_path) if language == "auto" else pipe(audio_path, generate_kwargs={"language": language})
128
  transcript = converter.convert(result.get("text", "").strip())
129
  return transcript, ""
130
 
131
 
132
+ @spaces.GPU(duration=100)
133
+ def _transcribe_whisper_gpu(model_id, language, audio_path, enable_diar):
134
  pipe = get_whisper_pipe(model_id, 0)
135
+ # Diarization-only branch
136
  if enable_diar:
137
  diarizer = get_diarization_pipe()
 
 
138
  diary = diarizer(audio_path)
139
  snippets = []
140
  for turn, _, speaker in diary.itertracks(yield_label=True):
 
148
  text = converter.convert(out.get("text", "").strip())
149
  snippets.append(f"[{speaker}] {text}")
150
  return "", "\n".join(snippets)
151
+ # Raw-only branch
152
  result = pipe(audio_path) if language == "auto" else pipe(audio_path, generate_kwargs={"language": language})
153
  transcript = converter.convert(result.get("text", "").strip())
154
  return transcript, ""
155
 
156
 
157
+ def transcribe_whisper(model_id, language, audio_path, device_sel, enable_diar):
158
  if device_sel == "GPU" and torch.cuda.is_available():
159
+ return _transcribe_whisper_gpu(model_id, language, audio_path, enable_diar)
160
+ return _transcribe_whisper_cpu(model_id, language, audio_path, enable_diar)
161
+
162
 
163
  # —————— SenseVoice Transcription ——————
164
+ def _transcribe_sense_cpu(model_id: str,
165
+ language: str,
166
+ audio_path: str,
167
+ enable_punct: bool,
168
+ enable_diar: bool):
169
  model = get_sense_model(model_id, "cpu")
170
+ # Diarization-only branch
171
  if enable_diar:
172
  diarizer = get_diarization_pipe()
 
 
173
  diary = diarizer(audio_path)
174
  snippets = []
175
  for turn, _, speaker in diary.itertracks(yield_label=True):
 
178
  segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
179
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
180
  segment.export(tmp.name, format="wav")
181
+ segs = model.generate(
182
+ input=tmp.name,
183
+ cache={},
184
+ language=language,
185
+ use_itn=True,
186
+ batch_size_s=300,
187
+ merge_vad=False,
188
+ merge_length_s=0,
189
+ )
190
  os.unlink(tmp.name)
191
  txt = rich_transcription_postprocess(segs[0]['text'])
192
  if not enable_punct:
 
194
  txt = converter.convert(txt)
195
  snippets.append(f"[{speaker}] {txt}")
196
  return "", "\n".join(snippets)
197
+ # Raw-only branch
198
+ segs = model.generate(
199
+ input=audio_path,
200
+ cache={},
201
+ language=language,
202
+ use_itn=True,
203
+ batch_size_s=300,
204
+ merge_vad=True,
205
+ merge_length_s=15,
206
+ )
207
  text = rich_transcription_postprocess(segs[0]['text'])
208
  if not enable_punct:
209
  text = re.sub(r"[^\w\s]", "", text)
 
211
  return text, ""
212
 
213
 
214
+ @spaces.GPU(duration=100)
215
+ def _transcribe_sense_gpu(model_id: str,
216
+ language: str,
217
+ audio_path: str,
218
+ enable_punct: bool,
219
+ enable_diar: bool):
220
  model = get_sense_model(model_id, "cuda:0")
221
+ # Diarization-only branch
222
  if enable_diar:
223
  diarizer = get_diarization_pipe()
 
 
224
  diary = diarizer(audio_path)
225
  snippets = []
226
  for turn, _, speaker in diary.itertracks(yield_label=True):
 
229
  segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
230
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
231
  segment.export(tmp.name, format="wav")
232
+ segs = model.generate(
233
+ input=tmp.name,
234
+ cache={},
235
+ language=language,
236
+ use_itn=True,
237
+ batch_size_s=300,
238
+ merge_vad=False,
239
+ merge_length_s=0,
240
+ )
241
  os.unlink(tmp.name)
242
  txt = rich_transcription_postprocess(segs[0]['text'])
243
  if not enable_punct:
 
245
  txt = converter.convert(txt)
246
  snippets.append(f"[{speaker}] {txt}")
247
  return "", "\n".join(snippets)
248
+ # Raw-only branch
249
+ segs = model.generate(
250
+ input=audio_path,
251
+ cache={},
252
+ language=language,
253
+ use_itn=True,
254
+ batch_size_s=300,
255
+ merge_vad=True,
256
+ merge_length_s=15,
257
+ )
258
  text = rich_transcription_postprocess(segs[0]['text'])
259
  if not enable_punct:
260
  text = re.sub(r"[^\w\s]", "", text)
 
262
  return text, ""
263
 
264
 
265
+ def transcribe_sense(model_id: str,
266
+ language: str,
267
+ audio_path: str,
268
+ enable_punct: bool,
269
+ enable_diar: bool,
270
+ device_sel: str):
271
  if device_sel == "GPU" and torch.cuda.is_available():
272
+ return _transcribe_sense_gpu(model_id, language, audio_path, enable_punct, enable_diar)
273
+ return _transcribe_sense_cpu(model_id, language, audio_path, enable_punct, enable_diar)
274
+
275
 
276
  # —————— Gradio UI ——————
277
  Demo = gr.Blocks()
278
  with Demo:
279
+ gr.Markdown("## Whisper vs. SenseVoice (Language, Device & Diarization with Simplified→Traditional Chinese)")
280
+
281
  audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
282
+
283
+ # Examples
284
  examples = gr.Examples(
285
+ examples=[
286
+ ["interview.mp3"],
287
+ ["news.mp3"]
288
+ ],
289
+ inputs=[audio_input],
290
+ label="Example Audio Files"
291
  )
292
+
293
  with gr.Row():
294
  with gr.Column():
295
  gr.Markdown("### Whisper ASR")
296
  whisper_dd = gr.Dropdown(choices=WHISPER_MODELS, value=WHISPER_MODELS[0], label="Whisper Model")
297
  whisper_lang = gr.Dropdown(choices=WHISPER_LANGUAGES, value="auto", label="Whisper Language")
298
+ device_radio = gr.Radio(choices=["GPU", "CPU"], value="GPU", label="Device")
 
299
  diar_check = gr.Checkbox(label="Enable Diarization", value=True)
300
  out_w = gr.Textbox(label="Transcript", visible=False)
301
  out_w_d = gr.Textbox(label="Diarized Transcript", visible=True)
302
+ # Toggle visibility based on checkbox
303
  diar_check.change(lambda e: gr.update(visible=not e), inputs=diar_check, outputs=out_w)
304
  diar_check.change(lambda e: gr.update(visible=e), inputs=diar_check, outputs=out_w_d)
305
  btn_w = gr.Button("Transcribe with Whisper")
306
  btn_w.click(fn=transcribe_whisper,
307
+ inputs=[whisper_dd, whisper_lang, audio_input, device_radio, diar_check],
308
+ outputs=[out_w, out_w_d])
309
+
310
  with gr.Column():
311
  gr.Markdown("### FunASR SenseVoice ASR")
312
  sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model")
313
  sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language")
314
+ device_radio_sense = gr.Radio(choices=["GPU", "CPU"], value="GPU", label="Device")
 
315
  punct_chk = gr.Checkbox(label="Enable Punctuation", value=True)
316
  diar_s_chk = gr.Checkbox(label="Enable Diarization", value=True)
317
  out_s = gr.Textbox(label="Transcript", visible=False)
318
  out_s_d = gr.Textbox(label="Diarized Transcript", visible=True)
319
+ # Toggle visibility
320
  diar_s_chk.change(lambda e: gr.update(visible=not e), inputs=diar_s_chk, outputs=out_s)
321
  diar_s_chk.change(lambda e: gr.update(visible=e), inputs=diar_s_chk, outputs=out_s_d)
322
  btn_s = gr.Button("Transcribe with SenseVoice")
323
  btn_s.click(fn=transcribe_sense,
324
+ inputs=[sense_dd, sense_lang, audio_input, punct_chk, diar_s_chk, device_radio_sense],
325
+ outputs=[out_s, out_s_d])
326
+
327
  if __name__ == "__main__":
328
  Demo.launch()