Spaces:
Running
Running
Commit
·
8ed90ba
1
Parent(s):
3fb07cc
(wip)update file cache pipeline
Browse files
app.py
CHANGED
@@ -4,6 +4,8 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|
4 |
from concurrent.futures import ThreadPoolExecutor
|
5 |
from datetime import datetime
|
6 |
import threading # Added for locking
|
|
|
|
|
7 |
from sqlalchemy import or_ # Added for vote counting query
|
8 |
import hashlib
|
9 |
|
@@ -42,7 +44,6 @@ from flask import (
|
|
42 |
redirect,
|
43 |
url_for,
|
44 |
session,
|
45 |
-
abort,
|
46 |
)
|
47 |
from flask_login import LoginManager, current_user
|
48 |
from models import *
|
@@ -61,8 +62,6 @@ import json
|
|
61 |
from datetime import datetime, timedelta
|
62 |
from flask_migrate import Migrate
|
63 |
import requests
|
64 |
-
import functools
|
65 |
-
import time # Added for potential retries
|
66 |
|
67 |
# Load environment variables
|
68 |
if not IS_SPACES:
|
@@ -118,6 +117,7 @@ TTS_CACHE_SIZE = int(os.getenv("TTS_CACHE_SIZE", "10"))
|
|
118 |
CACHE_AUDIO_SUBDIR = "cache"
|
119 |
tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at}
|
120 |
tts_cache_lock = threading.Lock()
|
|
|
121 |
SMOOTHING_FACTOR_MODEL_SELECTION = 500 # For weighted random model selection
|
122 |
# Increased max_workers to 8 for concurrent generation/refill
|
123 |
cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer')
|
@@ -371,6 +371,7 @@ with open("init_sentences.txt", "r") as f:
|
|
371 |
initial_sentences = random.sample(all_harvard_sentences,
|
372 |
min(len(all_harvard_sentences), 500)) # Limit initial pass for template
|
373 |
|
|
|
374 |
@app.route("/")
|
375 |
def arena():
|
376 |
# Pass a subset of sentences for the random button fallback
|
@@ -616,6 +617,11 @@ def generate_tts():
|
|
616 |
if not text or len(text) > 1000:
|
617 |
return jsonify({"error": "Invalid or too long text"}), 400
|
618 |
|
|
|
|
|
|
|
|
|
|
|
619 |
# --- Cache Check ---
|
620 |
cache_hit = False
|
621 |
session_data_from_cache = None
|
@@ -662,7 +668,31 @@ def generate_tts():
|
|
662 |
return jsonify({"error": "Not enough TTS models available"}), 500
|
663 |
|
664 |
selected_models = get_weighted_random_models(available_models, 2, ModelType.TTS)
|
|
|
|
|
|
|
665 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
666 |
try:
|
667 |
audio_files = []
|
668 |
model_ids = []
|
@@ -716,15 +746,16 @@ def generate_tts():
|
|
716 |
|
717 |
# Check if text and prompt are in predefined libraries
|
718 |
if text in predefined_texts and prompt_md5 in predefined_prompts.values():
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
|
|
723 |
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
|
729 |
# Return audio file paths and session
|
730 |
return jsonify(
|
@@ -1120,98 +1151,105 @@ def setup_periodic_tasks():
|
|
1120 |
同步缓存音频到HF dataset并从HF下载更新的缓存��频。
|
1121 |
"""
|
1122 |
os.makedirs(PRELOAD_CACHE_DIR, exist_ok=True)
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
# 获取带有 etag 的文件列表
|
1127 |
-
files_info = api.list_repo_files(repo_id=REFERENCE_AUDIO_DATASET, repo_type="dataset", expand=True)
|
1128 |
-
# 只处理cache_audios/下的wav文件
|
1129 |
-
wav_files = [f for f in files_info if
|
1130 |
-
f["rfilename"].startswith(CACHE_AUDIO_PATTERN) and f["rfilename"].endswith(".wav")]
|
1131 |
-
|
1132 |
-
# 获取本地已有文件名及hash集合
|
1133 |
-
local_hashes = {}
|
1134 |
-
for root, _, filenames in os.walk(PRELOAD_CACHE_DIR):
|
1135 |
-
for fname in filenames:
|
1136 |
-
if fname.endswith(".wav"):
|
1137 |
-
rel_path = os.path.relpath(os.path.join(root, fname), PRELOAD_CACHE_DIR)
|
1138 |
-
remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
|
1139 |
-
local_file_path = os.path.join(root, fname)
|
1140 |
-
# 计算本地文件md5
|
1141 |
-
try:
|
1142 |
-
with open(local_file_path, 'rb') as f:
|
1143 |
-
md5 = hashlib.md5(f.read()).hexdigest()
|
1144 |
-
local_hashes[remote_path] = md5
|
1145 |
-
except Exception:
|
1146 |
-
continue
|
1147 |
-
|
1148 |
-
download_count = 0
|
1149 |
-
for f in wav_files:
|
1150 |
-
remote_path = f["rfilename"]
|
1151 |
-
etag = f.get("lfs", {}).get("oid") or f.get("etag") # 优先lfs oid, 其次etag
|
1152 |
-
local_md5 = local_hashes.get(remote_path)
|
1153 |
-
|
1154 |
-
# 如果远端etag为32位md5且与本地一致,跳过下载
|
1155 |
-
if etag and len(etag) == 32 and local_md5 == etag:
|
1156 |
-
continue
|
1157 |
-
|
1158 |
-
# 下载文件
|
1159 |
-
local_path = hf_hub_download(
|
1160 |
-
repo_id=REFERENCE_AUDIO_DATASET,
|
1161 |
-
filename=remote_path,
|
1162 |
-
repo_type="dataset",
|
1163 |
-
local_dir=PRELOAD_CACHE_DIR,
|
1164 |
-
token=os.getenv("HF_TOKEN"),
|
1165 |
-
force_download=True if local_md5 else False
|
1166 |
-
)
|
1167 |
-
print(f"Downloaded cache audio: {local_path}")
|
1168 |
-
download_count += 1
|
1169 |
|
1170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1171 |
|
1172 |
-
|
1173 |
-
for root, _, files in os.walk(PRELOAD_CACHE_DIR):
|
1174 |
-
for file in files:
|
1175 |
-
if file.endswith('.wav'):
|
1176 |
-
local_path = os.path.join(root, file)
|
1177 |
-
rel_path = os.path.relpath(local_path, PRELOAD_CACHE_DIR)
|
1178 |
-
remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
|
1179 |
|
1180 |
-
|
1181 |
-
|
1182 |
-
|
1183 |
-
|
|
|
|
|
|
|
1184 |
|
1185 |
-
# 尝试获取远程文件信息
|
1186 |
try:
|
1187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1188 |
repo_id=REFERENCE_AUDIO_DATASET,
|
1189 |
repo_type="dataset",
|
1190 |
-
path
|
1191 |
-
|
1192 |
-
|
1193 |
-
|
1194 |
-
|
1195 |
-
continue
|
1196 |
-
except Exception:
|
1197 |
-
pass
|
1198 |
|
1199 |
-
|
1200 |
-
|
1201 |
-
|
1202 |
-
path_or_fileobj=local_path,
|
1203 |
-
path_in_repo=remote_path,
|
1204 |
-
repo_id=REFERENCE_AUDIO_DATASET,
|
1205 |
-
repo_type="dataset",
|
1206 |
-
commit_message=f"Upload preload cache file: {os.path.basename(file)}"
|
1207 |
-
)
|
1208 |
-
app.logger.info(f"Successfully uploaded {remote_path}")
|
1209 |
-
except Exception as e:
|
1210 |
-
app.logger.error(f"Error uploading {remote_path}: {str(e)}")
|
1211 |
-
|
1212 |
-
except Exception as e:
|
1213 |
-
print(f"Error syncing cache audios with HF: {e}")
|
1214 |
-
app.logger.error(f"Error syncing cache audios with HF: {e}")
|
1215 |
|
1216 |
# Schedule periodic tasks
|
1217 |
scheduler = BackgroundScheduler()
|
@@ -1354,9 +1392,6 @@ def get_tts_cache_key(model_name, text, prompt_audio_path):
|
|
1354 |
return hashlib.md5(key_str.encode('utf-8')).hexdigest()
|
1355 |
|
1356 |
|
1357 |
-
|
1358 |
-
|
1359 |
-
|
1360 |
if __name__ == "__main__":
|
1361 |
with app.app_context():
|
1362 |
# Ensure ./instance and ./votes directories exist
|
|
|
4 |
from concurrent.futures import ThreadPoolExecutor
|
5 |
from datetime import datetime
|
6 |
import threading # Added for locking
|
7 |
+
|
8 |
+
from huggingface_hub.hf_api import RepoFile
|
9 |
from sqlalchemy import or_ # Added for vote counting query
|
10 |
import hashlib
|
11 |
|
|
|
44 |
redirect,
|
45 |
url_for,
|
46 |
session,
|
|
|
47 |
)
|
48 |
from flask_login import LoginManager, current_user
|
49 |
from models import *
|
|
|
62 |
from datetime import datetime, timedelta
|
63 |
from flask_migrate import Migrate
|
64 |
import requests
|
|
|
|
|
65 |
|
66 |
# Load environment variables
|
67 |
if not IS_SPACES:
|
|
|
117 |
CACHE_AUDIO_SUBDIR = "cache"
|
118 |
tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at}
|
119 |
tts_cache_lock = threading.Lock()
|
120 |
+
preload_cache_lock = threading.Lock()
|
121 |
SMOOTHING_FACTOR_MODEL_SELECTION = 500 # For weighted random model selection
|
122 |
# Increased max_workers to 8 for concurrent generation/refill
|
123 |
cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer')
|
|
|
371 |
initial_sentences = random.sample(all_harvard_sentences,
|
372 |
min(len(all_harvard_sentences), 500)) # Limit initial pass for template
|
373 |
|
374 |
+
|
375 |
@app.route("/")
|
376 |
def arena():
|
377 |
# Pass a subset of sentences for the random button fallback
|
|
|
617 |
if not text or len(text) > 1000:
|
618 |
return jsonify({"error": "Invalid or too long text"}), 400
|
619 |
|
620 |
+
prompt_md5 = ''
|
621 |
+
if reference_audio_path and os.path.exists(reference_audio_path):
|
622 |
+
with open(reference_audio_path, 'rb') as f:
|
623 |
+
prompt_md5 = hashlib.md5(f.read()).hexdigest()
|
624 |
+
|
625 |
# --- Cache Check ---
|
626 |
cache_hit = False
|
627 |
session_data_from_cache = None
|
|
|
668 |
return jsonify({"error": "Not enough TTS models available"}), 500
|
669 |
|
670 |
selected_models = get_weighted_random_models(available_models, 2, ModelType.TTS)
|
671 |
+
# 尝试从持久化缓存中查找两个模型的音频
|
672 |
+
audio_a_path = find_cached_audio(str(selected_models[0].id), text, reference_audio_path)
|
673 |
+
audio_b_path = find_cached_audio(str(selected_models[1].id), text, reference_audio_path)
|
674 |
|
675 |
+
if audio_a_path and audio_b_path:
|
676 |
+
app.logger.info(f"Persistent Cache HIT for: '{text[:50]}...'. Using files directly.")
|
677 |
+
session_id = str(uuid.uuid4())
|
678 |
+
app.tts_sessions[session_id] = {
|
679 |
+
"model_a": selected_models[0].id,
|
680 |
+
"model_b": selected_models[1].id,
|
681 |
+
"audio_a": audio_a_path,
|
682 |
+
"audio_b": audio_b_path,
|
683 |
+
"text": text,
|
684 |
+
"created_at": datetime.utcnow(),
|
685 |
+
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
686 |
+
"voted": False,
|
687 |
+
}
|
688 |
+
return jsonify({
|
689 |
+
"session_id": session_id,
|
690 |
+
"audio_a": f"/api/tts/audio/{session_id}/a",
|
691 |
+
"audio_b": f"/api/tts/audio/{session_id}/b",
|
692 |
+
"expires_in": 1800,
|
693 |
+
"cache_hit": True, # 可以认为这也是一种缓存命中
|
694 |
+
})
|
695 |
+
# --- 持久化缓存检查结束 ---
|
696 |
try:
|
697 |
audio_files = []
|
698 |
model_ids = []
|
|
|
746 |
|
747 |
# Check if text and prompt are in predefined libraries
|
748 |
if text in predefined_texts and prompt_md5 in predefined_prompts.values():
|
749 |
+
with preload_cache_lock:
|
750 |
+
preload_key = get_tts_cache_key(str(model_ids[0]), text, reference_audio_path)
|
751 |
+
preload_path = os.path.join(PRELOAD_CACHE_DIR, f"{preload_key}.wav")
|
752 |
+
shutil.copy(audio_files[0], preload_path)
|
753 |
+
app.logger.info(f"Preloaded cache audio saved: {preload_path}")
|
754 |
|
755 |
+
preload_key = get_tts_cache_key(str(model_ids[1]), text, reference_audio_path)
|
756 |
+
preload_path = os.path.join(PRELOAD_CACHE_DIR, f"{preload_key}.wav")
|
757 |
+
shutil.copy(audio_files[1], preload_path)
|
758 |
+
app.logger.info(f"Preloaded cache audio saved: {preload_path}")
|
759 |
|
760 |
# Return audio file paths and session
|
761 |
return jsonify(
|
|
|
1151 |
同步缓存音频到HF dataset并从HF下载更新的缓存��频。
|
1152 |
"""
|
1153 |
os.makedirs(PRELOAD_CACHE_DIR, exist_ok=True)
|
1154 |
+
with preload_cache_lock:
|
1155 |
+
try:
|
1156 |
+
api = HfApi(token=os.getenv("HF_TOKEN"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1157 |
|
1158 |
+
# 获取带有 etag 的文件列表
|
1159 |
+
files_info = [
|
1160 |
+
f
|
1161 |
+
for f in api.list_repo_tree(
|
1162 |
+
repo_id=REFERENCE_AUDIO_DATASET, path_in_repo=CACHE_AUDIO_PATTERN.strip("/"), recursive=True,
|
1163 |
+
repo_type="dataset", expand=True
|
1164 |
+
)
|
1165 |
+
if isinstance(f, RepoFile)
|
1166 |
+
]
|
1167 |
+
# 只处理cache_audios/下的wav文件
|
1168 |
+
wav_files = [f for f in files_info if
|
1169 |
+
f.path.endswith(".wav")]
|
1170 |
+
|
1171 |
+
# 获取本地已有文件名及hash集合
|
1172 |
+
local_hashes = {}
|
1173 |
+
for root, _, filenames in os.walk(PRELOAD_CACHE_DIR):
|
1174 |
+
for fname in filenames:
|
1175 |
+
if fname.endswith(".wav"):
|
1176 |
+
rel_path = os.path.relpath(os.path.join(root, fname), PRELOAD_CACHE_DIR)
|
1177 |
+
remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
|
1178 |
+
local_file_path = os.path.join(root, fname)
|
1179 |
+
# 计算本地文件md5
|
1180 |
+
try:
|
1181 |
+
with open(local_file_path, 'rb') as f:
|
1182 |
+
file_hash = hashlib.sha256(f.read()).hexdigest()
|
1183 |
+
local_hashes[remote_path] = file_hash
|
1184 |
+
except Exception:
|
1185 |
+
continue
|
1186 |
+
|
1187 |
+
download_count = 0
|
1188 |
+
for f in wav_files:
|
1189 |
+
remote_path = f.path
|
1190 |
+
etag = f.lfs.sha256 if f.lfs else None
|
1191 |
+
local_hash = local_hashes.get(remote_path)
|
1192 |
+
|
1193 |
+
# 如果远端etag为32位md5且与本地一致,跳过下载
|
1194 |
+
if local_hash == etag:
|
1195 |
+
continue
|
1196 |
+
|
1197 |
+
# 下载文件
|
1198 |
+
local_path = hf_hub_download(
|
1199 |
+
repo_id=REFERENCE_AUDIO_DATASET,
|
1200 |
+
filename=remote_path,
|
1201 |
+
repo_type="dataset",
|
1202 |
+
local_dir=PRELOAD_CACHE_DIR,
|
1203 |
+
token=os.getenv("HF_TOKEN"),
|
1204 |
+
force_download=True if local_hash else False
|
1205 |
+
)
|
1206 |
+
print(f"Downloaded cache audio: {local_path}")
|
1207 |
+
download_count += 1
|
1208 |
|
1209 |
+
print(f"Downloaded {download_count} new/updated cache audios from HF to {PRELOAD_CACHE_DIR}")
|
|
|
|
|
|
|
|
|
|
|
|
|
1210 |
|
1211 |
+
# 上传本地文件到HF dataset
|
1212 |
+
for root, _, files in os.walk(PRELOAD_CACHE_DIR):
|
1213 |
+
for file in files:
|
1214 |
+
if file.endswith('.wav'):
|
1215 |
+
local_path = os.path.join(root, file)
|
1216 |
+
rel_path = os.path.relpath(local_path, PRELOAD_CACHE_DIR)
|
1217 |
+
remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
|
1218 |
|
|
|
1219 |
try:
|
1220 |
+
# 计算本地文件MD5,用于检查是否需要上传
|
1221 |
+
with open(local_path, 'rb') as f:
|
1222 |
+
file_hash = hashlib.sha256(f.read()).hexdigest()
|
1223 |
+
|
1224 |
+
# 尝试获取远程文件信息
|
1225 |
+
try:
|
1226 |
+
remote_info = api.get_paths_info(
|
1227 |
+
repo_id=REFERENCE_AUDIO_DATASET,
|
1228 |
+
repo_type="dataset",
|
1229 |
+
path=[remote_path],expand=True)
|
1230 |
+
remote_etag = remote_info[0].lfs.sha256 if remote_info and remote_info[0].lfs else None
|
1231 |
+
# 如果远程文件存在且hash相同,则跳过
|
1232 |
+
if remote_etag and remote_etag == file_hash:
|
1233 |
+
app.logger.debug(f"Skipping upload for {remote_path}: file unchanged")
|
1234 |
+
continue
|
1235 |
+
except Exception as e:
|
1236 |
+
app.logger.warning(f"Could not get remote info for {remote_path}: {str(e)}")
|
1237 |
+
# 上传文件
|
1238 |
+
app.logger.info(f"Uploading preload cache file: {remote_path}")
|
1239 |
+
api.upload_file(
|
1240 |
+
path_or_fileobj=local_path,
|
1241 |
+
path_in_repo=remote_path,
|
1242 |
repo_id=REFERENCE_AUDIO_DATASET,
|
1243 |
repo_type="dataset",
|
1244 |
+
commit_message=f"Upload preload cache file: {os.path.basename(file)}"
|
1245 |
+
)
|
1246 |
+
app.logger.info(f"Successfully uploaded {remote_path}")
|
1247 |
+
except Exception as e:
|
1248 |
+
app.logger.error(f"Error uploading {remote_path}: {str(e)}")
|
|
|
|
|
|
|
1249 |
|
1250 |
+
except Exception as e:
|
1251 |
+
print(f"Error syncing cache audios with HF: {e}")
|
1252 |
+
app.logger.error(f"Error syncing cache audios with HF: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1253 |
|
1254 |
# Schedule periodic tasks
|
1255 |
scheduler = BackgroundScheduler()
|
|
|
1392 |
return hashlib.md5(key_str.encode('utf-8')).hexdigest()
|
1393 |
|
1394 |
|
|
|
|
|
|
|
1395 |
if __name__ == "__main__":
|
1396 |
with app.app_context():
|
1397 |
# Ensure ./instance and ./votes directories exist
|