kemuriririn commited on
Commit
6deea8f
·
1 Parent(s): 050b9af

update check cache

Browse files
Files changed (1) hide show
  1. app.py +56 -111
app.py CHANGED
@@ -6,8 +6,11 @@ from datetime import datetime
6
  import threading # Added for locking
7
 
8
  from huggingface_hub.hf_api import RepoFile
 
9
  from sqlalchemy import or_ # Added for vote counting query
10
  import hashlib
 
 
11
 
12
  year = datetime.now().year
13
  month = datetime.now().month
@@ -667,117 +670,50 @@ def generate_tts():
667
  if len(available_models) < 2:
668
  return jsonify({"error": "Not enough TTS models available"}), 500
669
 
670
- selected_models = get_weighted_random_models(available_models, 2, ModelType.TTS)
671
- # 尝试从持久化缓存中查找两个模型的音频
672
- audio_a_path = find_cached_audio(str(selected_models[0].id), text, reference_audio_path)
673
- audio_b_path = find_cached_audio(str(selected_models[1].id), text, reference_audio_path)
674
-
675
- if audio_a_path and audio_b_path:
676
- app.logger.warning(f"Persistent Cache HIT for: '{text[:50]}...'. Using files directly.")
677
- session_id = str(uuid.uuid4())
678
- app.tts_sessions[session_id] = {
679
- "model_a": selected_models[0].id,
680
- "model_b": selected_models[1].id,
681
- "audio_a": audio_a_path,
682
- "audio_b": audio_b_path,
683
- "text": text,
684
- "created_at": datetime.utcnow(),
685
- "expires_at": datetime.utcnow() + timedelta(minutes=30),
686
- "voted": False,
687
- }
688
- return jsonify({
689
- "session_id": session_id,
690
- "audio_a": f"/api/tts/audio/{session_id}/a",
691
- "audio_b": f"/api/tts/audio/{session_id}/b",
692
- "expires_in": 1800,
693
- "cache_hit": True, # 可以认为这也是一种缓存命中
694
- })
695
- # --- 持久化缓存检查结束 ---
696
- try:
697
- audio_files = []
698
- model_ids = []
699
-
700
- # Function to process a single model (generate directly to TEMP_AUDIO_DIR, not cache subdir)
701
- def process_model_on_the_fly(model):
702
- app.logger.warning(f"Processing model {model.id} for text: '{text[:30]}...', prompt_md5: {prompt_md5}")
703
- app.logger.warning(f"Expected key: {get_tts_cache_key(str(model.id), text, reference_audio_path)}")
704
- # 传递 reference_audio_path 给 predict_tts
705
- temp_audio_path = predict_tts(text, model.id, reference_audio_path=reference_audio_path,
706
- user_token=user_token)
707
- if not temp_audio_path or not os.path.exists(temp_audio_path):
708
- raise ValueError(f"predict_tts failed for model {model.id}")
709
-
710
- # Create a unique name in the main TEMP_AUDIO_DIR for the session
711
- file_uuid = str(uuid.uuid4())
712
- dest_path = os.path.join(TEMP_AUDIO_DIR, f"{file_uuid}.wav")
713
- shutil.move(temp_audio_path, dest_path) # Move from predict_tts's temp location
714
-
715
- return {"model_id": model.id, "audio_path": dest_path}
716
-
717
- # Use ThreadPoolExecutor to process models concurrently
718
- with ThreadPoolExecutor(max_workers=2) as executor:
719
- results = list(executor.map(process_model_on_the_fly, selected_models))
720
-
721
- # Extract results
722
- for result in results:
723
- model_ids.append(result["model_id"])
724
- audio_files.append(result["audio_path"])
725
-
726
- # Create session
727
- session_id = str(uuid.uuid4())
728
- app.tts_sessions[session_id] = {
729
- "model_a": model_ids[0],
730
- "model_b": model_ids[1],
731
- "audio_a": audio_files[0], # Paths are now from TEMP_AUDIO_DIR directly
732
- "audio_b": audio_files[1],
733
- "text": text,
734
- "created_at": datetime.utcnow(),
735
- "expires_at": datetime.utcnow() + timedelta(minutes=30),
736
- "voted": False,
737
- }
738
-
739
- # 清理临时参考音频文件
740
- if reference_audio_path and os.path.exists(reference_audio_path):
741
- os.remove(reference_audio_path)
742
-
743
- # Check if text and prompt are in predefined libraries
744
- if text in predefined_texts and prompt_md5 in predefined_prompts.values():
745
- with preload_cache_lock:
746
- preload_key = get_tts_cache_key(str(model_ids[0]), text, reference_audio_path)
747
- preload_path = os.path.join(PRELOAD_CACHE_DIR, f"{preload_key}.wav")
748
- shutil.copy(audio_files[0], preload_path)
749
- app.logger.info(f"Preloaded cache audio saved: {preload_path}")
750
-
751
- preload_key = get_tts_cache_key(str(model_ids[1]), text, reference_audio_path)
752
- preload_path = os.path.join(PRELOAD_CACHE_DIR, f"{preload_key}.wav")
753
- shutil.copy(audio_files[1], preload_path)
754
- app.logger.info(f"Preloaded cache audio saved: {preload_path}")
755
-
756
- # Return audio file paths and session
757
- return jsonify(
758
- {
759
- "session_id": session_id,
760
- "audio_a": f"/api/tts/audio/{session_id}/a",
761
- "audio_b": f"/api/tts/audio/{session_id}/b",
762
- "expires_in": 1800,
763
- "cache_hit": False,
764
- }
765
- )
766
-
767
- except Exception as e:
768
- app.logger.error(f"TTS on-the-fly generation error: {str(e)}", exc_info=True)
769
- # Cleanup any files potentially created during the failed attempt
770
- if 'results' in locals():
771
- for res in results:
772
- if 'audio_path' in res and os.path.exists(res['audio_path']):
773
- try:
774
- os.remove(res['audio_path'])
775
- except OSError:
776
- pass
777
- # 清理临时参考音频文件
778
- if reference_audio_path and os.path.exists(reference_audio_path):
779
- os.remove(reference_audio_path)
780
- return jsonify({"error": f"Failed to generate TTS:{str(e)}"}), 500
781
  # --- End Cache Miss ---
