baoyin2024's picture
Update app.py
f5dc719 verified
raw
history blame
5.76 kB
from flask import Flask, request, jsonify
import os
import io
import whisperx
import torchaudio
import gc
import tempfile
import ffmpeg
from datetime import datetime
from threading import Semaphore
app = Flask(__name__)
# 从环境变量中读取 API_KEY
api_key = os.environ.get("API_KEY")
if not api_key:
print("Error: API_KEY environment variable not set!")
# 信号量,用于限制并发请求的数量
MAX_CONCURRENT_REQUESTS = 2
request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
# GPU device
device = "cuda"
compute_type = "float16"
def validate_api_key(request):
"""
验证 API Key. 从 request header 读取 API Key,并与环境变量中的 API Key 进行比较。
Args:
request: Flask request 对象.
Returns:
True 如果 API Key 有效,否则 False.
"""
api_key_header = request.headers.get("X-API-Key")
api_key_query = request.args.get("api_key")
api_key_form = request.form.get("api_key")
api_key_env = os.environ.get("API_KEY")
if not api_key_env:
return False, "API_KEY environment variable not set"
if api_key_header == api_key_env or api_key_query == api_key_env or api_key_form == api_key_env:
return True, None
else:
return False, "Invalid API Key"
@app.route("/whisper_transcribe", methods=["POST"])
def whisper_transcribe():
is_valid, message = validate_api_key(request) # 验证 API Key
if not is_valid:
return jsonify({"error": message}), 401
with request_semaphore:
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
filename = file.filename
file_extension = filename.rsplit('.', 1)[1].lower()
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a', 'flac', 'aac', 'wma', 'opus', 'aiff', 'mp4', 'avi', 'mov',
'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'}
if file_extension not in allowed_extensions:
return jsonify({'error': f'Invalid file format. Supported: {", ".join(allowed_extensions)}'}), 400
try:
# Save the uploaded file to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_extension}') as temp_file:
file.save(temp_file.name)
temp_file_path = temp_file.name
# Determine if the file is a video file
video_extensions = {'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'}
if file_extension in video_extensions:
file_type = "video"
try:
# Extract audio from video using ffmpeg
audio_file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
ffmpeg.input(temp_file_path).output(audio_file_path, format='wav', acodec='pcm_s16le').run(quiet=True, overwrite_output=True)
except Exception as e:
return jsonify({'error': f'Failed to extract audio from video: {str(e)}'}), 500
# Delete the temporary video file
os.remove(temp_file_path)
audio_file_path_final = audio_file_path
else:
file_type = "audio"
audio_file_path_final = temp_file_path
# Load the audio file
try:
audio, samplerate = torchaudio.load(audio_file_path_final)
audio = audio.to(device)
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
audio = audio.squeeze()
if samplerate != 16000:
audio = torchaudio.functional.resample(audio, samplerate, 16000)
except Exception as e:
return jsonify({'error': f'Failed to load audio file: {str(e)}'}), 500
# Ensure the audio duration does not exceed 10 minutes
max_duration = 10 * 60 # 10 minutes in seconds
if audio.shape[-1] / 16000 > max_duration:
return jsonify({'error': 'Audio duration exceeds the maximum allowed duration of 10 minutes'}), 400
# Perform transcription
try:
wmodel, model_options = get_model()
segments, info = wmodel.transcribe(audio, batch_size=model_options.get("batch_size", None))
segments = list(segments) # Convert generator to list
transcription = ""
for segment in segments:
transcription += segment.text
except Exception as e:
return jsonify({'error': f'Transcription failed: {str(e)}'}), 500
finally:
# Clean up temporary files
os.remove(audio_file_path_final)
gc.collect()
torch.cuda.empty_cache()
return jsonify({'transcription': transcription, 'file_type': file_type}), 200
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route("/health", methods=["GET"])
def health_check():
return jsonify({"status": "healthy"}), 200
@app.route("/status/busy", methods=["GET"])
def status_busy():
return jsonify({"busy": request_semaphore._value == 0}), 200
def get_model():
"""Load model"""
model_name = "guillaumekln/faster-whisper-small"
model_options = {"beam_size": 5}
wmodel = whisperx.load_model(model_name, device, compute_type=compute_type)
return wmodel, model_options
if __name__ == "__main__":
app.run(debug=True, port=int(os.environ.get("PORT", 7860)))