kemuriririn commited on
Commit
58071a6
·
1 Parent(s): 253ae60

(wip)debug

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. templates/arena.html +0 -7
  3. tts.py +17 -31
app.py CHANGED
@@ -692,7 +692,7 @@ def generate_tts():
692
  # 清理临时参考音频文件
693
  if reference_audio_path and os.path.exists(reference_audio_path):
694
  os.remove(reference_audio_path)
695
- return jsonify({"error": "Failed to generate TTS"}), 500
696
  # --- End Cache Miss ---
697
 
698
 
 
692
  # 清理临时参考音频文件
693
  if reference_audio_path and os.path.exists(reference_audio_path):
694
  os.remove(reference_audio_path)
695
+ return jsonify({"error": f"Failed to generate TTS:{str(e)}"}), 500
696
  # --- End Cache Miss ---
697
 
698
 
templates/arena.html CHANGED
@@ -869,13 +869,6 @@
869
  return response.json();
870
  })
871
  .then(data => {
872
- if (data.error) {
873
- // 显示错误信息并重置界面
874
- loadingContainer.style.display = 'none';
875
- initialKeyboardHint.style.display = 'block';
876
- openToast(data.error, "error");
877
- return;
878
- }
879
  currentSessionId = data.session_id;
880
 
881
  // Load audio in waveplayers
 
869
  return response.json();
870
  })
871
  .then(data => {
 
 
 
 
 
 
 
872
  currentSessionId = data.session_id;
873
 
874
  // Load audio in waveplayers
tts.py CHANGED
@@ -47,7 +47,7 @@ data = {"text": "string", "provider": "string", "model": "string"}
47
 
48
  def predict_index_tts(text, reference_audio_path=None):
49
  from gradio_client import Client, handle_file
50
- client = Client("kemuriririn/IndexTTS",hf_token=os.getenv("HF_TOKEN"))
51
  if reference_audio_path:
52
  prompt = handle_file(reference_audio_path)
53
  else:
@@ -65,7 +65,7 @@ def predict_index_tts(text, reference_audio_path=None):
65
 
66
  def predict_spark_tts(text, reference_audio_path=None):
67
  from gradio_client import Client, handle_file
68
- client = Client("kemuriririn/SparkTTS",hf_token=os.getenv("HF_TOKEN"))
69
  prompt_wav = None
70
  if reference_audio_path:
71
  prompt_wav = handle_file(reference_audio_path)
@@ -82,7 +82,7 @@ def predict_spark_tts(text, reference_audio_path=None):
82
 
83
  def predict_cosyvoice_tts(text, reference_audio_path=None):
84
  from gradio_client import Client, file, handle_file
85
- client = Client("kemuriririn/CosyVoice2-0.5B",hf_token=os.getenv("HF_TOKEN"))
86
  if not reference_audio_path:
87
  raise ValueError("cosyvoice-2.0 需要 reference_audio_path")
88
  prompt_wav = handle_file(reference_audio_path)
@@ -125,7 +125,7 @@ def predict_maskgct(text, reference_audio_path=None):
125
 
126
  def predict_gpt_sovits_v2(text, reference_audio_path=None):
127
  from gradio_client import Client, file
128
- client = Client("kemuriririn/GPT-SoVITS-v2",hf_token=os.getenv("HF_TOKEN"))
129
  if not reference_audio_path:
130
  raise ValueError("GPT-SoVITS-v2 需要 reference_audio_path")
131
  result = client.predict(
@@ -152,33 +152,19 @@ def predict_tts(text, model, reference_audio_path=None):
152
  global client
153
  print(f"Predicting TTS for {model}")
154
  # Exceptions: special models that shouldn't be passed to the router
155
- try:
156
- if model == "index-tts":
157
- result = predict_index_tts(text, reference_audio_path)
158
- elif model == "spark-tts":
159
- result = predict_spark_tts(text, reference_audio_path)
160
- elif model == "cosyvoice-2.0":
161
- result = predict_cosyvoice_tts(text, reference_audio_path)
162
- elif model == "maskgct":
163
- result = predict_maskgct(text, reference_audio_path)
164
- elif model == "gpt-sovits-v2":
165
- result = predict_gpt_sovits_v2(text, reference_audio_path)
166
- else:
167
- raise ValueError(f"Model {model} not found")
168
-
169
- if isinstance(result, dict) and "error" in result:
170
- return result
171
-
172
- return result
173
- except AppError as e:
174
- error_message = str(e)
175
- print(f"Gradio客户端错误: {error_message}")
176
- return {"error": error_message}
177
- except Exception as e:
178
- error_message = str(e)
179
- print(f"生成失败: {error_message}")
180
- return {"error": error_message}
181
-
182
 
183
  if __name__ == "__main__":
184
  pass
 
47
 
48
  def predict_index_tts(text, reference_audio_path=None):
49
  from gradio_client import Client, handle_file
50
+ client = Client("kemuriririn/IndexTTS", hf_token=os.getenv("HF_TOKEN"))
51
  if reference_audio_path:
52
  prompt = handle_file(reference_audio_path)
53
  else:
 
65
 
66
  def predict_spark_tts(text, reference_audio_path=None):
67
  from gradio_client import Client, handle_file
68
+ client = Client("kemuriririn/SparkTTS", hf_token=os.getenv("HF_TOKEN"))
69
  prompt_wav = None
70
  if reference_audio_path:
71
  prompt_wav = handle_file(reference_audio_path)
 
82
 
83
  def predict_cosyvoice_tts(text, reference_audio_path=None):
84
  from gradio_client import Client, file, handle_file
85
+ client = Client("kemuriririn/CosyVoice2-0.5B", hf_token=os.getenv("HF_TOKEN"))
86
  if not reference_audio_path:
87
  raise ValueError("cosyvoice-2.0 需要 reference_audio_path")
88
  prompt_wav = handle_file(reference_audio_path)
 
125
 
126
  def predict_gpt_sovits_v2(text, reference_audio_path=None):
127
  from gradio_client import Client, file
128
+ client = Client("kemuriririn/GPT-SoVITS-v2", hf_token=os.getenv("HF_TOKEN"))
129
  if not reference_audio_path:
130
  raise ValueError("GPT-SoVITS-v2 需要 reference_audio_path")
131
  result = client.predict(
 
152
  global client
153
  print(f"Predicting TTS for {model}")
154
  # Exceptions: special models that shouldn't be passed to the router
155
+ if model == "index-tts":
156
+ result = predict_index_tts(text, reference_audio_path)
157
+ elif model == "spark-tts":
158
+ result = predict_spark_tts(text, reference_audio_path)
159
+ elif model == "cosyvoice-2.0":
160
+ result = predict_cosyvoice_tts(text, reference_audio_path)
161
+ elif model == "maskgct":
162
+ result = predict_maskgct(text, reference_audio_path)
163
+ elif model == "gpt-sovits-v2":
164
+ result = predict_gpt_sovits_v2(text, reference_audio_path)
165
+ else:
166
+ raise ValueError(f"Model {model} not found")
167
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  if __name__ == "__main__":
170
  pass