filler-trans / app.py
cheesecz's picture
Upload 3 files
4874e49 verified
import os
import tempfile
import json
from pathlib import Path
from typing import Dict, Any
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import logging
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Speech-to-Text API",
description="API for speech-to-text transcription using CrisperWhisper model",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize model and processor
@app.on_event("startup")
async def load_model():
logger.info("Loading CrisperWhisper model...")
global processor, model, device
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("nyrahealth/CrisperWhisper")
model = AutoModelForSpeechSeq2Seq.from_pretrained("nyrahealth/CrisperWhisper").to(device)
model.eval()
logger.info(f"Model loaded successfully on {device}")
# Create a temporary directory to store files
TEMP_DIR = Path(tempfile.mkdtemp())
ALLOWED_EXTENSIONS = {'mp3', 'wav', 'flac', 'ogg', 'm4a', 'mp4'}
def is_valid_audio_file(filename: str) -> bool:
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
"""
Transcribe an audio file and return word-level timestamps.
- **file**: Audio file to transcribe (MP3, WAV, FLAC, OGG, M4A, MP4)
Returns a JSON with transcription and timestamps.
"""
# Check if file is selected
if not file.filename:
raise HTTPException(status_code=400, detail="No file selected")
# Check if file type is allowed
if not is_valid_audio_file(file.filename):
raise HTTPException(status_code=400,
detail=f"File type not allowed. Supported formats: {', '.join(ALLOWED_EXTENSIONS)}")
try:
# Create a safe filename
safe_filename = ''.join(c if c.isalnum() or c in '._- ' else '_' for c in file.filename)
file_path = TEMP_DIR / safe_filename
# Save the uploaded file
with open(file_path, "wb") as buffer:
content = await file.read()
buffer.write(content)
logger.info(f"Processing file: {safe_filename}")
# Load audio file
waveform, sample_rate = torchaudio.load(file_path)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
sample_rate = 16000
# Process audio with the model
input_features = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").to(device)
# Generate transcription with word timestamps
with torch.no_grad():
generated_tokens = model.generate(
**input_features,
return_timestamps=True,
task="transcribe"
)
# Process outputs
result = processor.decode_timestamps(generated_tokens[0].detach().cpu(), slice_start_indices=True)
# Format the output
full_text = result['text']
# Process chunks with timestamps
chunks = []
for chunk in result['chunks']:
# Only include non-empty chunks
if chunk['text'].strip():
chunks.append({
"timestamp": [chunk['timestamp'][0], chunk['timestamp'][1]],
"text": chunk['text'].strip()
})
# Create output JSON
output = {
"text": full_text,
"chunks": chunks
}
# Clean up the file immediately to save space
os.remove(file_path)
# Return JSON directly
return output
except Exception as e:
logger.error(f"Error during transcription: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint for Cloud Run"""
return {"status": "healthy"}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 8080))
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)