File size: 5,766 Bytes
f5dc719
7323fd3
f5dc719
 
 
 
7323fd3
f5dc719
 
 
0117db0
 
 
f5dc719
 
 
 
 
 
 
922901f
 
f5dc719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7323fd3
8a199a7
f5dc719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a199a7
f5dc719
27eb3e4
 
 
f5dc719
27eb3e4
 
f5dc719
 
 
 
 
e9ed0f8
f5dc719
 
 
 
27eb3e4
0117db0
 
f5dc719
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from flask import Flask, request, jsonify
import os
import io
import whisperx
import torchaudio
import gc
import tempfile
import ffmpeg
from datetime import datetime
from threading import Semaphore

app = Flask(__name__)

# 从环境变量中读取 API_KEY
api_key = os.environ.get("API_KEY")
if not api_key:
    print("Error: API_KEY environment variable not set!")

# 信号量,用于限制并发请求的数量
MAX_CONCURRENT_REQUESTS = 2
request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)

# GPU device
device = "cuda"
compute_type = "float16"

def validate_api_key(request):
    """
    验证 API Key.  从 request header 读取 API Key,并与环境变量中的 API Key 进行比较。

    Args:
        request: Flask request 对象.

    Returns:
        True 如果 API Key 有效,否则 False.
    """
    api_key_header = request.headers.get("X-API-Key")
    api_key_query = request.args.get("api_key")
    api_key_form = request.form.get("api_key")

    api_key_env = os.environ.get("API_KEY")

    if not api_key_env:
        return False, "API_KEY environment variable not set"

    if api_key_header == api_key_env or api_key_query == api_key_env or api_key_form == api_key_env:
        return True, None
    else:
        return False, "Invalid API Key"


@app.route("/whisper_transcribe", methods=["POST"])
def whisper_transcribe():
    is_valid, message = validate_api_key(request)  # 验证 API Key
    if not is_valid:
        return jsonify({"error": message}), 401

    with request_semaphore:
        if 'file' not in request.files:
            return jsonify({'error': 'No file uploaded'}), 400

        file = request.files['file']
        if file.filename == '':
            return jsonify({'error': 'No file selected'}), 400

        filename = file.filename
        file_extension = filename.rsplit('.', 1)[1].lower()
        allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a', 'flac', 'aac', 'wma', 'opus', 'aiff', 'mp4', 'avi', 'mov',
                              'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'}
        if file_extension not in allowed_extensions:
            return jsonify({'error': f'Invalid file format. Supported: {", ".join(allowed_extensions)}'}), 400

        try:
            # Save the uploaded file to a temporary file
            with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_extension}') as temp_file:
                file.save(temp_file.name)
                temp_file_path = temp_file.name

            # Determine if the file is a video file
            video_extensions = {'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'}
            if file_extension in video_extensions:
                file_type = "video"
                try:
                    # Extract audio from video using ffmpeg
                    audio_file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
                    ffmpeg.input(temp_file_path).output(audio_file_path, format='wav', acodec='pcm_s16le').run(quiet=True, overwrite_output=True)
                except Exception as e:
                    return jsonify({'error': f'Failed to extract audio from video: {str(e)}'}), 500

                # Delete the temporary video file
                os.remove(temp_file_path)
                audio_file_path_final = audio_file_path
            else:
                file_type = "audio"
                audio_file_path_final = temp_file_path

            # Load the audio file
            try:
                audio, samplerate = torchaudio.load(audio_file_path_final)
                audio = audio.to(device)
                if audio.shape[0] > 1:
                    audio = audio.mean(dim=0, keepdim=True)
                audio = audio.squeeze()
                if samplerate != 16000:
                    audio = torchaudio.functional.resample(audio, samplerate, 16000)
            except Exception as e:
                return jsonify({'error': f'Failed to load audio file: {str(e)}'}), 500

            # Ensure the audio duration does not exceed 10 minutes
            max_duration = 10 * 60  # 10 minutes in seconds
            if audio.shape[-1] / 16000 > max_duration:
                return jsonify({'error': 'Audio duration exceeds the maximum allowed duration of 10 minutes'}), 400

            # Perform transcription
            try:
                wmodel, model_options = get_model()

                segments, info = wmodel.transcribe(audio, batch_size=model_options.get("batch_size", None))
                segments = list(segments)  # Convert generator to list

                transcription = ""
                for segment in segments:
                    transcription += segment.text

            except Exception as e:
                return jsonify({'error': f'Transcription failed: {str(e)}'}), 500
            finally:
                # Clean up temporary files
                os.remove(audio_file_path_final)
                gc.collect()
                torch.cuda.empty_cache()

            return jsonify({'transcription': transcription, 'file_type': file_type}), 200

        except Exception as e:
            return jsonify({'error': str(e)}), 500

@app.route("/health", methods=["GET"])
def health_check():
    return jsonify({"status": "healthy"}), 200

@app.route("/status/busy", methods=["GET"])
def status_busy():
    return jsonify({"busy": request_semaphore._value == 0}), 200

def get_model():
    """Load model"""
    model_name = "guillaumekln/faster-whisper-large-v2"
    model_options = {"beam_size": 5}
    wmodel = whisperx.load_model(model_name, device, compute_type=compute_type)

    return wmodel, model_options


if __name__ == "__main__":
    app.run(debug=True, port=int(os.environ.get("PORT", 7860)))