File size: 4,098 Bytes
8bc640e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import torch
import gradio as gr
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import uvicorn
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import soundfile as sf
import numpy as np
import tempfile
# Initialize FastAPI app
app = FastAPI()
# Initialize model and processor
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "nyrahealth/CrisperWhisper"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
batch_size=16,
return_timestamps='word',
torch_dtype=torch_dtype,
device=device,
)
def adjust_pauses_for_hf_pipeline_output(pipeline_output, split_threshold=0.12):
"""
Adjust pause timings by distributing pauses up to the threshold evenly between adjacent words.
"""
adjusted_chunks = pipeline_output["chunks"].copy()
for i in range(len(adjusted_chunks) - 1):
current_chunk = adjusted_chunks[i]
next_chunk = adjusted_chunks[i + 1]
current_start, current_end = current_chunk["timestamp"]
next_start, next_end = next_chunk["timestamp"]
pause_duration = next_start - current_end
if pause_duration > 0:
if pause_duration > split_threshold:
distribute = split_threshold / 2
else:
distribute = pause_duration / 2
adjusted_chunks[i]["timestamp"] = (current_start, current_end + distribute)
adjusted_chunks[i + 1]["timestamp"] = (next_start - distribute, next_end)
pipeline_output["chunks"] = adjusted_chunks
return pipeline_output
def process_audio(audio_path):
"""Process audio file and return transcription with timestamps"""
try:
# Read audio file
audio_data, sample_rate = sf.read(audio_path)
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
# Process with pipeline
result = pipe({"array": audio_data, "sampling_rate": sample_rate})
# Adjust pauses
adjusted_result = adjust_pauses_for_hf_pipeline_output(result)
return adjusted_result
except Exception as e:
return {"error": str(e)}
# FastAPI endpoint
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
try:
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
# Process the audio
result = process_audio(temp_file_path)
# Clean up temporary file
os.unlink(temp_file_path)
return JSONResponse(content=result)
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
# Gradio interface
def gradio_transcribe(audio):
if audio is None:
return "Please upload an audio file"
result = process_audio(audio)
return result
# Create Gradio interface
demo = gr.Interface(
fn=gradio_transcribe,
inputs=gr.Audio(type="filepath", label="Upload Audio (MP3 or WAV)"),
outputs=gr.JSON(label="Transcription with Timestamps"),
title="CrisperWhisper Audio Transcription",
description="Upload an audio file to get transcription with word-level timestamps",
examples=[],
allow_flagging="never"
)
# Mount Gradio app
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |