Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
import tempfile | |
import torch | |
import gradio as gr | |
from transformers import pipeline | |
from faster_whisper import WhisperModel | |
from pydub import AudioSegment | |
from pyannote.audio import Pipeline as DiarizationPipeline | |
import opencc | |
import spaces # zeroGPU support | |
from funasr import AutoModel | |
from funasr.utils.postprocess_utils import rich_transcription_postprocess | |
from termcolor import cprint | |
# —————— Model Lists —————— | |
WHISPER_MODELS = [ | |
"guillaumekln/faster-whisper-tiny", | |
"Systran/faster-whisper-large-v3", | |
"deepdml/faster-whisper-large-v3-turbo-ct2", | |
"XA9/Belle-faster-whisper-large-v3-zh-punct", | |
"asadfgglie/faster-whisper-large-v3-zh-TW", | |
"guillaumekln/faster-whisper-medium", | |
"guillaumekln/faster-whisper-small", | |
"guillaumekln/faster-whisper-base", | |
] | |
SENSEVOICE_MODELS = [ | |
"FunAudioLLM/SenseVoiceSmall", | |
"AXERA-TECH/SenseVoice", | |
"alextomcat/SenseVoiceSmall", | |
"ChenChenyu/SenseVoiceSmall-finetuned", | |
"apinge/sensevoice-small", | |
] | |
# —————— Language Options —————— | |
WHISPER_LANGUAGES = [ | |
"auto", "af","am","ar","as","az","ba","be","bg","bn","bo", | |
"br","bs","ca","cs","cy","da","de","el","en","es","et", | |
"eu","fa","fi","fo","fr","gl","gu","ha","haw","he","hi", | |
"hr","ht","hu","hy","id","is","it","ja","jw","ka","kk", | |
"km","kn","ko","la","lb","ln","lo","lt","lv","mg","mi", | |
"mk","ml","mn","mr","ms","mt","my","ne","nl","nn","no", | |
"oc","pa","pl","ps","pt","ro","ru","sa","sd","si","sk", | |
"sl","sn","so","sq","sr","su","sv","sw","ta","te","tg", | |
"th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo", | |
"zh","yue" | |
] | |
SENSEVOICE_LANGUAGES = ["auto", "zh", "yue", "en", "ja", "ko", "nospeech"] | |
# —————— Caches —————— | |
whisper_pipes = {} | |
sense_models = {} | |
dar_pipe = None | |
converter = opencc.OpenCC('s2t') | |
# —————— Diarization Formatter —————— | |
def format_diarization_html(snippets): | |
palette = ["#e74c3c", "#3498db", "#27ae60", "#e67e22", "#9b59b6", "#16a085", "#f1c40f"] | |
speaker_colors = {} | |
html_lines = [] | |
last_spk = None | |
for s in snippets: | |
if s.startswith("[") and "]" in s: | |
spk, txt = s[1:].split("]", 1) | |
spk, txt = spk.strip(), txt.strip() | |
else: | |
spk, txt = "", s.strip() | |
# hide empty lines | |
if not txt: | |
continue | |
# assign color if new speaker | |
if spk not in speaker_colors: | |
speaker_colors[spk] = palette[len(speaker_colors) % len(palette)] | |
color = speaker_colors[spk] | |
# simplify tag for same speaker | |
if spk == last_spk: | |
display = txt | |
else: | |
display = f"<strong>{spk}:</strong> {txt}" | |
last_spk = spk | |
html_lines.append( | |
f"<p style='margin:4px 0; font-family:monospace; color:{color};'>{display}</p>" | |
) | |
return "<div>" + "".join(html_lines) + "</div>" | |
# —————— Helpers —————— | |
def get_whisper_pipe(model_id: str, device: int): | |
key = (model_id, device) | |
if key not in whisper_pipes: | |
whisper_pipes[key] = pipeline( | |
"automatic-speech-recognition", | |
model=model_id, | |
device=device, | |
chunk_length_s=30, | |
stride_length_s=5, | |
return_timestamps=False, | |
) | |
return whisper_pipes[key] | |
# —————— Faster-Whisper Cache & Factory —————— | |
_fwhisper_models: dict[tuple[str, str], WhisperModel] = {} | |
def get_fwhisper_model(model_id: str, device: str) -> WhisperModel: | |
""" | |
Lazily load and cache WhisperModel(model_id) on 'cpu' or 'cuda:0'. | |
Uses float16 on GPU and int8 on CPU for speed. | |
""" | |
key = (model_id, device) | |
if key not in _fwhisper_models: | |
compute_type = "float16" if device.startswith("cuda") else "int8" | |
_fwhisper_models[key] = WhisperModel( | |
model_id, | |
device=device, | |
compute_type=compute_type, | |
) | |
return _fwhisper_models[key] | |
def get_sense_model(model_id: str, device_str: str): | |
key = (model_id, device_str) | |
if key not in sense_models: | |
sense_models[key] = AutoModel( | |
model=model_id, | |
vad_model="fsmn-vad", | |
vad_kwargs={"max_single_segment_time": 300000}, | |
device=device_str, | |
hub="hf", | |
) | |
return sense_models[key] | |
def get_diarization_pipe(): | |
global dar_pipe | |
if dar_pipe is None: | |
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") | |
try: | |
dar_pipe = DiarizationPipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", | |
use_auth_token=token or True | |
) | |
except Exception as e: | |
print(f"Failed to load pyannote/speaker-diarization-3.1: {e}\nFalling back to pyannote/speaker-diarization@2.1.") | |
dar_pipe = DiarizationPipeline.from_pretrained( | |
"pyannote/speaker-diarization@2.1", | |
use_auth_token=token or True | |
) | |
dar_pipe.to(torch.device("cpu")) | |
return dar_pipe | |
# —————— Whisper Transcription —————— | |
def transcribe_with_fwhisper(model: WhisperModel, audio_path: str, language: str) -> str: | |
""" | |
Runs faster-whisper's .transcribe(), then concatenates all segments. | |
If language == "auto", detection is automatic. | |
""" | |
lang_arg = None if language == "auto" else language | |
segments, _ = model.transcribe( | |
audio_path, | |
beam_size=5, | |
best_of=5, | |
language=lang_arg, | |
) | |
return "".join(seg.text for seg in segments).strip() | |
def _transcribe_whisper_cpu(model_id, language, audio_path, enable_diar): | |
model = get_fwhisper_model(model_id, "cpu") | |
cprint('Whisper (faster-whisper) using CPU', 'red') | |
# Diarization-only branch | |
if enable_diar: | |
diarizer = get_diarization_pipe() | |
diary = diarizer(audio_path) | |
snippets = [] | |
for turn, _, speaker in diary.itertracks(yield_label=True): | |
start_ms = int(turn.start * 1000) | |
end_ms = int(turn.end * 1000) | |
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
segment.export(tmp.name, format="wav") | |
txt = transcribe_with_fwhisper(model, tmp.name, language) | |
os.unlink(tmp.name) | |
text = converter.convert(txt.strip()) | |
snippets.append(f"[{speaker}] {text}") | |
return "", format_diarization_html(snippets) | |
# Raw-only branch | |
text = transcribe_with_fwhisper(model, audio_path, language) | |
transcript = converter.convert(text.strip()) | |
return transcript, "" | |
def _transcribe_whisper_gpu(model_id, language, audio_path, enable_diar): | |
pipe = get_fwhisper_model(model_id, "cuda") | |
cprint('Whisper (faster-whisper) using CUDA', 'green') | |
# Diarization-only branch | |
if enable_diar: | |
diarizer = get_diarization_pipe() | |
diary = diarizer(audio_path) | |
snippets = [] | |
for turn, _, speaker in diary.itertracks(yield_label=True): | |
start_ms = int(turn.start * 1000) | |
end_ms = int(turn.end * 1000) | |
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
segment.export(tmp.name, format="wav") | |
txt = transcribe_with_fwhisper(pipe, tmp.name, language) | |
os.unlink(tmp.name) | |
text = converter.convert(txt.strip()) | |
snippets.append(f"[{speaker}] {text}") | |
return "", format_diarization_html(snippets) | |
# Raw-only branch | |
text = transcribe_with_fwhisper(pipe, tmp.name, language) | |
transcript = converter.convert(text.strip()) | |
return transcript, "" | |
def transcribe_whisper(model_id, language, audio_path, device_sel, enable_diar): | |
if device_sel == "GPU" and torch.cuda.is_available(): | |
return _transcribe_whisper_gpu(model_id, language, audio_path, enable_diar) | |
return _transcribe_whisper_cpu(model_id, language, audio_path, enable_diar) | |
# —————— SenseVoice Transcription —————— | |
def _transcribe_sense_cpu(model_id: str, | |
language: str, | |
audio_path: str, | |
enable_punct: bool, | |
enable_diar: bool): | |
model = get_sense_model(model_id, "cpu") | |
# Diarization-only branch | |
if enable_diar: | |
diarizer = get_diarization_pipe() | |
diary = diarizer(audio_path) | |
snippets = [] | |
for turn, _, speaker in diary.itertracks(yield_label=True): | |
start_ms = int(turn.start * 1000) | |
end_ms = int(turn.end * 1000) | |
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
segment.export(tmp.name, format="wav") | |
segs = model.generate( | |
input=tmp.name, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=300, | |
merge_vad=False, | |
merge_length_s=0, | |
) | |
os.unlink(tmp.name) | |
txt = rich_transcription_postprocess(segs[0]['text']) | |
if not enable_punct: | |
txt = re.sub(r"[^\w\s]", "", txt) | |
txt = converter.convert(txt) | |
snippets.append(f"[{speaker}] {txt}") | |
return "", format_diarization_html(snippets) | |
# Raw-only branch | |
segs = model.generate( | |
input=audio_path, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=300, | |
merge_vad=True, | |
merge_length_s=15, | |
) | |
text = rich_transcription_postprocess(segs[0]['text']) | |
if not enable_punct: | |
text = re.sub(r"[^\w\s]", "", text) | |
text = converter.convert(text) | |
return text, "" | |
def _transcribe_sense_gpu(model_id: str, | |
language: str, | |
audio_path: str, | |
enable_punct: bool, | |
enable_diar: bool): | |
model = get_sense_model(model_id, "cuda:0") | |
# Diarization-only branch | |
if enable_diar: | |
diarizer = get_diarization_pipe() | |
diary = diarizer(audio_path) | |
snippets = [] | |
for turn, _, speaker in diary.itertracks(yield_label=True): | |
start_ms = int(turn.start * 1000) | |
end_ms = int(turn.end * 1000) | |
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
segment.export(tmp.name, format="wav") | |
segs = model.generate( | |
input=tmp.name, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=300, | |
merge_vad=False, | |
merge_length_s=0, | |
) | |
os.unlink(tmp.name) | |
txt = rich_transcription_postprocess(segs[0]['text']) | |
if not enable_punct: | |
txt = re.sub(r"[^\w\s]", "", txt) | |
txt = converter.convert(txt) | |
snippets.append(f"[{speaker}] {txt}") | |
return "", format_diarization_html(snippets) | |
# Raw-only branch | |
segs = model.generate( | |
input=audio_path, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=300, | |
merge_vad=True, | |
merge_length_s=15, | |
) | |
text = rich_transcription_postprocess(segs[0]['text']) | |
if not enable_punct: | |
text = re.sub(r"[^\w\s]", "", text) | |
text = converter.convert(text) | |
return text, "" | |
def transcribe_sense(model_id: str, | |
language: str, | |
audio_path: str, | |
enable_punct: bool, | |
enable_diar: bool, | |
device_sel: str): | |
if device_sel == "GPU" and torch.cuda.is_available(): | |
return _transcribe_sense_gpu(model_id, language, audio_path, enable_punct, enable_diar) | |
return _transcribe_sense_cpu(model_id, language, audio_path, enable_punct, enable_diar) | |
# —————— Gradio UI —————— | |
DEMO_CSS = """ | |
.diar { | |
min-height: 100px !important; | |
max-height: 300px; | |
overflow-y: auto; | |
padding: 8px; | |
border: 1px solid #444; | |
} | |
""" | |
Demo = gr.Blocks(css=DEMO_CSS) | |
with Demo: | |
gr.Markdown("## Whisper vs. SenseVoice (Language, Device & Diarization with Simplified→Traditional Chinese)") | |
audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input") | |
# Examples | |
examples = gr.Examples( | |
examples=[ | |
["interview.mp3"], | |
["news.mp3"] | |
], | |
inputs=[audio_input], | |
label="Example Audio Files" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Whisper ASR") | |
whisper_dd = gr.Dropdown(choices=WHISPER_MODELS, value=WHISPER_MODELS[0], label="Whisper Model") | |
whisper_lang = gr.Dropdown(choices=WHISPER_LANGUAGES, value="auto", label="Whisper Language") | |
device_radio = gr.Radio(choices=["GPU", "CPU"], value="GPU", label="Device") | |
diar_check = gr.Checkbox(label="Enable Diarization", value=True) | |
out_w = gr.Textbox(label="Transcript", visible=False) | |
out_w_d = gr.HTML(label="Diarized Transcript", visible=True, elem_classes=["diar"]) | |
# Toggle visibility based on checkbox | |
diar_check.change(lambda e: gr.update(visible=not e), inputs=diar_check, outputs=out_w) | |
diar_check.change(lambda e: gr.update(visible=e), inputs=diar_check, outputs=out_w_d) | |
btn_w = gr.Button("Transcribe with Whisper") | |
btn_w.click(fn=transcribe_whisper, | |
inputs=[whisper_dd, whisper_lang, audio_input, device_radio, diar_check], | |
outputs=[out_w, out_w_d]) | |
with gr.Column(): | |
gr.Markdown("### FunASR SenseVoice ASR") | |
sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model") | |
sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language") | |
device_radio_sense = gr.Radio(choices=["GPU", "CPU"], value="GPU", label="Device") | |
punct_chk = gr.Checkbox(label="Enable Punctuation", value=True) | |
diar_s_chk = gr.Checkbox(label="Enable Diarization", value=True) | |
out_s = gr.Textbox(label="Transcript", visible=False) | |
out_s_d = gr.HTML(label="Diarized Transcript", visible=True, elem_classes=["diar"]) | |
# Toggle visibility | |
diar_s_chk.change(lambda e: gr.update(visible=not e), inputs=diar_s_chk, outputs=out_s) | |
diar_s_chk.change(lambda e: gr.update(visible=e), inputs=diar_s_chk, outputs=out_s_d) | |
btn_s = gr.Button("Transcribe with SenseVoice") | |
btn_s.click(fn=transcribe_sense, | |
inputs=[sense_dd, sense_lang, audio_input, punct_chk, diar_s_chk, device_radio_sense], | |
outputs=[out_s, out_s_d]) | |
if __name__ == "__main__": | |
Demo.launch() | |