782
 
783
 
@@ -1360,6 +1296,15 @@ def get_tts_cache_key(model_name, text, prompt_audio_path):
1360
  return hashlib.md5(key_str.encode('utf-8')).hexdigest()
1361
 
1362
 
 
 
 
 
 
 
 
 
 
1363
  if __name__ == "__main__":
1364
  with app.app_context():
1365
  # Ensure ./instance and ./votes directories exist
 
6
  import threading # Added for locking
7
 
8
  from huggingface_hub.hf_api import RepoFile
9
+ from pydub import AudioSegment, silence
10
  from sqlalchemy import or_ # Added for vote counting query
11
  import hashlib
12
+ import numpy as np
13
+ import wave
14
 
15
  year = datetime.now().year
16
  month = datetime.now().month
 
670
  if len(available_models) < 2:
671
  return jsonify({"error": "Not enough TTS models available"}), 500
672
 
673
+ # 新增:a和b模型都需通过缓存和静音检测
674
+ candidate_models = available_models.copy()
675
+ random.shuffle(candidate_models)
676
+ valid_pairs = []
677
+ # 枚举所有模型对,找到第一个都通过的组合
678
+ for i in range(len(candidate_models)):
679
+ for j in range(len(candidate_models)):
680
+ if i == j:
681
+ continue
682
+ model_a = candidate_models[i]
683
+ model_b = candidate_models[j]
684
+ audio_a_path = find_cached_audio(str(model_a.id), text, reference_audio_path)
685
+ audio_b_path = find_cached_audio(str(model_b.id), text, reference_audio_path)
686
+ if (audio_a_path and os.path.exists(audio_a_path)
687
+ and not has_long_silence(audio_a_path)
688
+ and audio_b_path and os.path.exists(audio_b_path)
689
+ and not has_long_silence(audio_b_path)):
690
+ valid_pairs.append((model_a, audio_a_path, model_b, audio_b_path))
691
+ if not valid_pairs:
692
+ return jsonify({"error": "所有模型均未通过持久化缓存和静音检测,无法生成音频"}), 500
693
+
694
+ # 随机选一个合格组合
695
+ model_a, audio_a_path, model_b, audio_b_path = random.choice(valid_pairs)
696
+ session_id = str(uuid.uuid4())
697
+ app.tts_sessions[session_id] = {
698
+ "model_a": model_a.id,
699
+ "model_b": model_b.id,
700
+ "audio_a": audio_a_path,
701
+ "audio_b": audio_b_path,
702
+ "text": text,
703
+ "created_at": datetime.utcnow(),
704
+ "expires_at": datetime.utcnow() + timedelta(minutes=30),
705
+ "voted": False,
706
+ }
707
+ # 清理临时参考音频文件
708
+ if reference_audio_path and os.path.exists(reference_audio_path):
709
+ os.remove(reference_audio_path)
710
+ return jsonify({
711
+ "session_id": session_id,
712
+ "audio_a": f"/api/tts/audio/{session_id}/a",
713
+ "audio_b": f"/api/tts/audio/{session_id}/b",
714
+ "expires_in": 1800,
715
+ "cache_hit": True,
716
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
717
  # --- End Cache Miss ---
718
 
719
 
 
1296
  return hashlib.md5(key_str.encode('utf-8')).hexdigest()
1297
 
1298
 
1299
+ def has_long_silence(audio_path, min_silence_len_ms=10000, silence_thresh_db=-40):
1300
+ try:
1301
+ audio = AudioSegment.from_file(audio_path)
1302
+ silent_ranges = silence.detect_silence(audio, min_silence_len=min_silence_len_ms, silence_thresh=silence_thresh_db)
1303
+ return len(silent_ranges) > 0
1304
+ except Exception as e:
1305
+ print(f"无法分析音频文件 {audio_path}: {e}")
1306
+ return False
1307
+
1308
  if __name__ == "__main__":
1309
  with app.app_context():
1310
  # Ensure ./instance and ./votes directories exist