kemuriririn's picture
(wip)debug
f55b556
import os
from dotenv import load_dotenv
import random
from fal_client import stream
from gradio_client.exceptions import AppError
load_dotenv()
ZEROGPU_TOKENS = os.getenv("ZEROGPU_TOKENS", "").split(",")
def get_zerogpu_token():
return random.choice(ZEROGPU_TOKENS)
model_mapping = {
"spark-tts": {
"provider": "spark",
"model": "spark-tts",
},
"cosyvoice-2.0": {
"provider": "cosyvoice",
"model": "cosyvoice_2_0",
},
"index-tts": {
"provider": "bilibili",
"model": "index-tts",
},
"maskgct": {
"provider": "amphion",
"model": "maskgct",
},
"gpt-sovits-v2": {
"provider": "gpt-sovits",
"model": "gpt-sovits-v2",
},
}
url = "https://tts-agi-tts-router-v2.hf.space/tts"
headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f'Bearer {os.getenv("HF_TOKEN")}',
}
data = {"text": "string", "provider": "string", "model": "string"}
def predict_index_tts(text, reference_audio_path=None):
from gradio_client import Client, handle_file
client = Client("kemuriririn/IndexTTS",verbose=True)
if reference_audio_path:
prompt = handle_file(reference_audio_path)
else:
raise ValueError("index-tts ιœ€θ¦ reference_audio_path")
result = client.predict(
prompt=prompt,
text=text,
api_name="/gen_single"
)
if type(result) != str:
result = result.get("value")
print("index-tts result:", result)
return result
def predict_spark_tts(text, reference_audio_path=None):
from gradio_client import Client, handle_file
client = Client("kemuriririn/SparkTTS")
prompt_wav = None
if reference_audio_path:
prompt_wav = handle_file(reference_audio_path)
result = client.predict(
text=text,
prompt_text=text,
prompt_wav_upload=prompt_wav,
prompt_wav_record=prompt_wav,
api_name="/voice_clone"
)
print("spark-tts result:", result)
return result
def predict_cosyvoice_tts(text, reference_audio_path=None):
from gradio_client import Client, file, handle_file
client = Client("kemuriririn/CosyVoice2-0.5B")
if not reference_audio_path:
raise ValueError("cosyvoice-2.0 ιœ€θ¦ reference_audio_path")
prompt_wav = handle_file(reference_audio_path)
# ε…ˆθ―†εˆ«ε‚θ€ƒιŸ³ι’‘ζ–‡ζœ¬
recog_result = client.predict(
prompt_wav=file(reference_audio_path),
api_name="/prompt_wav_recognition"
)
print("cosyvoice-2.0 prompt_wav_recognition result:", recog_result)
prompt_text = recog_result if isinstance(recog_result, str) else str(recog_result)
result = client.predict(
tts_text=text,
prompt_text=prompt_text,
prompt_wav_upload=prompt_wav,
prompt_wav_record=prompt_wav,
seed=0,
stream=False,
api_name="/generate_audio"
)
print("cosyvoice-2.0 result:", result)
return result
def predict_maskgct(text, reference_audio_path=None):
from gradio_client import Client, handle_file
client = Client("amphion/maskgct")
if not reference_audio_path:
raise ValueError("maskgct ιœ€θ¦ reference_audio_path")
prompt_wav = handle_file(reference_audio_path)
result = client.predict(
prompt_wav=prompt_wav,
target_text=text,
target_len=-1,
n_timesteps=25,
api_name="/predict"
)
print("maskgct result:", result)
return result
def predict_gpt_sovits_v2(text, reference_audio_path=None):
from gradio_client import Client, file
client = Client("kemuriririn/GPT-SoVITS-v2")
if not reference_audio_path:
raise ValueError("GPT-SoVITS-v2 ιœ€θ¦ reference_audio_path")
result = client.predict(
ref_wav_path=file(reference_audio_path),
prompt_text="",
prompt_language="English",
text=text,
text_language="English",
how_to_cut="Slice once every 4 sentences",
top_k=15,
top_p=1,
temperature=1,
ref_free=False,
speed=1,
if_freeze=False,
inp_refs=[],
api_name="/get_tts_wav"
)
print("gpt-sovits-v2 result:", result)
return result
def predict_tts(text, model, reference_audio_path=None):
global client
print(f"Predicting TTS for {model}")
# Exceptions: special models that shouldn't be passed to the router
if model == "index-tts":
result = predict_index_tts(text, reference_audio_path)
elif model == "spark-tts":
result = predict_spark_tts(text, reference_audio_path)
elif model == "cosyvoice-2.0":
result = predict_cosyvoice_tts(text, reference_audio_path)
elif model == "maskgct":
result = predict_maskgct(text, reference_audio_path)
elif model == "gpt-sovits-v2":
result = predict_gpt_sovits_v2(text, reference_audio_path)
else:
raise ValueError(f"Model {model} not found")
return result
if __name__ == "__main__":
pass