Spaces:
Running
on
Zero
Running
on
Zero
add spk diarization to sensevoice transscript
Browse files
app.py
CHANGED
@@ -55,8 +55,7 @@ SENSEVOICE_MODELS = [
|
|
55 |
|
56 |
# —————— Language Options ——————
|
57 |
WHISPER_LANGUAGES = [
|
58 |
-
"auto",
|
59 |
-
"af","am","ar","as","az","ba","be","bg","bn","bo","br","bs","ca",
|
60 |
"cs","cy","da","de","el","en","es","et","eu","fa","fi","fo","fr",
|
61 |
"gl","gu","ha","haw","he","hi","hr","ht","hu","hy","id","is","it",
|
62 |
"ja","jw","ka","kk","km","kn","ko","la","lb","ln","lo","lt","lv",
|
@@ -111,7 +110,7 @@ def get_diarization_pipe():
|
|
111 |
|
112 |
# —————— Transcription Functions ——————
|
113 |
def transcribe_whisper(model_id: str, language: str, audio_path: str, device_sel: str, enable_diar: bool):
|
114 |
-
# select device
|
115 |
use_gpu = (device_sel == "GPU" and torch.cuda.is_available())
|
116 |
device = 0 if use_gpu else -1
|
117 |
pipe = get_whisper_pipe(model_id, device)
|
@@ -122,7 +121,7 @@ def transcribe_whisper(model_id: str, language: str, audio_path: str, device_sel
|
|
122 |
result = pipe(audio_path, generate_kwargs={"language": language})
|
123 |
transcript = result.get("text", "").strip()
|
124 |
diar_text = ""
|
125 |
-
# optional diarization
|
126 |
if enable_diar:
|
127 |
diarizer = get_diarization_pipe()
|
128 |
diarization = diarizer(audio_path)
|
@@ -144,9 +143,50 @@ def transcribe_whisper(model_id: str, language: str, audio_path: str, device_sel
|
|
144 |
return transcript, diar_text
|
145 |
|
146 |
@spaces.GPU
|
147 |
-
def transcribe_sense(model_id: str, language: str, audio_path: str, enable_punct: bool):
|
148 |
model = get_sense_model(model_id)
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
input=audio_path,
|
151 |
cache={},
|
152 |
language=language,
|
@@ -155,15 +195,15 @@ def transcribe_sense(model_id: str, language: str, audio_path: str, enable_punct
|
|
155 |
merge_vad=True,
|
156 |
merge_length_s=15,
|
157 |
)
|
158 |
-
|
159 |
if not enable_punct:
|
160 |
-
|
161 |
-
return
|
162 |
|
163 |
# —————— Gradio UI ——————
|
164 |
demo = gr.Blocks()
|
165 |
with demo:
|
166 |
-
gr.Markdown("## Whisper vs. SenseVoice
|
167 |
|
168 |
audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
|
169 |
|
@@ -176,12 +216,12 @@ with demo:
|
|
176 |
device_radio = gr.Radio(choices=["GPU", "CPU"], value="GPU", label="Device")
|
177 |
diar_check = gr.Checkbox(label="Enable Speaker Diarization", value=False)
|
178 |
whisper_btn = gr.Button("Transcribe with Whisper")
|
179 |
-
out_whisper = gr.Textbox(label="
|
180 |
-
|
181 |
whisper_btn.click(
|
182 |
fn=transcribe_whisper,
|
183 |
inputs=[whisper_dd, whisper_lang, audio_input, device_radio, diar_check],
|
184 |
-
outputs=[out_whisper,
|
185 |
)
|
186 |
|
187 |
# SenseVoice column
|
@@ -190,12 +230,14 @@ with demo:
|
|
190 |
sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model")
|
191 |
sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language")
|
192 |
punct = gr.Checkbox(label="Enable Punctuation", value=True)
|
|
|
193 |
sense_btn = gr.Button("Transcribe with SenseVoice")
|
194 |
-
out_sense = gr.Textbox(label="
|
|
|
195 |
sense_btn.click(
|
196 |
fn=transcribe_sense,
|
197 |
-
inputs=[sense_dd, sense_lang, audio_input, punct],
|
198 |
-
outputs=[out_sense]
|
199 |
)
|
200 |
|
201 |
if __name__ == "__main__":
|
|
|
55 |
|
56 |
# —————— Language Options ——————
|
57 |
WHISPER_LANGUAGES = [
|
58 |
+
"auto", "af","am","ar","as","az","ba","be","bg","bn","bo","br","bs","ca",
|
|
|
59 |
"cs","cy","da","de","el","en","es","et","eu","fa","fi","fo","fr",
|
60 |
"gl","gu","ha","haw","he","hi","hr","ht","hu","hy","id","is","it",
|
61 |
"ja","jw","ka","kk","km","kn","ko","la","lb","ln","lo","lt","lv",
|
|
|
110 |
|
111 |
# —————— Transcription Functions ——————
|
112 |
def transcribe_whisper(model_id: str, language: str, audio_path: str, device_sel: str, enable_diar: bool):
|
113 |
+
# select device for Whisper
|
114 |
use_gpu = (device_sel == "GPU" and torch.cuda.is_available())
|
115 |
device = 0 if use_gpu else -1
|
116 |
pipe = get_whisper_pipe(model_id, device)
|
|
|
121 |
result = pipe(audio_path, generate_kwargs={"language": language})
|
122 |
transcript = result.get("text", "").strip()
|
123 |
diar_text = ""
|
124 |
+
# optional diarization for Whisper
|
125 |
if enable_diar:
|
126 |
diarizer = get_diarization_pipe()
|
127 |
diarization = diarizer(audio_path)
|
|
|
143 |
return transcript, diar_text
|
144 |
|
145 |
@spaces.GPU
|
146 |
+
def transcribe_sense(model_id: str, language: str, audio_path: str, enable_punct: bool, enable_diar: bool):
|
147 |
model = get_sense_model(model_id)
|
148 |
+
# if no diarization, full file
|
149 |
+
if not enable_diar:
|
150 |
+
segments = model.generate(
|
151 |
+
input=audio_path,
|
152 |
+
cache={},
|
153 |
+
language=language,
|
154 |
+
use_itn=True,
|
155 |
+
batch_size_s=300,
|
156 |
+
merge_vad=True,
|
157 |
+
merge_length_s=15,
|
158 |
+
)
|
159 |
+
text = rich_transcription_postprocess(segments[0]['text'])
|
160 |
+
if not enable_punct:
|
161 |
+
text = re.sub(r"[^\w\s]", "", text)
|
162 |
+
return text, ""
|
163 |
+
# with diarization: split by speaker
|
164 |
+
diarizer = get_diarization_pipe()
|
165 |
+
diarization = diarizer(audio_path)
|
166 |
+
speaker_snippets = []
|
167 |
+
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
168 |
+
start_ms = int(turn.start * 1000)
|
169 |
+
end_ms = int(turn.end * 1000)
|
170 |
+
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
|
171 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
172 |
+
segment.export(tmp.name, format="wav")
|
173 |
+
segments = model.generate(
|
174 |
+
input=tmp.name,
|
175 |
+
cache={},
|
176 |
+
language=language,
|
177 |
+
use_itn=True,
|
178 |
+
batch_size_s=300,
|
179 |
+
merge_vad=False,
|
180 |
+
merge_length_s=0,
|
181 |
+
)
|
182 |
+
os.unlink(tmp.name)
|
183 |
+
txt = rich_transcription_postprocess(segments[0]['text'])
|
184 |
+
if not enable_punct:
|
185 |
+
txt = re.sub(r"[^\w\s]", "", txt)
|
186 |
+
speaker_snippets.append(f"[{speaker}] {txt}")
|
187 |
+
full_text = "\n".join(speaker_snippets)
|
188 |
+
# also return full non-diarized transcript for comparison
|
189 |
+
segments_full = model.generate(
|
190 |
input=audio_path,
|
191 |
cache={},
|
192 |
language=language,
|
|
|
195 |
merge_vad=True,
|
196 |
merge_length_s=15,
|
197 |
)
|
198 |
+
text_full = rich_transcription_postprocess(segments_full[0]['text'])
|
199 |
if not enable_punct:
|
200 |
+
text_full = re.sub(r"[^\w\s]", "", text_full)
|
201 |
+
return text_full, full_text
|
202 |
|
203 |
# —————— Gradio UI ——————
|
204 |
demo = gr.Blocks()
|
205 |
with demo:
|
206 |
+
gr.Markdown("## Whisper vs. SenseVoice (Language, Device & Speaker Diarization)")
|
207 |
|
208 |
audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
|
209 |
|
|
|
216 |
device_radio = gr.Radio(choices=["GPU", "CPU"], value="GPU", label="Device")
|
217 |
diar_check = gr.Checkbox(label="Enable Speaker Diarization", value=False)
|
218 |
whisper_btn = gr.Button("Transcribe with Whisper")
|
219 |
+
out_whisper = gr.Textbox(label="Transcript")
|
220 |
+
out_whisper_diar = gr.Textbox(label="Diarized Transcript")
|
221 |
whisper_btn.click(
|
222 |
fn=transcribe_whisper,
|
223 |
inputs=[whisper_dd, whisper_lang, audio_input, device_radio, diar_check],
|
224 |
+
outputs=[out_whisper, out_whisper_diar]
|
225 |
)
|
226 |
|
227 |
# SenseVoice column
|
|
|
230 |
sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model")
|
231 |
sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language")
|
232 |
punct = gr.Checkbox(label="Enable Punctuation", value=True)
|
233 |
+
diar_sense = gr.Checkbox(label="Enable Speaker Diarization", value=False)
|
234 |
sense_btn = gr.Button("Transcribe with SenseVoice")
|
235 |
+
out_sense = gr.Textbox(label="Transcript")
|
236 |
+
out_sense_diar = gr.Textbox(label="Diarized Transcript")
|
237 |
sense_btn.click(
|
238 |
fn=transcribe_sense,
|
239 |
+
inputs=[sense_dd, sense_lang, audio_input, punct, diar_sense],
|
240 |
+
outputs=[out_sense, out_sense_diar]
|
241 |
)
|
242 |
|
243 |
if __name__ == "__main__":
|