Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
import tempfile | |
import torch | |
import gradio as gr | |
from transformers import pipeline | |
from pydub import AudioSegment | |
from pyannote.audio import Pipeline as DiarizationPipeline | |
import spaces # zeroGPU support | |
from funasr import AutoModel | |
from funasr.utils.postprocess_utils import rich_transcription_postprocess | |
# —————— Model Lists —————— | |
WHISPER_MODELS = [ | |
"openai/whisper-large-v3-turbo", | |
"openai/whisper-large-v3", | |
"openai/whisper-tiny", | |
"openai/whisper-small", | |
"openai/whisper-medium", | |
"openai/whisper-base", | |
"JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW", | |
"Jingmiao/whisper-small-zh_tw", | |
"DDTChen/whisper-medium-zh-tw", | |
"kimbochen/whisper-small-zh-tw", | |
"JacobLinCool/whisper-large-v3-turbo-zh-TW-clean-1", | |
"JunWorks/whisper-small-zhTW", | |
"WANGTINGTING/whisper-large-v2-zh-TW-vol2", | |
"xmzhu/whisper-tiny-zh-TW", | |
"ingrenn/whisper-small-common-voice-13-zh-TW", | |
"jun-han/whisper-small-zh-TW", | |
"xmzhu/whisper-tiny-zh-TW-baseline", | |
"JacobLinCool/whisper-large-v3-turbo-common_voice_16_1-zh-TW-2", | |
"JacobLinCool/whisper-large-v3-common_voice_19_0-zh-TW-full-1", | |
"momo103197/whisper-small-zh-TW-mix", | |
"JacobLinCool/whisper-large-v3-turbo-zh-TW-clean-1-merged", | |
"JacobLinCool/whisper-large-v2-common_voice_19_0-zh-TW-full-1", | |
"kimas1269/whisper-meduim_zhtw", | |
"JunWorks/whisper-base-zhTW", | |
"JunWorks/whisper-small-zhTW-frozenDecoder", | |
"sandy1990418/whisper-large-v3-turbo-zh-tw", | |
"JacobLinCool/whisper-large-v3-turbo-common_voice_16_1-zh-TW-pissa-merged", | |
"momo103197/whisper-small-zh-TW-16", | |
"k1nto/Belle-whisper-large-v3-zh-punct-ct2" | |
] | |
SENSEVOICE_MODELS = [ | |
"FunAudioLLM/SenseVoiceSmall", | |
"AXERA-TECH/SenseVoice", | |
"alextomcat/SenseVoiceSmall", | |
"ChenChenyu/SenseVoiceSmall-finetuned", | |
"apinge/sensevoice-small", | |
] | |
# —————— Language Options —————— | |
WHISPER_LANGUAGES = [ | |
"auto", "af","am","ar","as","az","ba","be","bg","bn","bo", | |
"br","bs","ca","cs","cy","da","de","el","en","es","et", | |
"eu","fa","fi","fo","fr","gl","gu","ha","haw","he","hi", | |
"hr","ht","hu","hy","id","is","it","ja","jw","ka","kk", | |
"km","kn","ko","la","lb","ln","lo","lt","lv","mg","mi", | |
"mk","ml","mn","mr","ms","mt","my","ne","nl","nn","no", | |
"oc","pa","pl","ps","pt","ro","ru","sa","sd","si","sk", | |
"sl","sn","so","sq","sr","su","sv","sw","ta","te","tg", | |
"th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo", | |
"zh","yue" | |
] | |
SENSEVOICE_LANGUAGES = ["auto", "zh", "yue", "en", "ja", "ko", "nospeech"] | |
# —————— Caches —————— | |
whisper_pipes = {} | |
sense_models = {} | |
dar_pipe = None | |
# —————— Helpers —————— | |
def get_whisper_pipe(model_id: str, device: int): | |
key = (model_id, device) | |
if key not in whisper_pipes: | |
whisper_pipes[key] = pipeline( | |
"automatic-speech-recognition", | |
model=model_id, | |
device=device, | |
chunk_length_s=30, | |
stride_length_s=5, | |
return_timestamps=False, | |
) | |
return whisper_pipes[key] | |
def get_sense_model(model_id: str): | |
if model_id not in sense_models: | |
device_str = "cuda:0" if torch.cuda.is_available() else "cpu" | |
sense_models[model_id] = AutoModel( | |
model=model_id, | |
vad_model="fsmn-vad", | |
vad_kwargs={"max_single_segment_time": 300000}, | |
device=device_str, | |
hub="hf", | |
) | |
return sense_models[model_id] | |
def get_diarization_pipe(): | |
global dar_pipe | |
if dar_pipe is None: | |
# Pull token from environment (HF_TOKEN or HUGGINGFACE_TOKEN) | |
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") | |
dar_pipe = DiarizationPipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", | |
use_auth_token=token or True | |
) | |
return dar_pipe | |
# —————— Transcription Functions —————— | |
def transcribe_whisper(model_id: str, | |
language: str, | |
audio_path: str, | |
device_sel: str, | |
enable_diar: bool): | |
# select device: 0 for GPU, -1 for CPU | |
use_gpu = (device_sel == "GPU" and torch.cuda.is_available()) | |
device = 0 if use_gpu else -1 | |
pipe = get_whisper_pipe(model_id, device) | |
# full transcription | |
result = (pipe(audio_path) if language == "auto" | |
else pipe(audio_path, generate_kwargs={"language": language})) | |
transcript = result.get("text", "").strip() | |
diar_text = "" | |
# optional speaker diarization | |
if enable_diar: | |
diarizer = get_diarization_pipe() | |
diary = diarizer(audio_path) | |
snippets = [] | |
for turn, _, speaker in diary.itertracks(yield_label=True): | |
start_ms, end_ms = int(turn.start*1000), int(turn.end*1000) | |
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
segment.export(tmp.name, format="wav") | |
seg_out = (pipe(tmp.name) if language == "auto" | |
else pipe(tmp.name, generate_kwargs={"language": language})) | |
os.unlink(tmp.name) | |
text = seg_out.get("text", "").strip() | |
snippets.append(f"[{speaker}] {text}") | |
diar_text = "\n".join(snippets) | |
return transcript, diar_text | |
def transcribe_sense(model_id: str, | |
language: str, | |
audio_path: str, | |
enable_punct: bool, | |
enable_diar: bool): | |
model = get_sense_model(model_id) | |
# no diarization | |
if not enable_diar: | |
segs = model.generate( | |
input=audio_path, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=300, | |
merge_vad=True, | |
merge_length_s=15, | |
) | |
text = rich_transcription_postprocess(segs[0]['text']) | |
if not enable_punct: | |
text = re.sub(r"[^\w\s]", "", text) | |
return text, "" | |
# with diarization | |
diarizer = get_diarization_pipe() | |
diary = diarizer(audio_path) | |
snippets = [] | |
for turn, _, speaker in diary.itertracks(yield_label=True): | |
start_ms, end_ms = int(turn.start*1000), int(turn.end*1000) | |
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
segment.export(tmp.name, format="wav") | |
segs = model.generate( | |
input=tmp.name, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=300, | |
merge_vad=False, | |
merge_length_s=0, | |
) | |
os.unlink(tmp.name) | |
txt = rich_transcription_postprocess(segs[0]['text']) | |
if not enable_punct: | |
txt = re.sub(r"[^\w\s]", "", txt) | |
snippets.append(f"[{speaker}] {txt}") | |
full = rich_transcription_postprocess(model.generate( | |
input=audio_path, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=300, | |
merge_vad=True, | |
merge_length_s=15 | |
)[0]['text']) | |
if not enable_punct: | |
full = re.sub(r"[^\w\s]", "", full) | |
return full, "\n".join(snippets) | |
# —————— Gradio UI —————— | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("## Whisper vs. SenseVoice (Language, Device & Diarization)") | |
audio_input = gr.Audio(sources=["upload","microphone"], type="filepath", label="Audio Input") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Whisper ASR") | |
whisper_dd = gr.Dropdown(choices=WHISPER_MODELS, value=WHISPER_MODELS[0], label="Whisper Model") | |
whisper_lang = gr.Dropdown(choices=WHISPER_LANGUAGES, value="auto", label="Whisper Language") | |
device_radio = gr.Radio(choices=["GPU","CPU"], value="GPU", label="Device") | |
diar_check = gr.Checkbox(label="Enable Diarization") | |
btn_w = gr.Button("Transcribe with Whisper") | |
out_w = gr.Textbox(label="Transcript") | |
out_w_d = gr.Textbox(label="Diarized Transcript") | |
btn_w.click(fn=transcribe_whisper, | |
inputs=[whisper_dd, whisper_lang, audio_input, device_radio, diar_check], | |
outputs=[out_w, out_w_d]) | |
with gr.Column(): | |
gr.Markdown("### FunASR SenseVoice ASR") | |
sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model") | |
sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language") | |
punct = gr.Checkbox(label="Enable Punctuation", value=True) | |
diar_s = gr.Checkbox(label="Enable Diarization") | |
btn_s = gr.Button("Transcribe with SenseVoice") | |
out_s = gr.Textbox(label="Transcript") | |
out_s_d = gr.Textbox(label="Diarized Transcript") | |
btn_s.click(fn=transcribe_sense, | |
inputs=[sense_dd, sense_lang, audio_input, punct, diar_s], | |
outputs=[out_s, out_s_d]) | |
if __name__ == "__main__": | |
demo.launch() | |