Spaces:
Running
Running
import os | |
from dotenv import load_dotenv | |
import random | |
from fal_client import stream | |
from gradio_client.exceptions import AppError | |
load_dotenv() | |
ZEROGPU_TOKENS = os.getenv("ZEROGPU_TOKENS", "").split(",") | |
def get_zerogpu_token(): | |
return random.choice(ZEROGPU_TOKENS) | |
model_mapping = { | |
"spark-tts": { | |
"provider": "spark", | |
"model": "spark-tts", | |
}, | |
"cosyvoice-2.0": { | |
"provider": "cosyvoice", | |
"model": "cosyvoice_2_0", | |
}, | |
"index-tts": { | |
"provider": "bilibili", | |
"model": "index-tts", | |
}, | |
"maskgct": { | |
"provider": "amphion", | |
"model": "maskgct", | |
}, | |
"gpt-sovits-v2": { | |
"provider": "gpt-sovits", | |
"model": "gpt-sovits-v2", | |
}, | |
} | |
url = "https://tts-agi-tts-router-v2.hf.space/tts" | |
headers = { | |
"accept": "application/json", | |
"Content-Type": "application/json", | |
"Authorization": f'Bearer {os.getenv("HF_TOKEN")}', | |
} | |
data = {"text": "string", "provider": "string", "model": "string"} | |
def predict_index_tts(text, reference_audio_path=None): | |
from gradio_client import Client, handle_file | |
client = Client("kemuriririn/IndexTTS",verbose=True) | |
if reference_audio_path: | |
prompt = handle_file(reference_audio_path) | |
else: | |
raise ValueError("index-tts ιθ¦ reference_audio_path") | |
result = client.predict( | |
prompt=prompt, | |
text=text, | |
api_name="/gen_single" | |
) | |
if type(result) != str: | |
result = result.get("value") | |
print("index-tts result:", result) | |
return result | |
def predict_spark_tts(text, reference_audio_path=None): | |
from gradio_client import Client, handle_file | |
client = Client("kemuriririn/SparkTTS") | |
prompt_wav = None | |
if reference_audio_path: | |
prompt_wav = handle_file(reference_audio_path) | |
result = client.predict( | |
text=text, | |
prompt_text=text, | |
prompt_wav_upload=prompt_wav, | |
prompt_wav_record=prompt_wav, | |
api_name="/voice_clone" | |
) | |
print("spark-tts result:", result) | |
return result | |
def predict_cosyvoice_tts(text, reference_audio_path=None): | |
from gradio_client import Client, file, handle_file | |
client = Client("kemuriririn/CosyVoice2-0.5B") | |
if not reference_audio_path: | |
raise ValueError("cosyvoice-2.0 ιθ¦ reference_audio_path") | |
prompt_wav = handle_file(reference_audio_path) | |
# ε θ―ε«εθι³ι’ζζ¬ | |
recog_result = client.predict( | |
prompt_wav=file(reference_audio_path), | |
api_name="/prompt_wav_recognition" | |
) | |
print("cosyvoice-2.0 prompt_wav_recognition result:", recog_result) | |
prompt_text = recog_result if isinstance(recog_result, str) else str(recog_result) | |
result = client.predict( | |
tts_text=text, | |
prompt_text=prompt_text, | |
prompt_wav_upload=prompt_wav, | |
prompt_wav_record=prompt_wav, | |
seed=0, | |
stream=False, | |
api_name="/generate_audio" | |
) | |
print("cosyvoice-2.0 result:", result) | |
return result | |
def predict_maskgct(text, reference_audio_path=None): | |
from gradio_client import Client, handle_file | |
client = Client("amphion/maskgct") | |
if not reference_audio_path: | |
raise ValueError("maskgct ιθ¦ reference_audio_path") | |
prompt_wav = handle_file(reference_audio_path) | |
result = client.predict( | |
prompt_wav=prompt_wav, | |
target_text=text, | |
target_len=-1, | |
n_timesteps=25, | |
api_name="/predict" | |
) | |
print("maskgct result:", result) | |
return result | |
def predict_gpt_sovits_v2(text, reference_audio_path=None): | |
from gradio_client import Client, file | |
client = Client("kemuriririn/GPT-SoVITS-v2") | |
if not reference_audio_path: | |
raise ValueError("GPT-SoVITS-v2 ιθ¦ reference_audio_path") | |
result = client.predict( | |
ref_wav_path=file(reference_audio_path), | |
prompt_text="", | |
prompt_language="English", | |
text=text, | |
text_language="English", | |
how_to_cut="Slice once every 4 sentences", | |
top_k=15, | |
top_p=1, | |
temperature=1, | |
ref_free=False, | |
speed=1, | |
if_freeze=False, | |
inp_refs=[], | |
api_name="/get_tts_wav" | |
) | |
print("gpt-sovits-v2 result:", result) | |
return result | |
def predict_tts(text, model, reference_audio_path=None): | |
global client | |
print(f"Predicting TTS for {model}") | |
# Exceptions: special models that shouldn't be passed to the router | |
if model == "index-tts": | |
result = predict_index_tts(text, reference_audio_path) | |
elif model == "spark-tts": | |
result = predict_spark_tts(text, reference_audio_path) | |
elif model == "cosyvoice-2.0": | |
result = predict_cosyvoice_tts(text, reference_audio_path) | |
elif model == "maskgct": | |
result = predict_maskgct(text, reference_audio_path) | |
elif model == "gpt-sovits-v2": | |
result = predict_gpt_sovits_v2(text, reference_audio_path) | |
else: | |
raise ValueError(f"Model {model} not found") | |
return result | |
if __name__ == "__main__": | |
pass | |