Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from sys import platform | |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor | |
from transformers.utils import is_flash_attn_2_available | |
from subtitle_manager import Subtitle | |
import logging | |
print(gr.__version__) | |
logging.basicConfig(level=logging.INFO) | |
# Global state | |
pipe = None | |
last_model = None | |
def get_language_names(): | |
return [ | |
"af", "am", "ar", "as", "az", "ba", "be", "bg", "bn", "bo", "br", "bs", | |
"ca", "cs", "cy", "da", "de", "el", "en", "es", "et", "eu", "fa", "fi", | |
"fo", "fr", "gl", "gu", "ha", "haw", "he", "hi", "hr", "ht", "hu", "hy", | |
"id", "is", "it", "ja", "jw", "ka", "kk", "km", "kn", "ko", "la", "lb", | |
"ln", "lo", "lt", "lv", "mg", "mi", "mk", "ml", "mn", "mr", "ms", "mt", | |
"my", "ne", "nl", "nn", "no", "oc", "pa", "pl", "ps", "pt", "ro", "ru", | |
"sa", "sd", "si", "sk", "sl", "sn", "so", "sq", "sr", "su", "sv", "sw", | |
"ta", "te", "tg", "th", "tk", "tl", "tr", "tt", "uk", "ur", "uz", "vi", | |
"yi", "yo", "zh" | |
] | |
def create_pipe(model_id, flash): | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True, | |
attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa", | |
).to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
return pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
device=device, | |
torch_dtype=torch_dtype, | |
) | |
def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash, | |
chunk_length_s, batch_size, progress=gr.Progress()): | |
global last_model, pipe | |
progress(0, desc="Loading Audio...") | |
if last_model != modelName or pipe is None: | |
torch.cuda.empty_cache() | |
progress(0.1, desc="Loading Model...") | |
pipe = create_pipe(modelName, flash) | |
last_model = modelName | |
files = [] | |
if multipleFiles: | |
files += multipleFiles | |
if urlData: | |
files.append(urlData) | |
if microphoneData: | |
files.append(microphoneData) | |
srt_sub = Subtitle("srt") | |
vtt_sub = Subtitle("vtt") | |
txt_sub = Subtitle("txt") | |
files_out = [] | |
for file in progress.tqdm(files, desc="Working..."): | |
outputs = pipe( | |
file, | |
chunk_length_s=chunk_length_s, | |
batch_size=batch_size, | |
generate_kwargs={ | |
"language": languageName if languageName != "Automatic Detection" else None, | |
"task": task | |
}, | |
return_timestamps=True, | |
) | |
file_out = file.split('/')[-1] | |
srt = srt_sub.get_subtitle(outputs["chunks"]) | |
vtt = vtt_sub.get_subtitle(outputs["chunks"]) | |
txt = txt_sub.get_subtitle(outputs["chunks"]) | |
with open(file_out+".srt", 'w', encoding='utf-8') as f: | |
f.write(srt) | |
with open(file_out+".vtt", 'w', encoding='utf-8') as f: | |
f.write(vtt) | |
with open(file_out+".txt", 'w', encoding='utf-8') as f: | |
f.write(txt) | |
files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"] | |
progress(1, desc="Completed!") | |
return files_out, vtt, txt | |
# Realtime STT | |
def transcribe_stream(buffer, new_chunk): | |
sr, chunk = new_chunk | |
if chunk.ndim > 1: | |
chunk = chunk.mean(axis=1) | |
chunk = chunk.astype(np.float32) | |
peak = np.max(np.abs(chunk)) | |
if peak > 0: | |
chunk /= peak | |
buffer = chunk if buffer is None else np.concatenate([buffer, chunk]) | |
text = pipe({"sampling_rate": sr, "raw": buffer})["text"] | |
return buffer, text | |
# Gradio UI | |
with gr.Blocks(title="Insanely Fast Whisper") as demo: | |
gr.Markdown("## 🎙️ Insanely Fast Whisper + Real-time STT") | |
whisper_models = [ | |
"openai/whisper-tiny", "openai/whisper-tiny.en", | |
"openai/whisper-base", "openai/whisper-base.en", | |
"openai/whisper-small", "openai/whisper-small.en", | |
"distil-whisper/distil-small.en", | |
"openai/whisper-medium", "openai/whisper-medium.en", | |
"distil-whisper/distil-medium.en", | |
"openai/whisper-large", "openai/whisper-large-v1", | |
"openai/whisper-large-v2", "distil-whisper/distil-large-v2", | |
"openai/whisper-large-v3", "distil-whisper/distil-large-v3", | |
] | |
with gr.Tab("File Transcription"): | |
with gr.Row(): | |
with gr.Column(): | |
model_dropdown = gr.Dropdown( | |
whisper_models, | |
value="distil-whisper/distil-large-v2", | |
label="Model" | |
) | |
language_dropdown = gr.Dropdown( | |
["Automatic Detection"] + sorted(get_language_names()), | |
value="Automatic Detection", | |
label="Language" | |
) | |
url_input = gr.Text(label="URL (YouTube, etc.)") | |
file_input = gr.File(label="Upload Files", file_count="multiple") | |
audio_input = gr.Audio( | |
sources=["upload", "microphone"], | |
type="filepath", | |
label="Audio Input" | |
) | |
task_dropdown = gr.Dropdown( | |
["transcribe", "translate"], | |
label="Task", | |
value="transcribe" | |
) | |
flash_checkbox = gr.Checkbox(label='Flash', info='Use Flash Attention 2') | |
chunk_length = gr.Number(label='chunk_length_s', value=30) | |
batch_size = gr.Number(label='batch_size', value=24) | |
transcribe_button = gr.Button("Transcribe") | |
with gr.Column(): | |
output_files = gr.File(label="Download") | |
output_text = gr.Text(label="Transcription") | |
output_segments = gr.Text(label="Segments") | |
transcribe_button.click( | |
fn=transcribe_webui_simple_progress, | |
inputs=[ | |
model_dropdown, language_dropdown, url_input, | |
file_input, audio_input, task_dropdown, | |
flash_checkbox, chunk_length, batch_size | |
], | |
outputs=[output_files, output_text, output_segments] | |
) | |
with gr.Tab("Real-time Transcription"): | |
st_buffer = gr.State() | |
mic_rt = gr.Audio( | |
sources=["microphone"], type="numpy", streaming=True, | |
label="🎤 Speak Now (Live Transcription)" | |
) | |
txt_rt = gr.Textbox(label="Real-time Transcription") | |
mic_rt.stream( | |
fn=transcribe_stream, | |
inputs=[st_buffer, mic_rt], | |
outputs=[st_buffer, txt_rt] | |
) | |
# Preload model for Hugging Face spaces | |
def load_model(): | |
global pipe, last_model | |
last_model = "distil-whisper/distil-large-v2" | |
pipe = create_pipe(last_model, flash=False) | |
demo.load(load_model) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |