fillertwo / app.py
cheesecz's picture
Create app.py
8bc640e verified
import os
import torch
import gradio as gr
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import uvicorn
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import soundfile as sf
import numpy as np
import tempfile
# Initialize FastAPI app
app = FastAPI()
# Initialize model and processor
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "nyrahealth/CrisperWhisper"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
batch_size=16,
return_timestamps='word',
torch_dtype=torch_dtype,
device=device,
)
def adjust_pauses_for_hf_pipeline_output(pipeline_output, split_threshold=0.12):
"""
Adjust pause timings by distributing pauses up to the threshold evenly between adjacent words.
"""
adjusted_chunks = pipeline_output["chunks"].copy()
for i in range(len(adjusted_chunks) - 1):
current_chunk = adjusted_chunks[i]
next_chunk = adjusted_chunks[i + 1]
current_start, current_end = current_chunk["timestamp"]
next_start, next_end = next_chunk["timestamp"]
pause_duration = next_start - current_end
if pause_duration > 0:
if pause_duration > split_threshold:
distribute = split_threshold / 2
else:
distribute = pause_duration / 2
adjusted_chunks[i]["timestamp"] = (current_start, current_end + distribute)
adjusted_chunks[i + 1]["timestamp"] = (next_start - distribute, next_end)
pipeline_output["chunks"] = adjusted_chunks
return pipeline_output
def process_audio(audio_path):
"""Process audio file and return transcription with timestamps"""
try:
# Read audio file
audio_data, sample_rate = sf.read(audio_path)
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
# Process with pipeline
result = pipe({"array": audio_data, "sampling_rate": sample_rate})
# Adjust pauses
adjusted_result = adjust_pauses_for_hf_pipeline_output(result)
return adjusted_result
except Exception as e:
return {"error": str(e)}
# FastAPI endpoint
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
try:
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
# Process the audio
result = process_audio(temp_file_path)
# Clean up temporary file
os.unlink(temp_file_path)
return JSONResponse(content=result)
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
# Gradio interface
def gradio_transcribe(audio):
if audio is None:
return "Please upload an audio file"
result = process_audio(audio)
return result
# Create Gradio interface
demo = gr.Interface(
fn=gradio_transcribe,
inputs=gr.Audio(type="filepath", label="Upload Audio (MP3 or WAV)"),
outputs=gr.JSON(label="Transcription with Timestamps"),
title="CrisperWhisper Audio Transcription",
description="Upload an audio file to get transcription with word-level timestamps",
examples=[],
allow_flagging="never"
)
# Mount Gradio app
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)