Spaces:
Running
Running
Commit
·
1bd39bb
1
Parent(s):
a528625
update for random ref voice
Browse files- app.py +8 -254
- templates/arena.html +0 -0
app.py
CHANGED
@@ -850,260 +850,6 @@ def cleanup_session(session_id):
|
|
850 |
# Remove session
|
851 |
del app.tts_sessions[session_id]
|
852 |
|
853 |
-
|
854 |
-
@app.route("/api/conversational/generate", methods=["POST"])
|
855 |
-
@limiter.limit("5 per minute")
|
856 |
-
def generate_podcast():
|
857 |
-
# If verification not setup, handle it first
|
858 |
-
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
|
859 |
-
return jsonify({"error": "Turnstile verification required"}), 403
|
860 |
-
|
861 |
-
data = request.json
|
862 |
-
script = data.get("script")
|
863 |
-
|
864 |
-
if not script or not isinstance(script, list) or len(script) < 2:
|
865 |
-
return jsonify({"error": "Invalid script format or too short"}), 400
|
866 |
-
|
867 |
-
# Validate script format
|
868 |
-
for line in script:
|
869 |
-
if not isinstance(line, dict) or "text" not in line or "speaker_id" not in line:
|
870 |
-
return (
|
871 |
-
jsonify(
|
872 |
-
{
|
873 |
-
"error": "Invalid script line format. Each line must have text and speaker_id"
|
874 |
-
}
|
875 |
-
),
|
876 |
-
400,
|
877 |
-
)
|
878 |
-
if (
|
879 |
-
not line["text"]
|
880 |
-
or not isinstance(line["speaker_id"], int)
|
881 |
-
or line["speaker_id"] not in [0, 1]
|
882 |
-
):
|
883 |
-
return (
|
884 |
-
jsonify({"error": "Invalid script content. Speaker ID must be 0 or 1"}),
|
885 |
-
400,
|
886 |
-
)
|
887 |
-
|
888 |
-
# Get two conversational models (currently only CSM and PlayDialog)
|
889 |
-
available_models = Model.query.filter_by(
|
890 |
-
model_type=ModelType.CONVERSATIONAL, is_active=True
|
891 |
-
).all()
|
892 |
-
|
893 |
-
if len(available_models) < 2:
|
894 |
-
return jsonify({"error": "Not enough conversational models available"}), 500
|
895 |
-
|
896 |
-
selected_models = get_weighted_random_models(available_models, 2, ModelType.CONVERSATIONAL)
|
897 |
-
|
898 |
-
try:
|
899 |
-
# Generate audio for both models concurrently
|
900 |
-
audio_files = []
|
901 |
-
model_ids = []
|
902 |
-
|
903 |
-
# Function to process a single model
|
904 |
-
def process_model(model):
|
905 |
-
# Call conversational TTS service
|
906 |
-
audio_content = predict_tts(script, model.id)
|
907 |
-
|
908 |
-
# Save to temp file with unique name
|
909 |
-
file_uuid = str(uuid.uuid4())
|
910 |
-
dest_path = os.path.join(TEMP_AUDIO_DIR, f"{file_uuid}.wav")
|
911 |
-
|
912 |
-
with open(dest_path, "wb") as f:
|
913 |
-
f.write(audio_content)
|
914 |
-
|
915 |
-
return {"model_id": model.id, "audio_path": dest_path}
|
916 |
-
|
917 |
-
# Use ThreadPoolExecutor to process models concurrently
|
918 |
-
with ThreadPoolExecutor(max_workers=2) as executor:
|
919 |
-
results = list(executor.map(process_model, selected_models))
|
920 |
-
|
921 |
-
# Extract results
|
922 |
-
for result in results:
|
923 |
-
model_ids.append(result["model_id"])
|
924 |
-
audio_files.append(result["audio_path"])
|
925 |
-
|
926 |
-
# Create session
|
927 |
-
session_id = str(uuid.uuid4())
|
928 |
-
script_text = " ".join([line["text"] for line in script])
|
929 |
-
app.conversational_sessions[session_id] = {
|
930 |
-
"model_a": model_ids[0],
|
931 |
-
"model_b": model_ids[1],
|
932 |
-
"audio_a": audio_files[0],
|
933 |
-
"audio_b": audio_files[1],
|
934 |
-
"text": script_text[:1000], # Limit text length
|
935 |
-
"created_at": datetime.utcnow(),
|
936 |
-
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
937 |
-
"voted": False,
|
938 |
-
"script": script,
|
939 |
-
}
|
940 |
-
|
941 |
-
# Return audio file paths and session
|
942 |
-
return jsonify(
|
943 |
-
{
|
944 |
-
"session_id": session_id,
|
945 |
-
"audio_a": f"/api/conversational/audio/{session_id}/a",
|
946 |
-
"audio_b": f"/api/conversational/audio/{session_id}/b",
|
947 |
-
"expires_in": 1800, # 30 minutes in seconds
|
948 |
-
}
|
949 |
-
)
|
950 |
-
|
951 |
-
except Exception as e:
|
952 |
-
app.logger.error(f"Conversational generation error: {str(e)}")
|
953 |
-
return jsonify({"error": f"Failed to generate podcast: {str(e)}"}), 500
|
954 |
-
|
955 |
-
|
956 |
-
@app.route("/api/conversational/audio/<session_id>/<model_key>")
|
957 |
-
def get_podcast_audio(session_id, model_key):
|
958 |
-
# If verification not setup, handle it first
|
959 |
-
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
|
960 |
-
return jsonify({"error": "Turnstile verification required"}), 403
|
961 |
-
|
962 |
-
if session_id not in app.conversational_sessions:
|
963 |
-
return jsonify({"error": "Invalid or expired session"}), 404
|
964 |
-
|
965 |
-
session_data = app.conversational_sessions[session_id]
|
966 |
-
|
967 |
-
# Check if session expired
|
968 |
-
if datetime.utcnow() > session_data["expires_at"]:
|
969 |
-
cleanup_conversational_session(session_id)
|
970 |
-
return jsonify({"error": "Session expired"}), 410
|
971 |
-
|
972 |
-
if model_key == "a":
|
973 |
-
audio_path = session_data["audio_a"]
|
974 |
-
elif model_key == "b":
|
975 |
-
audio_path = session_data["audio_b"]
|
976 |
-
else:
|
977 |
-
return jsonify({"error": "Invalid model key"}), 400
|
978 |
-
|
979 |
-
# Check if file exists
|
980 |
-
if not os.path.exists(audio_path):
|
981 |
-
return jsonify({"error": "Audio file not found"}), 404
|
982 |
-
|
983 |
-
return send_file(audio_path, mimetype="audio/wav")
|
984 |
-
|
985 |
-
|
986 |
-
@app.route("/api/conversational/vote", methods=["POST"])
|
987 |
-
@limiter.limit("30 per minute")
|
988 |
-
def submit_podcast_vote():
|
989 |
-
# If verification not setup, handle it first
|
990 |
-
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
|
991 |
-
return jsonify({"error": "Turnstile verification required"}), 403
|
992 |
-
|
993 |
-
data = request.json
|
994 |
-
session_id = data.get("session_id")
|
995 |
-
chosen_model_key = data.get("chosen_model") # "a" or "b"
|
996 |
-
|
997 |
-
if not session_id or session_id not in app.conversational_sessions:
|
998 |
-
return jsonify({"error": "Invalid or expired session"}), 404
|
999 |
-
|
1000 |
-
if not chosen_model_key or chosen_model_key not in ["a", "b"]:
|
1001 |
-
return jsonify({"error": "Invalid chosen model"}), 400
|
1002 |
-
|
1003 |
-
session_data = app.conversational_sessions[session_id]
|
1004 |
-
|
1005 |
-
# Check if session expired
|
1006 |
-
if datetime.utcnow() > session_data["expires_at"]:
|
1007 |
-
cleanup_conversational_session(session_id)
|
1008 |
-
return jsonify({"error": "Session expired"}), 410
|
1009 |
-
|
1010 |
-
# Check if already voted
|
1011 |
-
if session_data["voted"]:
|
1012 |
-
return jsonify({"error": "Vote already submitted for this session"}), 400
|
1013 |
-
|
1014 |
-
# Get model IDs and audio paths
|
1015 |
-
chosen_id = (
|
1016 |
-
session_data["model_a"] if chosen_model_key == "a" else session_data["model_b"]
|
1017 |
-
)
|
1018 |
-
rejected_id = (
|
1019 |
-
session_data["model_b"] if chosen_model_key == "a" else session_data["model_a"]
|
1020 |
-
)
|
1021 |
-
chosen_audio_path = (
|
1022 |
-
session_data["audio_a"] if chosen_model_key == "a" else session_data["audio_b"]
|
1023 |
-
)
|
1024 |
-
rejected_audio_path = (
|
1025 |
-
session_data["audio_b"] if chosen_model_key == "a" else session_data["audio_a"]
|
1026 |
-
)
|
1027 |
-
|
1028 |
-
# Record vote in database
|
1029 |
-
user_id = current_user.id if current_user.is_authenticated else None
|
1030 |
-
vote, error = record_vote(
|
1031 |
-
user_id, session_data["text"], chosen_id, rejected_id, ModelType.CONVERSATIONAL
|
1032 |
-
)
|
1033 |
-
|
1034 |
-
if error:
|
1035 |
-
return jsonify({"error": error}), 500
|
1036 |
-
|
1037 |
-
# --- Save preference data ---\
|
1038 |
-
try:
|
1039 |
-
vote_uuid = str(uuid.uuid4())
|
1040 |
-
vote_dir = os.path.join("./votes", vote_uuid)
|
1041 |
-
os.makedirs(vote_dir, exist_ok=True)
|
1042 |
-
|
1043 |
-
# Copy audio files
|
1044 |
-
shutil.copy(chosen_audio_path, os.path.join(vote_dir, "chosen.wav"))
|
1045 |
-
shutil.copy(rejected_audio_path, os.path.join(vote_dir, "rejected.wav"))
|
1046 |
-
|
1047 |
-
# Create metadata
|
1048 |
-
chosen_model_obj = Model.query.get(chosen_id)
|
1049 |
-
rejected_model_obj = Model.query.get(rejected_id)
|
1050 |
-
metadata = {
|
1051 |
-
"script": session_data["script"], # Save the full script
|
1052 |
-
"chosen_model": chosen_model_obj.name if chosen_model_obj else "Unknown",
|
1053 |
-
"chosen_model_id": chosen_model_obj.id if chosen_model_obj else "Unknown",
|
1054 |
-
"rejected_model": rejected_model_obj.name if rejected_model_obj else "Unknown",
|
1055 |
-
"rejected_model_id": rejected_model_obj.id if rejected_model_obj else "Unknown",
|
1056 |
-
"session_id": session_id,
|
1057 |
-
"timestamp": datetime.utcnow().isoformat(),
|
1058 |
-
"username": current_user.username if current_user.is_authenticated else None,
|
1059 |
-
"model_type": "CONVERSATIONAL"
|
1060 |
-
}
|
1061 |
-
with open(os.path.join(vote_dir, "metadata.json"), "w") as f:
|
1062 |
-
json.dump(metadata, f, indent=2)
|
1063 |
-
|
1064 |
-
except Exception as e:
|
1065 |
-
app.logger.error(f"Error saving preference data for conversational vote {session_id}: {str(e)}")
|
1066 |
-
# Continue even if saving preference data fails, vote is already recorded
|
1067 |
-
|
1068 |
-
# Mark session as voted
|
1069 |
-
session_data["voted"] = True
|
1070 |
-
|
1071 |
-
# Return updated models (use previously fetched objects)
|
1072 |
-
return jsonify(
|
1073 |
-
{
|
1074 |
-
"success": True,
|
1075 |
-
"chosen_model": {"id": chosen_id, "name": chosen_model_obj.name if chosen_model_obj else "Unknown"},
|
1076 |
-
"rejected_model": {
|
1077 |
-
"id": rejected_id,
|
1078 |
-
"name": rejected_model_obj.name if rejected_model_obj else "Unknown",
|
1079 |
-
},
|
1080 |
-
"names": {
|
1081 |
-
"a": Model.query.get(session_data["model_a"]).name,
|
1082 |
-
"b": Model.query.get(session_data["model_b"]).name,
|
1083 |
-
},
|
1084 |
-
}
|
1085 |
-
)
|
1086 |
-
|
1087 |
-
|
1088 |
-
def cleanup_conversational_session(session_id):
|
1089 |
-
"""Remove conversational session and its audio files"""
|
1090 |
-
if session_id in app.conversational_sessions:
|
1091 |
-
session = app.conversational_sessions[session_id]
|
1092 |
-
|
1093 |
-
# Remove audio files
|
1094 |
-
for audio_file in [session["audio_a"], session["audio_b"]]:
|
1095 |
-
if os.path.exists(audio_file):
|
1096 |
-
try:
|
1097 |
-
os.remove(audio_file)
|
1098 |
-
except Exception as e:
|
1099 |
-
app.logger.error(
|
1100 |
-
f"Error removing conversational audio file: {str(e)}"
|
1101 |
-
)
|
1102 |
-
|
1103 |
-
# Remove session
|
1104 |
-
del app.conversational_sessions[session_id]
|
1105 |
-
|
1106 |
-
|
1107 |
# Schedule periodic cleanup
|
1108 |
def setup_cleanup():
|
1109 |
def cleanup_expired_sessions():
|
@@ -1375,6 +1121,14 @@ def get_reference_audio(filename):
|
|
1375 |
return jsonify({"error": "Reference audio not found"}), 404
|
1376 |
return send_file(file_path, mimetype="audio/wav")
|
1377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1378 |
|
1379 |
def get_weighted_random_models(
|
1380 |
applicable_models: list[Model], num_to_select: int, model_type: ModelType
|
|
|
850 |
# Remove session
|
851 |
del app.tts_sessions[session_id]
|
852 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
853 |
# Schedule periodic cleanup
|
854 |
def setup_cleanup():
|
855 |
def cleanup_expired_sessions():
|
|
|
1121 |
return jsonify({"error": "Reference audio not found"}), 404
|
1122 |
return send_file(file_path, mimetype="audio/wav")
|
1123 |
|
1124 |
+
@app.route('/api/voice/random', methods=['GET'])
|
1125 |
+
def get_random_voice():
|
1126 |
+
# 随机选择一个音频文件
|
1127 |
+
random_voice = random.choice(reference_audio_files)
|
1128 |
+
voice_path = os.path.join(REFERENCE_AUDIO_DIR, random_voice)
|
1129 |
+
|
1130 |
+
# 返回音频文件
|
1131 |
+
return send_file(voice_path, mimetype='audio/' + voice_path.split('.')[-1])
|
1132 |
|
1133 |
def get_weighted_random_models(
|
1134 |
applicable_models: list[Model], num_to_select: int, model_type: ModelType
|
templates/arena.html
CHANGED
The diff for this file is too large to render.
See raw diff
|
|