kemuriririn commited on
Commit
1bd39bb
·
1 Parent(s): a528625

update for random ref voice

Browse files
Files changed (2) hide show
  1. app.py +8 -254
  2. templates/arena.html +0 -0
app.py CHANGED
@@ -850,260 +850,6 @@ def cleanup_session(session_id):
850
  # Remove session
851
  del app.tts_sessions[session_id]
852
 
853
-
854
- @app.route("/api/conversational/generate", methods=["POST"])
855
- @limiter.limit("5 per minute")
856
- def generate_podcast():
857
- # If verification not setup, handle it first
858
- if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
859
- return jsonify({"error": "Turnstile verification required"}), 403
860
-
861
- data = request.json
862
- script = data.get("script")
863
-
864
- if not script or not isinstance(script, list) or len(script) < 2:
865
- return jsonify({"error": "Invalid script format or too short"}), 400
866
-
867
- # Validate script format
868
- for line in script:
869
- if not isinstance(line, dict) or "text" not in line or "speaker_id" not in line:
870
- return (
871
- jsonify(
872
- {
873
- "error": "Invalid script line format. Each line must have text and speaker_id"
874
- }
875
- ),
876
- 400,
877
- )
878
- if (
879
- not line["text"]
880
- or not isinstance(line["speaker_id"], int)
881
- or line["speaker_id"] not in [0, 1]
882
- ):
883
- return (
884
- jsonify({"error": "Invalid script content. Speaker ID must be 0 or 1"}),
885
- 400,
886
- )
887
-
888
- # Get two conversational models (currently only CSM and PlayDialog)
889
- available_models = Model.query.filter_by(
890
- model_type=ModelType.CONVERSATIONAL, is_active=True
891
- ).all()
892
-
893
- if len(available_models) < 2:
894
- return jsonify({"error": "Not enough conversational models available"}), 500
895
-
896
- selected_models = get_weighted_random_models(available_models, 2, ModelType.CONVERSATIONAL)
897
-
898
- try:
899
- # Generate audio for both models concurrently
900
- audio_files = []
901
- model_ids = []
902
-
903
- # Function to process a single model
904
- def process_model(model):
905
- # Call conversational TTS service
906
- audio_content = predict_tts(script, model.id)
907
-
908
- # Save to temp file with unique name
909
- file_uuid = str(uuid.uuid4())
910
- dest_path = os.path.join(TEMP_AUDIO_DIR, f"{file_uuid}.wav")
911
-
912
- with open(dest_path, "wb") as f:
913
- f.write(audio_content)
914
-
915
- return {"model_id": model.id, "audio_path": dest_path}
916
-
917
- # Use ThreadPoolExecutor to process models concurrently
918
- with ThreadPoolExecutor(max_workers=2) as executor:
919
- results = list(executor.map(process_model, selected_models))
920
-
921
- # Extract results
922
- for result in results:
923
- model_ids.append(result["model_id"])
924
- audio_files.append(result["audio_path"])
925
-
926
- # Create session
927
- session_id = str(uuid.uuid4())
928
- script_text = " ".join([line["text"] for line in script])
929
- app.conversational_sessions[session_id] = {
930
- "model_a": model_ids[0],
931
- "model_b": model_ids[1],
932
- "audio_a": audio_files[0],
933
- "audio_b": audio_files[1],
934
- "text": script_text[:1000], # Limit text length
935
- "created_at": datetime.utcnow(),
936
- "expires_at": datetime.utcnow() + timedelta(minutes=30),
937
- "voted": False,
938
- "script": script,
939
- }
940
-
941
- # Return audio file paths and session
942
- return jsonify(
943
- {
944
- "session_id": session_id,
945
- "audio_a": f"/api/conversational/audio/{session_id}/a",
946
- "audio_b": f"/api/conversational/audio/{session_id}/b",
947
- "expires_in": 1800, # 30 minutes in seconds
948
- }
949
- )
950
-
951
- except Exception as e:
952
- app.logger.error(f"Conversational generation error: {str(e)}")
953
- return jsonify({"error": f"Failed to generate podcast: {str(e)}"}), 500
954
-
955
-
956
- @app.route("/api/conversational/audio/<session_id>/<model_key>")
957
- def get_podcast_audio(session_id, model_key):
958
- # If verification not setup, handle it first
959
- if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
960
- return jsonify({"error": "Turnstile verification required"}), 403
961
-
962
- if session_id not in app.conversational_sessions:
963
- return jsonify({"error": "Invalid or expired session"}), 404
964
-
965
- session_data = app.conversational_sessions[session_id]
966
-
967
- # Check if session expired
968
- if datetime.utcnow() > session_data["expires_at"]:
969
- cleanup_conversational_session(session_id)
970
- return jsonify({"error": "Session expired"}), 410
971
-
972
- if model_key == "a":
973
- audio_path = session_data["audio_a"]
974
- elif model_key == "b":
975
- audio_path = session_data["audio_b"]
976
- else:
977
- return jsonify({"error": "Invalid model key"}), 400
978
-
979
- # Check if file exists
980
- if not os.path.exists(audio_path):
981
- return jsonify({"error": "Audio file not found"}), 404
982
-
983
- return send_file(audio_path, mimetype="audio/wav")
984
-
985
-
986
- @app.route("/api/conversational/vote", methods=["POST"])
987
- @limiter.limit("30 per minute")
988
- def submit_podcast_vote():
989
- # If verification not setup, handle it first
990
- if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
991
- return jsonify({"error": "Turnstile verification required"}), 403
992
-
993
- data = request.json
994
- session_id = data.get("session_id")
995
- chosen_model_key = data.get("chosen_model") # "a" or "b"
996
-
997
- if not session_id or session_id not in app.conversational_sessions:
998
- return jsonify({"error": "Invalid or expired session"}), 404
999
-
1000
- if not chosen_model_key or chosen_model_key not in ["a", "b"]:
1001
- return jsonify({"error": "Invalid chosen model"}), 400
1002
-
1003
- session_data = app.conversational_sessions[session_id]
1004
-
1005
- # Check if session expired
1006
- if datetime.utcnow() > session_data["expires_at"]:
1007
- cleanup_conversational_session(session_id)
1008
- return jsonify({"error": "Session expired"}), 410
1009
-
1010
- # Check if already voted
1011
- if session_data["voted"]:
1012
- return jsonify({"error": "Vote already submitted for this session"}), 400
1013
-
1014
- # Get model IDs and audio paths
1015
- chosen_id = (
1016
- session_data["model_a"] if chosen_model_key == "a" else session_data["model_b"]
1017
- )
1018
- rejected_id = (
1019
- session_data["model_b"] if chosen_model_key == "a" else session_data["model_a"]
1020
- )
1021
- chosen_audio_path = (
1022
- session_data["audio_a"] if chosen_model_key == "a" else session_data["audio_b"]
1023
- )
1024
- rejected_audio_path = (
1025
- session_data["audio_b"] if chosen_model_key == "a" else session_data["audio_a"]
1026
- )
1027
-
1028
- # Record vote in database
1029
- user_id = current_user.id if current_user.is_authenticated else None
1030
- vote, error = record_vote(
1031
- user_id, session_data["text"], chosen_id, rejected_id, ModelType.CONVERSATIONAL
1032
- )
1033
-
1034
- if error:
1035
- return jsonify({"error": error}), 500
1036
-
1037
- # --- Save preference data ---\
1038
- try:
1039
- vote_uuid = str(uuid.uuid4())
1040
- vote_dir = os.path.join("./votes", vote_uuid)
1041
- os.makedirs(vote_dir, exist_ok=True)
1042
-
1043
- # Copy audio files
1044
- shutil.copy(chosen_audio_path, os.path.join(vote_dir, "chosen.wav"))
1045
- shutil.copy(rejected_audio_path, os.path.join(vote_dir, "rejected.wav"))
1046
-
1047
- # Create metadata
1048
- chosen_model_obj = Model.query.get(chosen_id)
1049
- rejected_model_obj = Model.query.get(rejected_id)
1050
- metadata = {
1051
- "script": session_data["script"], # Save the full script
1052
- "chosen_model": chosen_model_obj.name if chosen_model_obj else "Unknown",
1053
- "chosen_model_id": chosen_model_obj.id if chosen_model_obj else "Unknown",
1054
- "rejected_model": rejected_model_obj.name if rejected_model_obj else "Unknown",
1055
- "rejected_model_id": rejected_model_obj.id if rejected_model_obj else "Unknown",
1056
- "session_id": session_id,
1057
- "timestamp": datetime.utcnow().isoformat(),
1058
- "username": current_user.username if current_user.is_authenticated else None,
1059
- "model_type": "CONVERSATIONAL"
1060
- }
1061
- with open(os.path.join(vote_dir, "metadata.json"), "w") as f:
1062
- json.dump(metadata, f, indent=2)
1063
-
1064
- except Exception as e:
1065
- app.logger.error(f"Error saving preference data for conversational vote {session_id}: {str(e)}")
1066
- # Continue even if saving preference data fails, vote is already recorded
1067
-
1068
- # Mark session as voted
1069
- session_data["voted"] = True
1070
-
1071
- # Return updated models (use previously fetched objects)
1072
- return jsonify(
1073
- {
1074
- "success": True,
1075
- "chosen_model": {"id": chosen_id, "name": chosen_model_obj.name if chosen_model_obj else "Unknown"},
1076
- "rejected_model": {
1077
- "id": rejected_id,
1078
- "name": rejected_model_obj.name if rejected_model_obj else "Unknown",
1079
- },
1080
- "names": {
1081
- "a": Model.query.get(session_data["model_a"]).name,
1082
- "b": Model.query.get(session_data["model_b"]).name,
1083
- },
1084
- }
1085
- )
1086
-
1087
-
1088
- def cleanup_conversational_session(session_id):
1089
- """Remove conversational session and its audio files"""
1090
- if session_id in app.conversational_sessions:
1091
- session = app.conversational_sessions[session_id]
1092
-
1093
- # Remove audio files
1094
- for audio_file in [session["audio_a"], session["audio_b"]]:
1095
- if os.path.exists(audio_file):
1096
- try:
1097
- os.remove(audio_file)
1098
- except Exception as e:
1099
- app.logger.error(
1100
- f"Error removing conversational audio file: {str(e)}"
1101
- )
1102
-
1103
- # Remove session
1104
- del app.conversational_sessions[session_id]
1105
-
1106
-
1107
  # Schedule periodic cleanup
1108
  def setup_cleanup():
1109
  def cleanup_expired_sessions():
@@ -1375,6 +1121,14 @@ def get_reference_audio(filename):
1375
  return jsonify({"error": "Reference audio not found"}), 404
1376
  return send_file(file_path, mimetype="audio/wav")
1377
 
 
 
 
 
 
 
 
 
1378
 
1379
  def get_weighted_random_models(
1380
  applicable_models: list[Model], num_to_select: int, model_type: ModelType
 
850
  # Remove session
851
  del app.tts_sessions[session_id]
852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853
  # Schedule periodic cleanup
854
  def setup_cleanup():
855
  def cleanup_expired_sessions():
 
1121
  return jsonify({"error": "Reference audio not found"}), 404
1122
  return send_file(file_path, mimetype="audio/wav")
1123
 
1124
+ @app.route('/api/voice/random', methods=['GET'])
1125
+ def get_random_voice():
1126
+ # 随机选择一个音频文件
1127
+ random_voice = random.choice(reference_audio_files)
1128
+ voice_path = os.path.join(REFERENCE_AUDIO_DIR, random_voice)
1129
+
1130
+ # 返回音频文件
1131
+ return send_file(voice_path, mimetype='audio/' + voice_path.split('.')[-1])
1132
 
1133
  def get_weighted_random_models(
1134
  applicable_models: list[Model], num_to_select: int, model_type: ModelType
templates/arena.html CHANGED
The diff for this file is too large to render. See raw diff