Spaces:
Build error
Build error
import os | |
import tempfile | |
import json | |
from pathlib import Path | |
from typing import Dict, Any | |
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
import torchaudio | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
import logging | |
import uvicorn | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="Speech-to-Text API", | |
description="API for speech-to-text transcription using CrisperWhisper model", | |
version="1.0.0" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize model and processor | |
async def load_model(): | |
logger.info("Loading CrisperWhisper model...") | |
global processor, model, device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = AutoProcessor.from_pretrained("nyrahealth/CrisperWhisper") | |
model = AutoModelForSpeechSeq2Seq.from_pretrained("nyrahealth/CrisperWhisper").to(device) | |
model.eval() | |
logger.info(f"Model loaded successfully on {device}") | |
# Create a temporary directory to store files | |
TEMP_DIR = Path(tempfile.mkdtemp()) | |
ALLOWED_EXTENSIONS = {'mp3', 'wav', 'flac', 'ogg', 'm4a', 'mp4'} | |
def is_valid_audio_file(filename: str) -> bool: | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
async def transcribe_audio(file: UploadFile = File(...)): | |
""" | |
Transcribe an audio file and return word-level timestamps. | |
- **file**: Audio file to transcribe (MP3, WAV, FLAC, OGG, M4A, MP4) | |
Returns a JSON with transcription and timestamps. | |
""" | |
# Check if file is selected | |
if not file.filename: | |
raise HTTPException(status_code=400, detail="No file selected") | |
# Check if file type is allowed | |
if not is_valid_audio_file(file.filename): | |
raise HTTPException(status_code=400, | |
detail=f"File type not allowed. Supported formats: {', '.join(ALLOWED_EXTENSIONS)}") | |
try: | |
# Create a safe filename | |
safe_filename = ''.join(c if c.isalnum() or c in '._- ' else '_' for c in file.filename) | |
file_path = TEMP_DIR / safe_filename | |
# Save the uploaded file | |
with open(file_path, "wb") as buffer: | |
content = await file.read() | |
buffer.write(content) | |
logger.info(f"Processing file: {safe_filename}") | |
# Load audio file | |
waveform, sample_rate = torchaudio.load(file_path) | |
# Convert to mono if stereo | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
# Resample to 16kHz if needed | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
waveform = resampler(waveform) | |
sample_rate = 16000 | |
# Process audio with the model | |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").to(device) | |
# Generate transcription with word timestamps | |
with torch.no_grad(): | |
generated_tokens = model.generate( | |
**input_features, | |
return_timestamps=True, | |
task="transcribe" | |
) | |
# Process outputs | |
result = processor.decode_timestamps(generated_tokens[0].detach().cpu(), slice_start_indices=True) | |
# Format the output | |
full_text = result['text'] | |
# Process chunks with timestamps | |
chunks = [] | |
for chunk in result['chunks']: | |
# Only include non-empty chunks | |
if chunk['text'].strip(): | |
chunks.append({ | |
"timestamp": [chunk['timestamp'][0], chunk['timestamp'][1]], | |
"text": chunk['text'].strip() | |
}) | |
# Create output JSON | |
output = { | |
"text": full_text, | |
"chunks": chunks | |
} | |
# Clean up the file immediately to save space | |
os.remove(file_path) | |
# Return JSON directly | |
return output | |
except Exception as e: | |
logger.error(f"Error during transcription: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint for Cloud Run""" | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 8080)) | |
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False) |