kemuriririn commited on
Commit
8ed90ba
·
1 Parent(s): 3fb07cc

(wip)update file cache pipeline

Browse files
Files changed (1) hide show
  1. app.py +133 -98
app.py CHANGED
@@ -4,6 +4,8 @@ from apscheduler.schedulers.background import BackgroundScheduler
4
  from concurrent.futures import ThreadPoolExecutor
5
  from datetime import datetime
6
  import threading # Added for locking
 
 
7
  from sqlalchemy import or_ # Added for vote counting query
8
  import hashlib
9
 
@@ -42,7 +44,6 @@ from flask import (
42
  redirect,
43
  url_for,
44
  session,
45
- abort,
46
  )
47
  from flask_login import LoginManager, current_user
48
  from models import *
@@ -61,8 +62,6 @@ import json
61
  from datetime import datetime, timedelta
62
  from flask_migrate import Migrate
63
  import requests
64
- import functools
65
- import time # Added for potential retries
66
 
67
  # Load environment variables
68
  if not IS_SPACES:
@@ -118,6 +117,7 @@ TTS_CACHE_SIZE = int(os.getenv("TTS_CACHE_SIZE", "10"))
118
  CACHE_AUDIO_SUBDIR = "cache"
119
  tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at}
120
  tts_cache_lock = threading.Lock()
 
121
  SMOOTHING_FACTOR_MODEL_SELECTION = 500 # For weighted random model selection
122
  # Increased max_workers to 8 for concurrent generation/refill
123
  cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer')
@@ -371,6 +371,7 @@ with open("init_sentences.txt", "r") as f:
371
  initial_sentences = random.sample(all_harvard_sentences,
372
  min(len(all_harvard_sentences), 500)) # Limit initial pass for template
373
 
 
374
  @app.route("/")
375
  def arena():
376
  # Pass a subset of sentences for the random button fallback
@@ -616,6 +617,11 @@ def generate_tts():
616
  if not text or len(text) > 1000:
617
  return jsonify({"error": "Invalid or too long text"}), 400
618
 
 
 
 
 
 
619
  # --- Cache Check ---
620
  cache_hit = False
621
  session_data_from_cache = None
@@ -662,7 +668,31 @@ def generate_tts():
662
  return jsonify({"error": "Not enough TTS models available"}), 500
663
 
664
  selected_models = get_weighted_random_models(available_models, 2, ModelType.TTS)
 
 
 
665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
  try:
667
  audio_files = []
668
  model_ids = []
@@ -716,15 +746,16 @@ def generate_tts():
716
 
717
  # Check if text and prompt are in predefined libraries
718
  if text in predefined_texts and prompt_md5 in predefined_prompts.values():
719
- preload_key = get_tts_cache_key(str(model_ids[0]), text, reference_audio_path)
720
- preload_path = os.path.join(PRELOAD_CACHE_DIR, f"{preload_key}.wav")
721
- shutil.copy(audio_files[0], preload_path)
722
- app.logger.info(f"Preloaded cache audio saved: {preload_path}")
 
723
 
724
- preload_key = get_tts_cache_key(str(model_ids[1]), text, reference_audio_path)
725
- preload_path = os.path.join(PRELOAD_CACHE_DIR, f"{preload_key}.wav")
726
- shutil.copy(audio_files[1], preload_path)
727
- app.logger.info(f"Preloaded cache audio saved: {preload_path}")
728
 
729
  # Return audio file paths and session
