import os
import re
import tempfile
import torch
import gradio as gr
from faster_whisper import WhisperModel
from pydub import AudioSegment
from pyannote.audio import Pipeline as DiarizationPipeline
import opencc
import spaces # zeroGPU support
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from termcolor import cprint
import time
import torchaudio
from pyannote.audio.pipelines.utils.hook import ProgressHook
# —————— Model Lists ——————
WHISPER_MODELS = [
"deepdml/faster-whisper-large-v3-turbo-ct2",
"guillaumekln/faster-whisper-tiny",
"Systran/faster-whisper-large-v3",
"XA9/Belle-faster-whisper-large-v3-zh-punct",
"asadfgglie/faster-whisper-large-v3-zh-TW",
"guillaumekln/faster-whisper-medium",
"guillaumekln/faster-whisper-small",
"guillaumekln/faster-whisper-base",
"Luigi/whisper-small-zh_tw-ct2",
]
SENSEVOICE_MODELS = [
"FunAudioLLM/SenseVoiceSmall",
]
# —————— Language Options ——————
WHISPER_LANGUAGES = [
"auto", "af","am","ar","as","az","ba","be","bg","bn","bo",
"br","bs","ca","cs","cy","da","de","el","en","es","et",
"eu","fa","fi","fo","fr","gl","gu","ha","haw","he","hi",
"hr","ht","hu","hy","id","is","it","ja","jw","ka","kk",
"km","kn","ko","la","lb","ln","lo","lt","lv","mg","mi",
"mk","ml","mn","mr","ms","mt","my","ne","nl","nn","no",
"oc","pa","pl","ps","pt","ro","ru","sa","sd","si","sk",
"sl","sn","so","sq","sr","su","sv","sw","ta","te","tg",
"th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo",
"zh","yue"
]
SENSEVOICE_LANGUAGES = ["auto", "zh", "yue", "en", "ja", "ko", "nospeech"]
# —————— Caches ——————
whisper_pipes = {}
sense_models = {}
dar_pipe = None
converter = opencc.OpenCC('s2t')
# —————— Diarization Formatter ——————
def format_diarization_html(snippets):
palette = ["#e74c3c", "#3498db", "#27ae60", "#e67e22", "#9b59b6", "#16a085", "#f1c40f"]
speaker_colors = {}
html_lines = []
last_spk = None
for s in snippets:
if s.startswith("[") and "]" in s:
spk, txt = s[1:].split("]", 1)
spk, txt = spk.strip(), txt.strip()
else:
spk, txt = "", s.strip()
# hide empty lines
if not txt:
continue
# assign color if new speaker
if spk not in speaker_colors:
speaker_colors[spk] = palette[len(speaker_colors) % len(palette)]
color = speaker_colors[spk]
# simplify tag for same speaker
if spk == last_spk:
display = txt
else:
display = f"{spk}: {txt}"
last_spk = spk
html_lines.append(
f"
{display}
"
)
return "" + "".join(html_lines) + "
"
# —————— Helpers ——————
# —————— Faster-Whisper Cache & Factory ——————
_fwhisper_models: dict[tuple[str, str], WhisperModel] = {}
def get_fwhisper_model(model_id: str, device: str) -> WhisperModel:
"""
Lazily load and cache WhisperModel(model_id) on 'cpu' or 'cuda:0'.
Uses float16 on GPU and int8 on CPU for speed.
"""
key = (model_id, device)
if key not in _fwhisper_models:
compute_type = "float16" if device.startswith("cuda") else "int8"
_fwhisper_models[key] = WhisperModel(
model_id,
device=device,
compute_type=compute_type,
)
return _fwhisper_models[key]
def get_sense_model(model_id: str, device_str: str):
key = (model_id, device_str)
if key not in sense_models:
sense_models[key] = AutoModel(
model=model_id,
vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 300000},
device=device_str,
hub="hf",
)
return sense_models[key]
def get_diarization_pipe():
global dar_pipe
if dar_pipe is None:
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
try:
dar_pipe = DiarizationPipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=token or True
)
except Exception as e:
print(f"Failed to load pyannote/speaker-diarization-3.1: {e}\nFalling back to pyannote/speaker-diarization@2.1.")
dar_pipe = DiarizationPipeline.from_pretrained(
"pyannote/speaker-diarization@2.1",
use_auth_token=token or True
)
return dar_pipe
# —————— Whisper Transcription ——————
def transcribe_with_fwhisper(model: WhisperModel, audio_path: str, language: str) -> str:
"""
Runs faster-whisper's .transcribe(), then concatenates all segments.
If language == "auto", detection is automatic.
"""
lang_arg = None if language == "auto" else language
segments, _ = model.transcribe(
audio_path,
beam_size=1,
best_of=1,
language=lang_arg,
vad_filter=True,
)
return "".join(seg.text for seg in segments).strip()
def _transcribe_fwhisper_cpu_stream(model_id, language, audio_path, enable_diar):
"""
Generator-based streaming transcription with accumulation using Faster-Whisper on CPU.
Yields (accumulated_text, diar_html) tuples for Gradio streaming.
"""
pipe = get_fwhisper_model(model_id, "cpu")
cprint('Whisper (faster-whisper) using CPU [stream]', 'red')
# Diarization branch: accumulate snippets and yield full HTML each turn
if enable_diar:
diarizer = get_diarization_pipe()
waveform, sample_rate = torchaudio.load(audio_path)
diarizer.to(torch.device('cpu'))
with ProgressHook() as hook:
diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook)
snippets = []
for turn, _, speaker in diary.itertracks(yield_label=True):
# extract segment
start_ms = int(turn.start * 1000)
end_ms = int(turn.end * 1000)
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
segment.export(tmp.name, format="wav")
segments, _ = pipe.transcribe(
tmp.name,
beam_size=1,
best_of=1,
language=None if language == "auto" else language,
vad_filter=True,
)
os.unlink(tmp.name)
text = converter.convert("".join(s.text for s in segments).strip())
snippets.append(f"[{speaker}] {text}")
# yield accumulated diarization HTML
yield "", format_diarization_html(snippets)
return
# Raw transcription: accumulate text segments and yield full transcript
accumulated = []
lang_arg = None if language == "auto" else language
for seg in pipe.transcribe(
audio_path,
beam_size=1,
best_of=1,
language=lang_arg,
vad_filter=True,
):
txt = converter.convert(seg.text.strip())
accumulated.append(txt)
yield "\n".join(accumulated), ""
@spaces.GPU
def _transcribe_fwhisper_gpu_stream(model_id, language, audio_path, enable_diar):
"""
Generator-based streaming transcription with accumulation using Faster-Whisper on CUDA.
Yields (accumulated_text, diar_html) tuples for Gradio streaming.
"""
pipe = get_fwhisper_model(model_id, "cuda")
cprint('Whisper (faster-whisper) using CUDA [stream]', 'green')
# Diarization branch: accumulate snippets and yield full HTML each turn
if enable_diar:
diarizer = get_diarization_pipe()
device = torch.device('cuda')
diarizer.to(device)
waveform, sample_rate = torchaudio.load(audio_path)
waveform = waveform.to(device)
with ProgressHook() as hook:
diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook)
snippets = []
for turn, _, speaker in diary.itertracks(yield_label=True):
start_ms = int(turn.start * 1000)
end_ms = int(turn.end * 1000)
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
segment.export(tmp.name, format="wav")
segments, _ = pipe.transcribe(
tmp.name,
beam_size=1,
best_of=1,
language=None if language == "auto" else language,
vad_filter=True,
)
os.unlink(tmp.name)
text = converter.convert("".join(s.text for s in segments).strip())
snippets.append(f"[{speaker}] {text}")
yield "", format_diarization_html(snippets)
return
# Raw transcription: accumulate text segments and yield full transcript
accumulated = []
lang_arg = None if language == "auto" else language
for seg in pipe.transcribe(
audio_path,
beam_size=1,
best_of=1,
language=lang_arg,
vad_filter=True,
):
txt = converter.convert(seg.text.strip())
accumulated.append(txt)
yield "\n".join(accumulated), ""
def _transcribe_fwhisper_cpu(model_id, language, audio_path, enable_diar):
model = get_fwhisper_model(model_id, "cpu")
cprint('Whisper (faster-whisper) using CPU', 'red')
# Diarization-only branch
if enable_diar:
diarizer = get_diarization_pipe()
# Pre-loading audio files in memory may result in faster processing
waveform, sample_rate = torchaudio.load(audio_path)
diarizer.to(torch.device('cpu'))
with ProgressHook() as hook:
diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook)
snippets = []
for turn, _, speaker in diary.itertracks(yield_label=True):
start_ms = int(turn.start * 1000)
end_ms = int(turn.end * 1000)
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
segment.export(tmp.name, format="wav")
txt = transcribe_with_fwhisper(model, tmp.name, language)
os.unlink(tmp.name)
text = converter.convert(txt.strip())
snippets.append(f"[{speaker}] {text}")
return "", format_diarization_html(snippets)
# Raw-only branch
text = transcribe_with_fwhisper(model, audio_path, language)
transcript = converter.convert(text.strip())
return transcript, ""
@spaces.GPU
def _transcribe_fwhisper_gpu(model_id, language, audio_path, enable_diar):
pipe = get_fwhisper_model(model_id, "cuda")
cprint('Whisper (faster-whisper) using CUDA', 'green')
# Diarization-only branch
if enable_diar:
diarizer = get_diarization_pipe()
diarizer.to(torch.device('cuda'))
# Pre-loading audio files in memory may result in faster processing
waveform, sample_rate = torchaudio.load(audio_path)
waveform.to(torch.device('cuda'))
with ProgressHook() as hook:
diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook)
snippets = []
for turn, _, speaker in diary.itertracks(yield_label=True):
start_ms = int(turn.start * 1000)
end_ms = int(turn.end * 1000)
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
segment.export(tmp.name, format="wav")
txt = transcribe_with_fwhisper(pipe, tmp.name, language)
os.unlink(tmp.name)
text = converter.convert(txt.strip())
snippets.append(f"[{speaker}] {text}")
return "", format_diarization_html(snippets)
# Raw-only branch
text = transcribe_with_fwhisper(pipe, tmp.name, language)
transcript = converter.convert(text.strip())
return transcript, ""
def transcribe_fwhisper(model_id, language, audio_path, device_sel, enable_diar):
if device_sel == "GPU" and torch.cuda.is_available():
return _transcribe_fwhisper_gpu(model_id, language, audio_path, enable_diar)
return _transcribe_fwhisper_cpu(model_id, language, audio_path, enable_diar)
def transcribe_fwhisper_stream(model_id, language, audio_path, device_sel, enable_diar):
"""Dispatch to CPU or GPU streaming generators, preserving two-value yields."""
if device_sel == "GPU" and torch.cuda.is_available():
yield from _transcribe_fwhisper_gpu_stream(model_id, language, audio_path, enable_diar)
else:
yield from _transcribe_fwhisper_cpu_stream(model_id, language, audio_path, enable_diar)
# —————— SenseVoice Transcription ——————
def _transcribe_sense_cpu_stream(model_id: str, language: str, audio_path: str,
enable_punct: bool, enable_diar: bool):
model = get_sense_model(model_id, "cpu")
cprint('SenseVoiceSmall using CPU [stream]', 'red')
if enable_diar:
diarizer = get_diarization_pipe()
diarizer.to(torch.device('cpu'))
waveform, sample_rate = torchaudio.load(audio_path)
with ProgressHook() as hook:
diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook)
snippets = []
for turn, _, speaker in diary.itertracks(yield_label=True):
start_ms, end_ms = int(turn.start*1000), int(turn.end*1000)
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
segment.export(tmp.name, format="wav")
segs = model.generate(input=tmp.name, cache={}, language=language,
use_itn=True, batch_size_s=300,
merge_vad=False, merge_length_s=0)
os.unlink(tmp.name)
txt = rich_transcription_postprocess(segs[0]['text'])
if not enable_punct:
txt = re.sub(r"[^\w\s]", "", txt)
txt = converter.convert(txt)
snippets.append(f"[{speaker}] {txt}")
yield "", format_diarization_html(snippets)
return
segs = model.generate(input=audio_path, cache={}, language=language,
use_itn=True, batch_size_s=300,
merge_vad=False, merge_length_s=0)
accumulated = []
for s in segs:
t = rich_transcription_postprocess(s['text'])
if not enable_punct:
t = re.sub(r"[^\w\s]", "", t)
t = converter.convert(t)
accumulated.append(t)
yield "\n".join(accumulated), ""
def _transcribe_sense_gpu_stream(model_id: str, language: str, audio_path: str,
enable_punct: bool, enable_diar: bool):
model = get_sense_model(model_id, "cuda:0")
cprint('SenseVoiceSmall using CUDA [stream]', 'green')
if enable_diar:
diarizer = get_diarization_pipe()
diarizer.to(torch.device('cuda'))
waveform, sample_rate = torchaudio.load(audio_path)
waveform = waveform.to(torch.device('cuda'))
with ProgressHook() as hook:
diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook)
snippets = []
for turn, _, speaker in diary.itertracks(yield_label=True):
start_ms, end_ms = int(turn.start*1000), int(turn.end*1000)
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
segment.export(tmp.name, format="wav")
segs = model.generate(input=tmp.name, cache={}, language=language,
use_itn=True, batch_size_s=300,
merge_vad=False, merge_length_s=0)
os.unlink(tmp.name)
txt = rich_transcription_postprocess(segs[0]['text'])
if not enable_punct:
txt = re.sub(r"[^\w\s]", "", txt)
txt = converter.convert(txt)
snippets.append(f"[{speaker}] {txt}")
yield "", format_diarization_html(snippets)
return
segs = model.generate(input=audio_path, cache={}, language=language,
use_itn=True, batch_size_s=300,
merge_vad=False, merge_length_s=0)
accumulated = []
for s in segs:
t = rich_transcription_postprocess(s['text'])
if not enable_punct:
t = re.sub(r"[^\w\s]", "", t)
t = converter.convert(t)
accumulated.append(t)
yield "\n".join(accumulated), ""
def _transcribe_sense_cpu(model_id: str,
language: str,
audio_path: str,
enable_punct: bool,
enable_diar: bool):
model = get_sense_model(model_id, "cpu")
# Diarization-only branch
if enable_diar:
diarizer = get_diarization_pipe()
diarizer.to(torch.device('cpu'))
# Pre-loading audio files in memory may result in faster processing
waveform, sample_rate = torchaudio.load(audio_path)
diarizer.to(torch.device('cpu'))
with ProgressHook() as hook:
diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook)
snippets = []
for turn, _, speaker in diary.itertracks(yield_label=True):
start_ms = int(turn.start * 1000)
end_ms = int(turn.end * 1000)
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
segment.export(tmp.name, format="wav")
segs = model.generate(
input=tmp.name,
cache={},
language=language,
use_itn=True,
batch_size_s=300,
merge_vad=False,
merge_length_s=0,
)
os.unlink(tmp.name)
txt = rich_transcription_postprocess(segs[0]['text'])
if not enable_punct:
txt = re.sub(r"[^\w\s]", "", txt)
txt = converter.convert(txt)
snippets.append(f"[{speaker}] {txt}")
return "", format_diarization_html(snippets)
# Raw-only branch
segs = model.generate(
input=audio_path,
cache={},
language=language,
use_itn=True,
batch_size_s=300,
merge_vad=True,
merge_length_s=15,
)
text = rich_transcription_postprocess(segs[0]['text'])
if not enable_punct:
text = re.sub(r"[^\w\s]", "", text)
text = converter.convert(text)
return text, ""
@spaces.GPU
def _transcribe_sense_gpu(model_id: str,
language: str,
audio_path: str,
enable_punct: bool,
enable_diar: bool):
model = get_sense_model(model_id, "cuda:0")
# Diarization-only branch
if enable_diar:
diarizer = get_diarization_pipe()
diarizer.to(torch.device('cuda'))
# Pre-loading audio files in memory may result in faster processing
waveform, sample_rate = torchaudio.load(audio_path)
waveform.to(torch.device('cuda'))
with ProgressHook() as hook:
diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook)
snippets = []
for turn, _, speaker in diary.itertracks(yield_label=True):
start_ms = int(turn.start * 1000)
end_ms = int(turn.end * 1000)
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
segment.export(tmp.name, format="wav")
segs = model.generate(
input=tmp.name,
cache={},
language=language,
use_itn=True,
batch_size_s=300,
merge_vad=False,
merge_length_s=0,
)
os.unlink(tmp.name)
txt = rich_transcription_postprocess(segs[0]['text'])
if not enable_punct:
txt = re.sub(r"[^\w\s]", "", txt)
txt = converter.convert(txt)
snippets.append(f"[{speaker}] {txt}")
return "", format_diarization_html(snippets)
# Raw-only branch
segs = model.generate(
input=audio_path,
cache={},
language=language,
use_itn=True,
batch_size_s=300,
merge_vad=True,
merge_length_s=15,
)
text = rich_transcription_postprocess(segs[0]['text'])
if not enable_punct:
text = re.sub(r"[^\w\s]", "", text)
text = converter.convert(text)
return text, ""
def transcribe_sense(model_id: str,
language: str,
audio_path: str,
enable_punct: bool,
enable_diar: bool,
device_sel: str):
if device_sel == "GPU" and torch.cuda.is_available():
return _transcribe_sense_gpu(model_id, language, audio_path, enable_punct, enable_diar)
return _transcribe_sense_cpu(model_id, language, audio_path, enable_punct, enable_diar)
def transcribe_sense_steam(model_id: str,
language: str,
audio_path: str,
enable_punct: bool,
enable_diar: bool,
device_sel: str):
if device_sel == "GPU" and torch.cuda.is_available():
yield from _transcribe_sense_gpu_stream(model_id, language, audio_path, enable_punct, enable_diar)
yield from _transcribe_sense_cpu_stream(model_id, language, audio_path, enable_punct, enable_diar)
# —————— Gradio UI ——————
DEMO_CSS = """
.diar {
min-height: 100px !important;
max-height: 300px;
overflow-y: auto;
padding: 8px;
border: 1px solid #444;
}
"""
Demo = gr.Blocks(css=DEMO_CSS)
with Demo:
gr.Markdown("## Whisper vs. SenseVoice (…)")
audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
examples = gr.Examples(
examples=[["interview.mp3"], ["news.mp3"]],
inputs=[audio_input],
label="Example Audio Files"
)
# ────────────────────────────────────────────────────────────────
# 1) CONTROL PANELS (still side-by-side)
with gr.Row():
with gr.Column():
gr.Markdown("### Faster-Whisper ASR")
whisper_dd = gr.Dropdown(choices=WHISPER_MODELS, value=WHISPER_MODELS[0], label="Whisper Model")
whisper_lang = gr.Dropdown(choices=WHISPER_LANGUAGES, value="auto", label="Whisper Language")
device_radio = gr.Radio(choices=["GPU","CPU"], value="GPU", label="Device")
diar_check = gr.Checkbox(label="Enable Diarization", value=True)
btn_w = gr.Button("Transcribe with Faster-Whisper")
with gr.Column():
gr.Markdown("### FunASR SenseVoice ASR")
sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model")
sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language")
device_radio_s = gr.Radio(choices=["GPU","CPU"], value="GPU", label="Device")
punct_chk = gr.Checkbox(label="Enable Punctuation", value=True)
diar_s_chk = gr.Checkbox(label="Enable Diarization", value=True)
btn_s = gr.Button("Transcribe with SenseVoice")
# ────────────────────────────────────────────────────────────────
# 2) SHARED TRANSCRIPT ROW (aligned side-by-side)
with gr.Row():
with gr.Column():
gr.Markdown("### Faster-Whisper Output")
out_w = gr.Textbox(label="Raw Transcript", visible=False)
out_w_d = gr.HTML(label="Diarized Transcript", elem_classes=["diar"])
with gr.Column():
gr.Markdown("### SenseVoice Output")
out_s = gr.Textbox(label="Raw Transcript", visible=False)
out_s_d = gr.HTML(label="Diarized Transcript", elem_classes=["diar"])
# ────────────────────────────────────────────────────────────────
# 3) WIRING UP TOGGLES & BUTTONS
# toggle raw ↔ diarized for each system
diar_check.change(lambda e: gr.update(visible=not e), diar_check, out_w)
diar_check.change(lambda e: gr.update(visible=e), diar_check, out_w_d)
diar_s_chk.change(lambda e: gr.update(visible=not e), diar_s_chk, out_s)
diar_s_chk.change(lambda e: gr.update(visible=e), diar_s_chk, out_s_d)
# wire the callbacks into those shared boxes
btn_w.click(
fn=transcribe_fwhisper_stream,
inputs=[whisper_dd, whisper_lang, audio_input, device_radio, diar_check],
outputs=[out_w, out_w_d]
)
btn_s.click(
fn=transcribe_sense_steam,
inputs=[sense_dd, sense_lang, audio_input, punct_chk, diar_s_chk, device_radio_s],
outputs=[out_s, out_s_d]
)
if __name__ == "__main__":
Demo.launch()