SlouchyBuffalo's picture
Create app.py
a869f42 verified
import spaces
import torch
import gradio as gr
import tempfile
import os
import uuid
import scipy.io.wavfile
import time
import numpy as np
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
# Set environment variable to skip flash attention issues
os.environ["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16
MODEL_NAME = "openai/whisper-large-v3-turbo"
# Load model - removed attn_implementation for compatibility
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_NAME,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(MODEL_NAME)
# Create pipeline with explicit tokenizer and feature extractor
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=10,
torch_dtype=torch_dtype,
device=device,
)
@spaces.GPU
def stream_transcribe(stream, new_chunk):
start_time = time.time()
try:
sr, y = new_chunk
# Convert to mono if stereo
if y.ndim > 1:
y = y.mean(axis=1)
y = y.astype(np.float32)
y /= np.max(np.abs(y))
if stream is not None:
stream = np.concatenate([stream, y])
else:
stream = y
transcription = pipe({"sampling_rate": sr, "raw": stream})["text"]
end_time = time.time()
latency = end_time - start_time
return stream, transcription, f"{latency:.2f}"
except Exception as e:
print(f"Error during Transcription: {e}")
return stream, str(e), "Error"
@spaces.GPU
def transcribe(inputs, previous_transcription):
start_time = time.time()
try:
filename = f"{uuid.uuid4().hex}.wav"
sample_rate, audio_data = inputs
scipy.io.wavfile.write(filename, sample_rate, audio_data)
transcription = pipe(filename)["text"]
previous_transcription += transcription
# Clean up the temporary file
os.remove(filename)
end_time = time.time()
latency = end_time - start_time
return previous_transcription, f"{latency:.2f}"
except Exception as e:
print(f"Error during Transcription: {e}")
return previous_transcription, "Error"
@spaces.GPU
def translate_and_transcribe(inputs, previous_transcription, target_language):
start_time = time.time()
try:
filename = f"{uuid.uuid4().hex}.wav"
sample_rate, audio_data = inputs
scipy.io.wavfile.write(filename, sample_rate, audio_data)
translation = pipe(filename, generate_kwargs={"task": "translate", "language": target_language})["text"]
previous_transcription += translation
# Clean up the temporary file
os.remove(filename)
end_time = time.time()
latency = end_time - start_time
return previous_transcription, f"{latency:.2f}"
except Exception as e:
print(f"Error during Translation and Transcription: {e}")
return previous_transcription, "Error"
def clear():
return ""
def clear_state():
return None
with gr.Blocks() as microphone:
with gr.Column():
gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
with gr.Row():
input_audio_microphone = gr.Audio(streaming=True)
output = gr.Textbox(label="Transcription", value="")
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
with gr.Row():
clear_button = gr.Button("Clear Output")
state = gr.State()
input_audio_microphone.stream(stream_transcribe, [state, input_audio_microphone], [state, output, latency_textbox], time_limit=30, stream_every=2, concurrency_limit=None)
clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
with gr.Blocks() as file:
with gr.Column():
gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
with gr.Row():
input_audio_microphone = gr.Audio(sources="upload", type="numpy")
output = gr.Textbox(label="Transcription", value="")
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
with gr.Row():
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear Output")
submit_button.click(transcribe, [input_audio_microphone, output], [output, latency_textbox], concurrency_limit=None)
clear_button.click(clear, outputs=[output])
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])
demo.launch()