import os import torch import gradio as gr from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse import uvicorn from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline import soundfile as sf import numpy as np import tempfile # Initialize FastAPI app app = FastAPI() # Initialize model and processor device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = "nyrahealth/CrisperWhisper" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, chunk_length_s=30, batch_size=16, return_timestamps='word', torch_dtype=torch_dtype, device=device, ) def adjust_pauses_for_hf_pipeline_output(pipeline_output, split_threshold=0.12): """ Adjust pause timings by distributing pauses up to the threshold evenly between adjacent words. """ adjusted_chunks = pipeline_output["chunks"].copy() for i in range(len(adjusted_chunks) - 1): current_chunk = adjusted_chunks[i] next_chunk = adjusted_chunks[i + 1] current_start, current_end = current_chunk["timestamp"] next_start, next_end = next_chunk["timestamp"] pause_duration = next_start - current_end if pause_duration > 0: if pause_duration > split_threshold: distribute = split_threshold / 2 else: distribute = pause_duration / 2 adjusted_chunks[i]["timestamp"] = (current_start, current_end + distribute) adjusted_chunks[i + 1]["timestamp"] = (next_start - distribute, next_end) pipeline_output["chunks"] = adjusted_chunks return pipeline_output def process_audio(audio_path): """Process audio file and return transcription with timestamps""" try: # Read audio file audio_data, sample_rate = sf.read(audio_path) # Convert to mono if stereo if len(audio_data.shape) > 1: audio_data = audio_data.mean(axis=1) # Process with pipeline result = pipe({"array": audio_data, "sampling_rate": sample_rate}) # Adjust pauses adjusted_result = adjust_pauses_for_hf_pipeline_output(result) return adjusted_result except Exception as e: return {"error": str(e)} # FastAPI endpoint @app.post("/transcribe") async def transcribe_audio(file: UploadFile = File(...)): try: # Save uploaded file temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: content = await file.read() temp_file.write(content) temp_file_path = temp_file.name # Process the audio result = process_audio(temp_file_path) # Clean up temporary file os.unlink(temp_file_path) return JSONResponse(content=result) except Exception as e: return JSONResponse( status_code=500, content={"error": str(e)} ) # Gradio interface def gradio_transcribe(audio): if audio is None: return "Please upload an audio file" result = process_audio(audio) return result # Create Gradio interface demo = gr.Interface( fn=gradio_transcribe, inputs=gr.Audio(type="filepath", label="Upload Audio (MP3 or WAV)"), outputs=gr.JSON(label="Transcription with Timestamps"), title="CrisperWhisper Audio Transcription", description="Upload an audio file to get transcription with word-level timestamps", examples=[], allow_flagging="never" ) # Mount Gradio app app = gr.mount_gradio_app(app, demo, path="/") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)