satyamr196 commited on
Commit
16b9706
·
1 Parent(s): 62021f3

1)Modified code to accomodate microsoft asr_models also which require their custom code to run, hence set trust_remote_code=True in pipe, 2)Added extra error handling, 3) Now, pipe is loaded in main thread instead of background thread.

Browse files
Files changed (1) hide show
  1. ASR_Server.py +48 -11
ASR_Server.py CHANGED
@@ -43,7 +43,7 @@ dataset = load_dataset("satyamr196/asr_fairness_audio", split="train")
43
  # dataset = dataset.with_format("python", decode_audio=False)
44
  dataset = dataset.cast_column("audio", Audio(decode=False))
45
 
46
- def generateTranscript(ASR_model):
47
  import os
48
  import time
49
  import tqdm
@@ -94,8 +94,10 @@ def generateTranscript(ASR_model):
94
  total = len(df)
95
  job_status["total"] = total
96
 
 
 
97
  # Initialize ASR pipeline
98
- pipe = pipeline("automatic-speech-recognition", model=ASR_model)
99
 
100
  # Column with filenames in the CSV
101
  filename_column = df.columns[0]
@@ -213,6 +215,8 @@ def get_status():
213
 
214
  @app.route('/api', methods=['GET'])
215
  def api():
 
 
216
  model = request.args.get('ASR_model', default="", type=str)
217
  # model = re.sub(r"\s+", "", model)
218
  model = re.sub(r"[^a-zA-Z0-9/_\-.]", "", model) # sanitize the model ID
@@ -246,17 +250,50 @@ def api():
246
  'status': job_status
247
  })
248
 
249
- response = jsonify({
250
- 'message': f'Given Model {model} is being Evaluated, Please come back after a few hours and run the query again. Usually, it completes within an hour'
251
- })
252
-
253
- # Run `generateTranscript(model)` in a separate thread
254
- # Start the transcript generation in a separate thread
255
- # thread = threading.Thread(target=generateTranscript, args=(model,), daemon=True)
256
- thread = threading.Thread(target=generateTranscript, args=(model,))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  thread.start()
 
 
 
 
 
 
 
 
258
 
259
- return response
 
 
 
260
 
261
  @app.route("/insert", methods=["POST"])
262
  def insert_document():
 
43
  # dataset = dataset.with_format("python", decode_audio=False)
44
  dataset = dataset.cast_column("audio", Audio(decode=False))
45
 
46
+ def generateTranscript(ASR_model, pipe=None):
47
  import os
48
  import time
49
  import tqdm
 
94
  total = len(df)
95
  job_status["total"] = total
96
 
97
+ if pipe is None:
98
+ pipe = pipeline("automatic-speech-recognition", model=ASR_model, trust_remote_code=True)
99
  # Initialize ASR pipeline
100
+ # pipe = pipeline("automatic-speech-recognition", model=ASR_model, trust_remote_code=True)
101
 
102
  # Column with filenames in the CSV
103
  filename_column = df.columns[0]
 
215
 
216
  @app.route('/api', methods=['GET'])
217
  def api():
218
+ from transformers import pipeline
219
+
220
  model = request.args.get('ASR_model', default="", type=str)
221
  # model = re.sub(r"\s+", "", model)
222
  model = re.sub(r"[^a-zA-Z0-9/_\-.]", "", model) # sanitize the model ID
 
250
  'status': job_status
251
  })
252
 
253
+ try:
254
+ print(f"⏳ Loading model {model} in main thread...")
255
+ pipe = pipeline("automatic-speech-recognition", model=model, trust_remote_code=True)
256
+ except Exception as e:
257
+ return jsonify({
258
+ "error": f"Model load failed: {str(e)}",
259
+ "message": f"Model load failed: {str(e)}"
260
+ }), 500
261
+
262
+ def thread_wrapper(model, pipe):
263
+ try:
264
+ job_status["running"] = True
265
+ job_status["error"] = None
266
+ job_status["model"] = model
267
+ generateTranscript(model, pipe)
268
+ job_status["running"] = False
269
+ # return jsonify({
270
+ # 'message': f'Given Model {model} is being Evaluated, Please come back after a few hours and run the query again. Usually, it completes within an hour'
271
+ # }),200
272
+ except Exception as e:
273
+ print(f"❌ Background transcription for {model} failed:", e)
274
+ job_status["running"] = False
275
+ job_status["error"] = str(e)
276
+ # return jsonify({
277
+ # "error": f"Background transcription failed: {str(e)}",
278
+ # "message": f"Background transcription failed: {str(e)}"
279
+ # }), 500
280
+
281
+ # Then use:
282
+ thread = threading.Thread(target=thread_wrapper, args=(model, pipe), daemon=True)
283
  thread.start()
284
+ # thread = threading.Thread(target=generateTranscript, args=(model,), daemon=True)
285
+ # thread = threading.Thread(target=generateTranscript, args=(model,pipe))
286
+ # thread.start()
287
+ if job_status.get("error"):
288
+ return jsonify({
289
+ 'message': f'❌transcription for model "{job_status.get("model")}" failed.',
290
+ 'error': job_status["error"]
291
+ }), 500
292
 
293
+ return jsonify({
294
+ 'message': f'Given Model {model} is being Evaluated. Please come back after a few hours and run the query again.',
295
+ 'status': job_status
296
+ }), 202
297
 
298
  @app.route("/insert", methods=["POST"])
299
  def insert_document():