cheesecz commited on
Commit
8bc640e
·
verified ·
1 Parent(s): 88ffb5b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from fastapi import FastAPI, UploadFile, File
5
+ from fastapi.responses import JSONResponse
6
+ import uvicorn
7
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
8
+ import soundfile as sf
9
+ import numpy as np
10
+ import tempfile
11
+
12
+ # Initialize FastAPI app
13
+ app = FastAPI()
14
+
15
+ # Initialize model and processor
16
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
17
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
+
19
+ model_id = "nyrahealth/CrisperWhisper"
20
+
21
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
22
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
23
+ )
24
+ model.to(device)
25
+
26
+ processor = AutoProcessor.from_pretrained(model_id)
27
+
28
+ pipe = pipeline(
29
+ "automatic-speech-recognition",
30
+ model=model,
31
+ tokenizer=processor.tokenizer,
32
+ feature_extractor=processor.feature_extractor,
33
+ chunk_length_s=30,
34
+ batch_size=16,
35
+ return_timestamps='word',
36
+ torch_dtype=torch_dtype,
37
+ device=device,
38
+ )
39
+
40
+ def adjust_pauses_for_hf_pipeline_output(pipeline_output, split_threshold=0.12):
41
+ """
42
+ Adjust pause timings by distributing pauses up to the threshold evenly between adjacent words.
43
+ """
44
+ adjusted_chunks = pipeline_output["chunks"].copy()
45
+
46
+ for i in range(len(adjusted_chunks) - 1):
47
+ current_chunk = adjusted_chunks[i]
48
+ next_chunk = adjusted_chunks[i + 1]
49
+
50
+ current_start, current_end = current_chunk["timestamp"]
51
+ next_start, next_end = next_chunk["timestamp"]
52
+ pause_duration = next_start - current_end
53
+
54
+ if pause_duration > 0:
55
+ if pause_duration > split_threshold:
56
+ distribute = split_threshold / 2
57
+ else:
58
+ distribute = pause_duration / 2
59
+
60
+ adjusted_chunks[i]["timestamp"] = (current_start, current_end + distribute)
61
+ adjusted_chunks[i + 1]["timestamp"] = (next_start - distribute, next_end)
62
+
63
+ pipeline_output["chunks"] = adjusted_chunks
64
+ return pipeline_output
65
+
66
+ def process_audio(audio_path):
67
+ """Process audio file and return transcription with timestamps"""
68
+ try:
69
+ # Read audio file
70
+ audio_data, sample_rate = sf.read(audio_path)
71
+
72
+ # Convert to mono if stereo
73
+ if len(audio_data.shape) > 1:
74
+ audio_data = audio_data.mean(axis=1)
75
+
76
+ # Process with pipeline
77
+ result = pipe({"array": audio_data, "sampling_rate": sample_rate})
78
+
79
+ # Adjust pauses
80
+ adjusted_result = adjust_pauses_for_hf_pipeline_output(result)
81
+
82
+ return adjusted_result
83
+ except Exception as e:
84
+ return {"error": str(e)}
85
+
86
+ # FastAPI endpoint
87
+ @app.post("/transcribe")
88
+ async def transcribe_audio(file: UploadFile = File(...)):
89
+ try:
90
+ # Save uploaded file temporarily
91
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
92
+ content = await file.read()
93
+ temp_file.write(content)
94
+ temp_file_path = temp_file.name
95
+
96
+ # Process the audio
97
+ result = process_audio(temp_file_path)
98
+
99
+ # Clean up temporary file
100
+ os.unlink(temp_file_path)
101
+
102
+ return JSONResponse(content=result)
103
+ except Exception as e:
104
+ return JSONResponse(
105
+ status_code=500,
106
+ content={"error": str(e)}
107
+ )
108
+
109
+ # Gradio interface
110
+ def gradio_transcribe(audio):
111
+ if audio is None:
112
+ return "Please upload an audio file"
113
+
114
+ result = process_audio(audio)
115
+ return result
116
+
117
+ # Create Gradio interface
118
+ demo = gr.Interface(
119
+ fn=gradio_transcribe,
120
+ inputs=gr.Audio(type="filepath", label="Upload Audio (MP3 or WAV)"),
121
+ outputs=gr.JSON(label="Transcription with Timestamps"),
122
+ title="CrisperWhisper Audio Transcription",
123
+ description="Upload an audio file to get transcription with word-level timestamps",
124
+ examples=[],
125
+ allow_flagging="never"
126
+ )
127
+
128
+ # Mount Gradio app
129
+ app = gr.mount_gradio_app(app, demo, path="/")
130
+
131
+ if __name__ == "__main__":
132
+ uvicorn.run(app, host="0.0.0.0", port=7860)