hyungjoochae's picture
Update app.py
d2d2762 verified
import gradio as gr
import torch
import numpy as np
from sys import platform
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.utils import is_flash_attn_2_available
from subtitle_manager import Subtitle
import logging
print(gr.__version__)
logging.basicConfig(level=logging.INFO)
# Global state
pipe = None
last_model = None
def get_language_names():
return [
"af", "am", "ar", "as", "az", "ba", "be", "bg", "bn", "bo", "br", "bs",
"ca", "cs", "cy", "da", "de", "el", "en", "es", "et", "eu", "fa", "fi",
"fo", "fr", "gl", "gu", "ha", "haw", "he", "hi", "hr", "ht", "hu", "hy",
"id", "is", "it", "ja", "jw", "ka", "kk", "km", "kn", "ko", "la", "lb",
"ln", "lo", "lt", "lv", "mg", "mi", "mk", "ml", "mn", "mr", "ms", "mt",
"my", "ne", "nl", "nn", "no", "oc", "pa", "pl", "ps", "pt", "ro", "ru",
"sa", "sd", "si", "sk", "sl", "sn", "so", "sq", "sr", "su", "sv", "sw",
"ta", "te", "tg", "th", "tk", "tl", "tr", "tt", "uk", "ur", "uz", "vi",
"yi", "yo", "zh"
]
def create_pipe(model_id, flash):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
).to(device)
processor = AutoProcessor.from_pretrained(model_id)
return pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
device=device,
torch_dtype=torch_dtype,
)
def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
chunk_length_s, batch_size, progress=gr.Progress()):
global last_model, pipe
progress(0, desc="Loading Audio...")
if last_model != modelName or pipe is None:
torch.cuda.empty_cache()
progress(0.1, desc="Loading Model...")
pipe = create_pipe(modelName, flash)
last_model = modelName
files = []
if multipleFiles:
files += multipleFiles
if urlData:
files.append(urlData)
if microphoneData:
files.append(microphoneData)
srt_sub = Subtitle("srt")
vtt_sub = Subtitle("vtt")
txt_sub = Subtitle("txt")
files_out = []
for file in progress.tqdm(files, desc="Working..."):
outputs = pipe(
file,
chunk_length_s=chunk_length_s,
batch_size=batch_size,
generate_kwargs={
"language": languageName if languageName != "Automatic Detection" else None,
"task": task
},
return_timestamps=True,
)
file_out = file.split('/')[-1]
srt = srt_sub.get_subtitle(outputs["chunks"])
vtt = vtt_sub.get_subtitle(outputs["chunks"])
txt = txt_sub.get_subtitle(outputs["chunks"])
with open(file_out+".srt", 'w', encoding='utf-8') as f:
f.write(srt)
with open(file_out+".vtt", 'w', encoding='utf-8') as f:
f.write(vtt)
with open(file_out+".txt", 'w', encoding='utf-8') as f:
f.write(txt)
files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"]
progress(1, desc="Completed!")
return files_out, vtt, txt
# Realtime STT
def transcribe_stream(buffer, new_chunk):
sr, chunk = new_chunk
if chunk.ndim > 1:
chunk = chunk.mean(axis=1)
chunk = chunk.astype(np.float32)
peak = np.max(np.abs(chunk))
if peak > 0:
chunk /= peak
buffer = chunk if buffer is None else np.concatenate([buffer, chunk])
text = pipe({"sampling_rate": sr, "raw": buffer})["text"]
return buffer, text
# Gradio UI
with gr.Blocks(title="Insanely Fast Whisper") as demo:
gr.Markdown("## 🎙️ Insanely Fast Whisper + Real-time STT")
whisper_models = [
"openai/whisper-tiny", "openai/whisper-tiny.en",
"openai/whisper-base", "openai/whisper-base.en",
"openai/whisper-small", "openai/whisper-small.en",
"distil-whisper/distil-small.en",
"openai/whisper-medium", "openai/whisper-medium.en",
"distil-whisper/distil-medium.en",
"openai/whisper-large", "openai/whisper-large-v1",
"openai/whisper-large-v2", "distil-whisper/distil-large-v2",
"openai/whisper-large-v3", "distil-whisper/distil-large-v3",
]
with gr.Tab("File Transcription"):
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
whisper_models,
value="distil-whisper/distil-large-v2",
label="Model"
)
language_dropdown = gr.Dropdown(
["Automatic Detection"] + sorted(get_language_names()),
value="Automatic Detection",
label="Language"
)
url_input = gr.Text(label="URL (YouTube, etc.)")
file_input = gr.File(label="Upload Files", file_count="multiple")
audio_input = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Audio Input"
)
task_dropdown = gr.Dropdown(
["transcribe", "translate"],
label="Task",
value="transcribe"
)
flash_checkbox = gr.Checkbox(label='Flash', info='Use Flash Attention 2')
chunk_length = gr.Number(label='chunk_length_s', value=30)
batch_size = gr.Number(label='batch_size', value=24)
transcribe_button = gr.Button("Transcribe")
with gr.Column():
output_files = gr.File(label="Download")
output_text = gr.Text(label="Transcription")
output_segments = gr.Text(label="Segments")
transcribe_button.click(
fn=transcribe_webui_simple_progress,
inputs=[
model_dropdown, language_dropdown, url_input,
file_input, audio_input, task_dropdown,
flash_checkbox, chunk_length, batch_size
],
outputs=[output_files, output_text, output_segments]
)
with gr.Tab("Real-time Transcription"):
st_buffer = gr.State()
mic_rt = gr.Audio(
sources=["microphone"], type="numpy", streaming=True,
label="🎤 Speak Now (Live Transcription)"
)
txt_rt = gr.Textbox(label="Real-time Transcription")
mic_rt.stream(
fn=transcribe_stream,
inputs=[st_buffer, mic_rt],
outputs=[st_buffer, txt_rt]
)
# Preload model for Hugging Face spaces
def load_model():
global pipe, last_model
last_model = "distil-whisper/distil-large-v2"
pipe = create_pipe(last_model, flash=False)
demo.load(load_model)
# Launch the app
if __name__ == "__main__":
demo.launch()