730
  return jsonify(
@@ -1120,98 +1151,105 @@ def setup_periodic_tasks():
1120
  同步缓存音频到HF dataset并从HF下载更新的缓存��频。
1121
  """
1122
  os.makedirs(PRELOAD_CACHE_DIR, exist_ok=True)
1123
- try:
1124
- api = HfApi(token=os.getenv("HF_TOKEN"))
1125
-
1126
- # 获取带有 etag 的文件列表
1127
- files_info = api.list_repo_files(repo_id=REFERENCE_AUDIO_DATASET, repo_type="dataset", expand=True)
1128
- # 只处理cache_audios/下的wav文件
1129
- wav_files = [f for f in files_info if
1130
- f["rfilename"].startswith(CACHE_AUDIO_PATTERN) and f["rfilename"].endswith(".wav")]
1131
-
1132
- # 获取本地已有文件名及hash集合
1133
- local_hashes = {}
1134
- for root, _, filenames in os.walk(PRELOAD_CACHE_DIR):
1135
- for fname in filenames:
1136
- if fname.endswith(".wav"):
1137
- rel_path = os.path.relpath(os.path.join(root, fname), PRELOAD_CACHE_DIR)
1138
- remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
1139
- local_file_path = os.path.join(root, fname)
1140
- # 计算本地文件md5
1141
- try:
1142
- with open(local_file_path, 'rb') as f:
1143
- md5 = hashlib.md5(f.read()).hexdigest()
1144
- local_hashes[remote_path] = md5
1145
- except Exception:
1146
- continue
1147
-
1148
- download_count = 0
1149
- for f in wav_files:
1150
- remote_path = f["rfilename"]
1151
- etag = f.get("lfs", {}).get("oid") or f.get("etag") # 优先lfs oid, 其次etag
1152
- local_md5 = local_hashes.get(remote_path)
1153
-
1154
- # 如果远端etag为32位md5且与本地一致,跳过下载
1155
- if etag and len(etag) == 32 and local_md5 == etag:
1156
- continue
1157
-
1158
- # 下载文件
1159
- local_path = hf_hub_download(
1160
- repo_id=REFERENCE_AUDIO_DATASET,
1161
- filename=remote_path,
1162
- repo_type="dataset",
1163
- local_dir=PRELOAD_CACHE_DIR,
1164
- token=os.getenv("HF_TOKEN"),
1165
- force_download=True if local_md5 else False
1166
- )
1167
- print(f"Downloaded cache audio: {local_path}")
1168
- download_count += 1
1169
 
1170
- print(f"Downloaded {download_count} new/updated cache audios from HF to {PRELOAD_CACHE_DIR}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1171
 
1172
- # 上传本地文件到HF dataset
1173
- for root, _, files in os.walk(PRELOAD_CACHE_DIR):
1174
- for file in files:
1175
- if file.endswith('.wav'):
1176
- local_path = os.path.join(root, file)
1177
- rel_path = os.path.relpath(local_path, PRELOAD_CACHE_DIR)
1178
- remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
1179
 
1180
- try:
1181
- # 计算本地文件MD5,用于检查是否需要上传
1182
- with open(local_path, 'rb') as f:
1183
- file_md5 = hashlib.md5(f.read()).hexdigest()
 
 
 
1184
 
1185
- # 尝试获取远程文件信息
1186
  try:
1187
- remote_info = api.get_file_info(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1188
  repo_id=REFERENCE_AUDIO_DATASET,
1189
  repo_type="dataset",
1190
- path=remote_path)
1191
- remote_etag = remote_info.etag or remote_info.lfs.get("oid", "")
1192
- # 如果远程文件存在且hash相同,则跳过
1193
- if remote_etag and remote_etag == file_md5:
1194
- app.logger.debug(f"Skipping upload for {remote_path}: file unchanged")
1195
- continue
1196
- except Exception:
1197
- pass
1198
 
1199
- # 上传文件
1200
- app.logger.info(f"Uploading preload cache file: {remote_path}")
1201
- api.upload_file(
1202
- path_or_fileobj=local_path,
1203
- path_in_repo=remote_path,
1204
- repo_id=REFERENCE_AUDIO_DATASET,
1205
- repo_type="dataset",
1206
- commit_message=f"Upload preload cache file: {os.path.basename(file)}"
1207
- )
1208
- app.logger.info(f"Successfully uploaded {remote_path}")
1209
- except Exception as e:
1210
- app.logger.error(f"Error uploading {remote_path}: {str(e)}")
1211
-
1212
- except Exception as e:
1213
- print(f"Error syncing cache audios with HF: {e}")
1214
- app.logger.error(f"Error syncing cache audios with HF: {e}")
1215
 
1216
  # Schedule periodic tasks
1217
  scheduler = BackgroundScheduler()
@@ -1354,9 +1392,6 @@ def get_tts_cache_key(model_name, text, prompt_audio_path):
1354
  return hashlib.md5(key_str.encode('utf-8')).hexdigest()
1355
 
1356
 
1357
-
1358
-
1359
-
1360
  if __name__ == "__main__":
1361
  with app.app_context():
1362
  # Ensure ./instance and ./votes directories exist
 
4
  from concurrent.futures import ThreadPoolExecutor
5
  from datetime import datetime
6
  import threading # Added for locking
7
+
8
+ from huggingface_hub.hf_api import RepoFile
9
  from sqlalchemy import or_ # Added for vote counting query
10
  import hashlib
11
 
 
44
  redirect,
45
  url_for,
46
  session,
 
47
  )
48
  from flask_login import LoginManager, current_user
49
  from models import *
 
62
  from datetime import datetime, timedelta
63
  from flask_migrate import Migrate
64
  import requests
 
 
65
 
66
  # Load environment variables
67
  if not IS_SPACES:
 
117
  CACHE_AUDIO_SUBDIR = "cache"
118
  tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at}
119
  tts_cache_lock = threading.Lock()
120
+ preload_cache_lock = threading.Lock()
121
  SMOOTHING_FACTOR_MODEL_SELECTION = 500 # For weighted random model selection
122
  # Increased max_workers to 8 for concurrent generation/refill
123
  cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer')
 
371
  initial_sentences = random.sample(all_harvard_sentences,
372
  min(len(all_harvard_sentences), 500)) # Limit initial pass for template
373
 
374
+
375
  @app.route("/")
376
  def arena():
377
  # Pass a subset of sentences for the random button fallback
 
617
  if not text or len(text) > 1000:
618
  return jsonify({"error": "Invalid or too long text"}), 400
619
 
620
+ prompt_md5 = ''
621
+ if reference_audio_path and os.path.exists(reference_audio_path):
622
+ with open(reference_audio_path, 'rb') as f:
623
+ prompt_md5 = hashlib.md5(f.read()).hexdigest()
624
+
625
  # --- Cache Check ---
626
  cache_hit = False
627
  session_data_from_cache = None
 
668
  return jsonify({"error": "Not enough TTS models available"}), 500
669
 
670
  selected_models = get_weighted_random_models(available_models, 2, ModelType.TTS)
671
+ # 尝试从持久化缓存中查找两个模型的音频
672
+ audio_a_path = find_cached_audio(str(selected_models[0].id), text, reference_audio_path)
673
+ audio_b_path = find_cached_audio(str(selected_models[1].id), text, reference_audio_path)
674
 
675
+ if audio_a_path and audio_b_path:
676
+ app.logger.info(f"Persistent Cache HIT for: '{text[:50]}...'. Using files directly.")
677
+ session_id = str(uuid.uuid4())
678
+ app.tts_sessions[session_id] = {
679
+ "model_a": selected_models[0].id,
680
+ "model_b": selected_models[1].id,
681
+ "audio_a": audio_a_path,
682
+ "audio_b": audio_b_path,
683
+ "text": text,
684
+ "created_at": datetime.utcnow(),
685
+ "expires_at": datetime.utcnow() + timedelta(minutes=30),
686
+ "voted": False,
687
+ }
688
+ return jsonify({
689
+ "session_id": session_id,
690
+ "audio_a": f"/api/tts/audio/{session_id}/a",
691
+ "audio_b": f"/api/tts/audio/{session_id}/b",
692
+ "expires_in": 1800,
693
+ "cache_hit": True, # 可以认为这也是一种缓存命中
694
+ })
695
+ # --- 持久化缓存检查结束 ---
696
  try:
697
  audio_files = []
698
  model_ids = []
 
746
 
747
  # Check if text and prompt are in predefined libraries
748
  if text in predefined_texts and prompt_md5 in predefined_prompts.values():
749
+ with preload_cache_lock:
750
+ preload_key = get_tts_cache_key(str(model_ids[0]), text, reference_audio_path)
751
+ preload_path = os.path.join(PRELOAD_CACHE_DIR, f"{preload_key}.wav")
752
+ shutil.copy(audio_files[0], preload_path)
753
+ app.logger.info(f"Preloaded cache audio saved: {preload_path}")
754
 
755
+ preload_key = get_tts_cache_key(str(model_ids[1]), text, reference_audio_path)
756
+ preload_path = os.path.join(PRELOAD_CACHE_DIR, f"{preload_key}.wav")
757
+ shutil.copy(audio_files[1], preload_path)
758
+ app.logger.info(f"Preloaded cache audio saved: {preload_path}")
759
 
760
  # Return audio file paths and session
761
  return jsonify(
 
1151
  同步缓存音频到HF dataset并从HF下载更新的缓存��频。
1152
  """
1153
  os.makedirs(PRELOAD_CACHE_DIR, exist_ok=True)
1154
+ with preload_cache_lock:
1155
+ try:
1156
+ api = HfApi(token=os.getenv("HF_TOKEN"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1157
 
1158
+ # 获取带有 etag 的文件列表
1159
+ files_info = [
1160
+ f
1161
+ for f in api.list_repo_tree(
1162
+ repo_id=REFERENCE_AUDIO_DATASET, path_in_repo=CACHE_AUDIO_PATTERN.strip("/"), recursive=True,
1163
+ repo_type="dataset", expand=True
1164
+ )
1165
+ if isinstance(f, RepoFile)
1166
+ ]
1167
+ # 只处理cache_audios/下的wav文件
1168
+ wav_files = [f for f in files_info if
1169
+ f.path.endswith(".wav")]
1170
+
1171
+ # 获取本地已有文件名及hash集合
1172
+ local_hashes = {}
1173
+ for root, _, filenames in os.walk(PRELOAD_CACHE_DIR):
1174
+ for fname in filenames:
1175
+ if fname.endswith(".wav"):
1176
+ rel_path = os.path.relpath(os.path.join(root, fname), PRELOAD_CACHE_DIR)
1177
+ remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
1178
+ local_file_path = os.path.join(root, fname)
1179
+ # 计算本地文件md5
1180
+ try:
1181
+ with open(local_file_path, 'rb') as f:
1182
+ file_hash = hashlib.sha256(f.read()).hexdigest()
1183
+ local_hashes[remote_path] = file_hash
1184
+ except Exception:
1185
+ continue
1186
+
1187
+ download_count = 0
1188
+ for f in wav_files:
1189
+ remote_path = f.path
1190
+ etag = f.lfs.sha256 if f.lfs else None
1191
+ local_hash = local_hashes.get(remote_path)
1192
+
1193
+ # 如果远端etag为32位md5且与本地一致,跳过下载
1194
+ if local_hash == etag:
1195
+ continue
1196
+
1197
+ # 下载文件
1198
+ local_path = hf_hub_download(
1199
+ repo_id=REFERENCE_AUDIO_DATASET,
1200
+ filename=remote_path,
1201
+ repo_type="dataset",
1202
+ local_dir=PRELOAD_CACHE_DIR,
1203
+ token=os.getenv("HF_TOKEN"),
1204
+ force_download=True if local_hash else False
1205
+ )
1206
+ print(f"Downloaded cache audio: {local_path}")
1207
+ download_count += 1
1208
 
1209
+ print(f"Downloaded {download_count} new/updated cache audios from HF to {PRELOAD_CACHE_DIR}")
 
 
 
 
 
 
1210
 
1211
+ # 上传本地文件到HF dataset
1212
+ for root, _, files in os.walk(PRELOAD_CACHE_DIR):
1213
+ for file in files:
1214
+ if file.endswith('.wav'):
1215
+ local_path = os.path.join(root, file)
1216
+ rel_path = os.path.relpath(local_path, PRELOAD_CACHE_DIR)
1217
+ remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
1218
 
 
1219
  try:
1220
+ # 计算本地文件MD5,用于检查是否需要上传
1221
+ with open(local_path, 'rb') as f:
1222
+ file_hash = hashlib.sha256(f.read()).hexdigest()
1223
+
1224
+ # 尝试获取远程文件信息
1225
+ try:
1226
+ remote_info = api.get_paths_info(
1227
+ repo_id=REFERENCE_AUDIO_DATASET,
1228
+ repo_type="dataset",
1229
+ path=[remote_path],expand=True)
1230
+ remote_etag = remote_info[0].lfs.sha256 if remote_info and remote_info[0].lfs else None
1231
+ # 如果远程文件存在且hash相同,则跳过
1232
+ if remote_etag and remote_etag == file_hash:
1233
+ app.logger.debug(f"Skipping upload for {remote_path}: file unchanged")
1234
+ continue
1235
+ except Exception as e:
1236
+ app.logger.warning(f"Could not get remote info for {remote_path}: {str(e)}")
1237
+ # 上传文件
1238
+ app.logger.info(f"Uploading preload cache file: {remote_path}")
1239
+ api.upload_file(
1240
+ path_or_fileobj=local_path,
1241
+ path_in_repo=remote_path,
1242
  repo_id=REFERENCE_AUDIO_DATASET,
1243
  repo_type="dataset",
1244
+ commit_message=f"Upload preload cache file: {os.path.basename(file)}"
1245
+ )
1246
+ app.logger.info(f"Successfully uploaded {remote_path}")
1247
+ except Exception as e:
1248
+ app.logger.error(f"Error uploading {remote_path}: {str(e)}")
 
 
 
1249
 
1250
+ except Exception as e:
1251
+ print(f"Error syncing cache audios with HF: {e}")
1252
+ app.logger.error(f"Error syncing cache audios with HF: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
1253
 
1254
  # Schedule periodic tasks
1255
  scheduler = BackgroundScheduler()
 
1392
  return hashlib.md5(key_str.encode('utf-8')).hexdigest()
1393
 
1394
 
 
 
 
1395
  if __name__ == "__main__":
1396
  with app.app_context():
1397
  # Ensure ./instance and ./votes directories exist