Luigi commited on
Commit
38f97a7
·
1 Parent(s): 95897a7

add spk diarization to sensevoice transscript

Browse files
Files changed (1) hide show
  1. app.py +58 -16
app.py CHANGED
@@ -55,8 +55,7 @@ SENSEVOICE_MODELS = [
55
 
56
  # —————— Language Options ——————
57
  WHISPER_LANGUAGES = [
58
- "auto",
59
- "af","am","ar","as","az","ba","be","bg","bn","bo","br","bs","ca",
60
  "cs","cy","da","de","el","en","es","et","eu","fa","fi","fo","fr",
61
  "gl","gu","ha","haw","he","hi","hr","ht","hu","hy","id","is","it",
62
  "ja","jw","ka","kk","km","kn","ko","la","lb","ln","lo","lt","lv",
@@ -111,7 +110,7 @@ def get_diarization_pipe():
111
 
112
  # —————— Transcription Functions ——————
113
  def transcribe_whisper(model_id: str, language: str, audio_path: str, device_sel: str, enable_diar: bool):
114
- # select device
115
  use_gpu = (device_sel == "GPU" and torch.cuda.is_available())
116
  device = 0 if use_gpu else -1
117
  pipe = get_whisper_pipe(model_id, device)
@@ -122,7 +121,7 @@ def transcribe_whisper(model_id: str, language: str, audio_path: str, device_sel
122
  result = pipe(audio_path, generate_kwargs={"language": language})
123
  transcript = result.get("text", "").strip()
124
  diar_text = ""
125
- # optional diarization
126
  if enable_diar:
127
  diarizer = get_diarization_pipe()
128
  diarization = diarizer(audio_path)
@@ -144,9 +143,50 @@ def transcribe_whisper(model_id: str, language: str, audio_path: str, device_sel
144
  return transcript, diar_text
145
 
146
  @spaces.GPU
147
- def transcribe_sense(model_id: str, language: str, audio_path: str, enable_punct: bool):
148
  model = get_sense_model(model_id)
149
- segments = model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  input=audio_path,
151
  cache={},
152
  language=language,
@@ -155,15 +195,15 @@ def transcribe_sense(model_id: str, language: str, audio_path: str, enable_punct
155
  merge_vad=True,
156
  merge_length_s=15,
157
  )
158
- text = rich_transcription_postprocess(segments[0]['text'])
159
  if not enable_punct:
160
- text = re.sub(r"[^\w\s]", "", text)
161
- return text
162
 
163
  # —————— Gradio UI ——————
164
  demo = gr.Blocks()
165
  with demo:
166
- gr.Markdown("## Whisper vs. SenseVoice Transcription (with Language, Device & Diarization)")
167
 
168
  audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
169
 
@@ -176,12 +216,12 @@ with demo:
176
  device_radio = gr.Radio(choices=["GPU", "CPU"], value="GPU", label="Device")
177
  diar_check = gr.Checkbox(label="Enable Speaker Diarization", value=False)
178
  whisper_btn = gr.Button("Transcribe with Whisper")
179
- out_whisper = gr.Textbox(label="Whisper Transcript")
180
- out_diar = gr.Textbox(label="Diarized Transcript (Whisper)")
181
  whisper_btn.click(
182
  fn=transcribe_whisper,
183
  inputs=[whisper_dd, whisper_lang, audio_input, device_radio, diar_check],
184
- outputs=[out_whisper, out_diar]
185
  )
186
 
187
  # SenseVoice column
@@ -190,12 +230,14 @@ with demo:
190
  sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model")
191
  sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language")
192
  punct = gr.Checkbox(label="Enable Punctuation", value=True)
 
193
  sense_btn = gr.Button("Transcribe with SenseVoice")
194
- out_sense = gr.Textbox(label="SenseVoice Transcript")
 
195
  sense_btn.click(
196
  fn=transcribe_sense,
197
- inputs=[sense_dd, sense_lang, audio_input, punct],
198
- outputs=[out_sense]
199
  )
200
 
201
  if __name__ == "__main__":
 
55
 
56
  # —————— Language Options ——————
57
  WHISPER_LANGUAGES = [
58
+ "auto", "af","am","ar","as","az","ba","be","bg","bn","bo","br","bs","ca",
 
59
  "cs","cy","da","de","el","en","es","et","eu","fa","fi","fo","fr",
60
  "gl","gu","ha","haw","he","hi","hr","ht","hu","hy","id","is","it",
61
  "ja","jw","ka","kk","km","kn","ko","la","lb","ln","lo","lt","lv",
 
110
 
111
  # —————— Transcription Functions ——————
112
  def transcribe_whisper(model_id: str, language: str, audio_path: str, device_sel: str, enable_diar: bool):
113
+ # select device for Whisper
114
  use_gpu = (device_sel == "GPU" and torch.cuda.is_available())
115
  device = 0 if use_gpu else -1
116
  pipe = get_whisper_pipe(model_id, device)
 
121
  result = pipe(audio_path, generate_kwargs={"language": language})
122
  transcript = result.get("text", "").strip()
123
  diar_text = ""
124
+ # optional diarization for Whisper
125
  if enable_diar:
126
  diarizer = get_diarization_pipe()
127
  diarization = diarizer(audio_path)
 
143
  return transcript, diar_text
144
 
145
  @spaces.GPU
146
+ def transcribe_sense(model_id: str, language: str, audio_path: str, enable_punct: bool, enable_diar: bool):
147
  model = get_sense_model(model_id)
148
+ # if no diarization, full file
149
+ if not enable_diar:
150
+ segments = model.generate(
151
+ input=audio_path,
152
+ cache={},
153
+ language=language,
154
+ use_itn=True,
155
+ batch_size_s=300,
156
+ merge_vad=True,
157
+ merge_length_s=15,
158
+ )
159
+ text = rich_transcription_postprocess(segments[0]['text'])
160
+ if not enable_punct:
161
+ text = re.sub(r"[^\w\s]", "", text)
162
+ return text, ""
163
+ # with diarization: split by speaker
164
+ diarizer = get_diarization_pipe()
165
+ diarization = diarizer(audio_path)
166
+ speaker_snippets = []
167
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
168
+ start_ms = int(turn.start * 1000)
169
+ end_ms = int(turn.end * 1000)
170
+ segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
171
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
172
+ segment.export(tmp.name, format="wav")
173
+ segments = model.generate(
174
+ input=tmp.name,
175
+ cache={},
176
+ language=language,
177
+ use_itn=True,
178
+ batch_size_s=300,
179
+ merge_vad=False,
180
+ merge_length_s=0,
181
+ )
182
+ os.unlink(tmp.name)
183
+ txt = rich_transcription_postprocess(segments[0]['text'])
184
+ if not enable_punct:
185
+ txt = re.sub(r"[^\w\s]", "", txt)
186
+ speaker_snippets.append(f"[{speaker}] {txt}")
187
+ full_text = "\n".join(speaker_snippets)
188
+ # also return full non-diarized transcript for comparison
189
+ segments_full = model.generate(
190
  input=audio_path,
191
  cache={},
192
  language=language,
 
195
  merge_vad=True,
196
  merge_length_s=15,
197
  )
198
+ text_full = rich_transcription_postprocess(segments_full[0]['text'])
199
  if not enable_punct:
200
+ text_full = re.sub(r"[^\w\s]", "", text_full)
201
+ return text_full, full_text
202
 
203
  # —————— Gradio UI ——————
204
  demo = gr.Blocks()
205
  with demo:
206
+ gr.Markdown("## Whisper vs. SenseVoice (Language, Device & Speaker Diarization)")
207
 
208
  audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
209
 
 
216
  device_radio = gr.Radio(choices=["GPU", "CPU"], value="GPU", label="Device")
217
  diar_check = gr.Checkbox(label="Enable Speaker Diarization", value=False)
218
  whisper_btn = gr.Button("Transcribe with Whisper")
219
+ out_whisper = gr.Textbox(label="Transcript")
220
+ out_whisper_diar = gr.Textbox(label="Diarized Transcript")
221
  whisper_btn.click(
222
  fn=transcribe_whisper,
223
  inputs=[whisper_dd, whisper_lang, audio_input, device_radio, diar_check],
224
+ outputs=[out_whisper, out_whisper_diar]
225
  )
226
 
227
  # SenseVoice column
 
230
  sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model")
231
  sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language")
232
  punct = gr.Checkbox(label="Enable Punctuation", value=True)
233
+ diar_sense = gr.Checkbox(label="Enable Speaker Diarization", value=False)
234
  sense_btn = gr.Button("Transcribe with SenseVoice")
235
+ out_sense = gr.Textbox(label="Transcript")
236
+ out_sense_diar = gr.Textbox(label="Diarized Transcript")
237
  sense_btn.click(
238
  fn=transcribe_sense,
239
+ inputs=[sense_dd, sense_lang, audio_input, punct, diar_sense],
240
+ outputs=[out_sense, out_sense_diar]
241
  )
242
 
243
  if __name__ == "__main__":