File size: 4,993 Bytes
4874e49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import tempfile
import json
from pathlib import Path
from typing import Dict, Any

from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import logging
import uvicorn

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="Speech-to-Text API",
    description="API for speech-to-text transcription using CrisperWhisper model",
    version="1.0.0"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize model and processor
@app.on_event("startup")
async def load_model():
    logger.info("Loading CrisperWhisper model...")
    global processor, model, device
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    processor = AutoProcessor.from_pretrained("nyrahealth/CrisperWhisper")
    model = AutoModelForSpeechSeq2Seq.from_pretrained("nyrahealth/CrisperWhisper").to(device)
    model.eval()
    logger.info(f"Model loaded successfully on {device}")

# Create a temporary directory to store files
TEMP_DIR = Path(tempfile.mkdtemp())
ALLOWED_EXTENSIONS = {'mp3', 'wav', 'flac', 'ogg', 'm4a', 'mp4'}

def is_valid_audio_file(filename: str) -> bool:
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
    """

    Transcribe an audio file and return word-level timestamps.

    

    - **file**: Audio file to transcribe (MP3, WAV, FLAC, OGG, M4A, MP4)

    

    Returns a JSON with transcription and timestamps.

    """
    # Check if file is selected
    if not file.filename:
        raise HTTPException(status_code=400, detail="No file selected")
        
    # Check if file type is allowed
    if not is_valid_audio_file(file.filename):
        raise HTTPException(status_code=400, 
                           detail=f"File type not allowed. Supported formats: {', '.join(ALLOWED_EXTENSIONS)}")
    
    try:
        # Create a safe filename
        safe_filename = ''.join(c if c.isalnum() or c in '._- ' else '_' for c in file.filename)
        file_path = TEMP_DIR / safe_filename
        
        # Save the uploaded file
        with open(file_path, "wb") as buffer:
            content = await file.read()
            buffer.write(content)
        
        logger.info(f"Processing file: {safe_filename}")
        
        # Load audio file
        waveform, sample_rate = torchaudio.load(file_path)
        
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Resample to 16kHz if needed
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            waveform = resampler(waveform)
            sample_rate = 16000
        
        # Process audio with the model
        input_features = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").to(device)
        
        # Generate transcription with word timestamps
        with torch.no_grad():
            generated_tokens = model.generate(
                **input_features,
                return_timestamps=True,
                task="transcribe"
            )
        
        # Process outputs
        result = processor.decode_timestamps(generated_tokens[0].detach().cpu(), slice_start_indices=True)
        
        # Format the output
        full_text = result['text']
        
        # Process chunks with timestamps
        chunks = []
        for chunk in result['chunks']:
            # Only include non-empty chunks
            if chunk['text'].strip():
                chunks.append({
                    "timestamp": [chunk['timestamp'][0], chunk['timestamp'][1]],
                    "text": chunk['text'].strip()
                })
        
        # Create output JSON
        output = {
            "text": full_text,
            "chunks": chunks
        }
        
        # Clean up the file immediately to save space
        os.remove(file_path)
        
        # Return JSON directly
        return output
        
    except Exception as e:
        logger.error(f"Error during transcription: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """Health check endpoint for Cloud Run"""
    return {"status": "healthy"}

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 8080))
    uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)