|
import os |
|
import torch |
|
import gradio as gr |
|
from fastapi import FastAPI, UploadFile, File |
|
from fastapi.responses import JSONResponse |
|
import uvicorn |
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
import soundfile as sf |
|
import numpy as np |
|
import tempfile |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
model_id = "nyrahealth/CrisperWhisper" |
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True |
|
) |
|
model.to(device) |
|
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
chunk_length_s=30, |
|
batch_size=16, |
|
return_timestamps='word', |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
) |
|
|
|
def adjust_pauses_for_hf_pipeline_output(pipeline_output, split_threshold=0.12): |
|
""" |
|
Adjust pause timings by distributing pauses up to the threshold evenly between adjacent words. |
|
""" |
|
adjusted_chunks = pipeline_output["chunks"].copy() |
|
|
|
for i in range(len(adjusted_chunks) - 1): |
|
current_chunk = adjusted_chunks[i] |
|
next_chunk = adjusted_chunks[i + 1] |
|
|
|
current_start, current_end = current_chunk["timestamp"] |
|
next_start, next_end = next_chunk["timestamp"] |
|
pause_duration = next_start - current_end |
|
|
|
if pause_duration > 0: |
|
if pause_duration > split_threshold: |
|
distribute = split_threshold / 2 |
|
else: |
|
distribute = pause_duration / 2 |
|
|
|
adjusted_chunks[i]["timestamp"] = (current_start, current_end + distribute) |
|
adjusted_chunks[i + 1]["timestamp"] = (next_start - distribute, next_end) |
|
|
|
pipeline_output["chunks"] = adjusted_chunks |
|
return pipeline_output |
|
|
|
def process_audio(audio_path): |
|
"""Process audio file and return transcription with timestamps""" |
|
try: |
|
|
|
audio_data, sample_rate = sf.read(audio_path) |
|
|
|
|
|
if len(audio_data.shape) > 1: |
|
audio_data = audio_data.mean(axis=1) |
|
|
|
|
|
result = pipe({"array": audio_data, "sampling_rate": sample_rate}) |
|
|
|
|
|
adjusted_result = adjust_pauses_for_hf_pipeline_output(result) |
|
|
|
return adjusted_result |
|
except Exception as e: |
|
return {"error": str(e)} |
|
|
|
|
|
@app.post("/transcribe") |
|
async def transcribe_audio(file: UploadFile = File(...)): |
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: |
|
content = await file.read() |
|
temp_file.write(content) |
|
temp_file_path = temp_file.name |
|
|
|
|
|
result = process_audio(temp_file_path) |
|
|
|
|
|
os.unlink(temp_file_path) |
|
|
|
return JSONResponse(content=result) |
|
except Exception as e: |
|
return JSONResponse( |
|
status_code=500, |
|
content={"error": str(e)} |
|
) |
|
|
|
|
|
def gradio_transcribe(audio): |
|
if audio is None: |
|
return "Please upload an audio file" |
|
|
|
result = process_audio(audio) |
|
return result |
|
|
|
|
|
demo = gr.Interface( |
|
fn=gradio_transcribe, |
|
inputs=gr.Audio(type="filepath", label="Upload Audio (MP3 or WAV)"), |
|
outputs=gr.JSON(label="Transcription with Timestamps"), |
|
title="CrisperWhisper Audio Transcription", |
|
description="Upload an audio file to get transcription with word-level timestamps", |
|
examples=[], |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/") |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |