Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from tc5.config import SAMPLE_RATE, HOP_LENGTH | |
from tc5.model import TaikoConformer5 | |
from tc5 import infer as tc5infer | |
from tc6.model import TaikoConformer6 | |
from tc6 import infer as tc6infer | |
from tc7.model import TaikoConformer7 | |
from tc7 import infer as tc7infer | |
from gradio_client import Client, handle_file | |
import tempfile | |
GPU_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load model once | |
tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5") | |
tc5.to(GPU_DEVICE) | |
tc5.eval() | |
tc5_cpu = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5") | |
tc5_cpu.to("cpu") | |
tc5_cpu.eval() | |
# Load TC6 model | |
tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6") | |
tc6.to(GPU_DEVICE) | |
tc6.eval() | |
tc6_cpu = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6") | |
tc6_cpu.to("cpu") | |
tc6_cpu.eval() | |
# Load TC7 model | |
tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7") | |
tc7.to(GPU_DEVICE) | |
tc7.eval() | |
tc7_cpu = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7") | |
tc7_cpu.to("cpu") | |
tc7_cpu.eval() | |
synthesizer = Client("ryanlinjui/taiko-music-generator") | |
def infer_tc5(audio, nps, bpm, offset, DEVICE, MODEL): | |
audio_path = audio | |
filename = audio_path.split("/")[-1] | |
# Preprocess | |
mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps) | |
# Inference | |
don_energy, ka_energy, drumroll_energy = tc5infer.run_inference( | |
MODEL, mel_input, nps_input, DEVICE | |
) | |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
onsets = tc5infer.decode_onsets( | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
output_frame_hop_sec, | |
threshold=0.3, | |
min_distance_frames=3, | |
) | |
# Generate plot | |
plot = tc5infer.plot_results( | |
mel_input, | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
onsets, | |
output_frame_hop_sec, | |
) | |
# Generate TJA content | |
tja_content = tc5infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset) | |
# wrtie TJA content to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: | |
temp_tja_file.write(tja_content.encode("utf-8")) | |
tja_path = temp_tja_file.name | |
result = synthesizer.predict( | |
param_0=handle_file(tja_path), | |
param_1=handle_file(audio_path), | |
param_2="達人譜面 / Master", | |
param_3=16, | |
param_4=7, | |
param_5=5, | |
param_6=5, | |
param_7=5, | |
param_8=5, | |
param_9=5, | |
param_10=5, | |
param_11=5, | |
param_12=5, | |
param_13=5, | |
param_14=5, | |
param_15=5, | |
api_name="/handle", | |
) | |
oni_audio = result[1] | |
return oni_audio, plot, tja_content | |
def infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL): | |
audio_path = audio | |
filename = audio_path.split("/")[-1] | |
# Preprocess | |
mel_input = tc6infer.preprocess_audio(audio_path) | |
nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE) | |
difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE) | |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE) | |
# Inference | |
don_energy, ka_energy, drumroll_energy = tc6infer.run_inference( | |
MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE | |
) | |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
onsets = tc6infer.decode_onsets( | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
output_frame_hop_sec, | |
threshold=0.3, | |
min_distance_frames=3, | |
) | |
# Generate plot | |
plot = tc6infer.plot_results( | |
mel_input, | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
onsets, | |
output_frame_hop_sec, | |
) | |
# Generate TJA content | |
tja_content = tc6infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset) | |
# wrtie TJA content to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: | |
temp_tja_file.write(tja_content.encode("utf-8")) | |
tja_path = temp_tja_file.name | |
result = synthesizer.predict( | |
param_0=handle_file(tja_path), | |
param_1=handle_file(audio_path), | |
param_2="達人譜面 / Master", | |
param_3=16, | |
param_4=7, | |
param_5=5, | |
param_6=5, | |
param_7=5, | |
param_8=5, | |
param_9=5, | |
param_10=5, | |
param_11=5, | |
param_12=5, | |
param_13=5, | |
param_14=5, | |
param_15=5, | |
api_name="/handle", | |
) | |
oni_audio = result[1] | |
return oni_audio, plot, tja_content | |
def infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL): | |
audio_path = audio | |
filename = audio_path.split("/")[-1] | |
# Preprocess | |
mel_input = tc7infer.preprocess_audio(audio_path) | |
nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE) | |
difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE) | |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE) | |
# Inference | |
don_energy, ka_energy, drumroll_energy = tc7infer.run_inference( | |
MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE | |
) | |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
onsets = tc7infer.decode_onsets( | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
output_frame_hop_sec, | |
threshold=0.3, | |
min_distance_frames=3, | |
) | |
# Generate plot | |
plot = tc7infer.plot_results( | |
mel_input, | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
onsets, | |
output_frame_hop_sec, | |
) | |
# Generate TJA content | |
tja_content = tc7infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset) | |
# wrtie TJA content to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: | |
temp_tja_file.write(tja_content.encode("utf-8")) | |
tja_path = temp_tja_file.name | |
result = synthesizer.predict( | |
param_0=handle_file(tja_path), | |
param_1=handle_file(audio_path), | |
param_2="達人譜面 / Master", | |
param_3=16, | |
param_4=7, | |
param_5=5, | |
param_6=5, | |
param_7=5, | |
param_8=5, | |
param_9=5, | |
param_10=5, | |
param_11=5, | |
param_12=5, | |
param_13=5, | |
param_14=5, | |
param_15=5, | |
api_name="/handle", | |
) | |
oni_audio = result[1] | |
return oni_audio, plot, tja_content | |
def run_inference_gpu(audio, model_choice, nps, bpm, offset, difficulty, level): | |
if model_choice == "TC5": | |
return infer_tc5(audio, nps, bpm, offset, GPU_DEVICE, tc5) | |
elif model_choice == "TC6": | |
return infer_tc6(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc6) | |
else: # TC7 | |
return infer_tc7(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc7) | |
def run_inference_cpu(audio, model_choice, nps, bpm, offset, difficulty, level): | |
DEVICE = torch.device("cpu") | |
if model_choice == "TC5": | |
return infer_tc5(audio, nps, bpm, offset, DEVICE, tc5_cpu) | |
elif model_choice == "TC6": | |
return infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, tc6_cpu) | |
else: # TC7 | |
return infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, tc7_cpu) | |
def run_inference(with_gpu, audio, model_choice, nps, bpm, offset, difficulty, level): | |
if with_gpu: | |
return run_inference_gpu( | |
audio, model_choice, nps, bpm, offset, difficulty, level | |
) | |
else: | |
return run_inference_cpu( | |
audio, model_choice, nps, bpm, offset, difficulty, level | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Taiko Conformer 5/6/7 Demo") | |
with gr.Row(): | |
audio_input = gr.Audio(sources="upload", type="filepath", label="Input Audio") | |
with gr.Row(): | |
model_choice = gr.Dropdown( | |
choices=["TC5", "TC6", "TC7"], | |
value="TC7", | |
label="Model Selection", | |
info="Choose between TaikoConformer 5, 6 or 7", | |
) | |
with gr.Row(): | |
nps = gr.Slider( | |
value=5.0, | |
minimum=0.5, | |
maximum=11.0, | |
step=0.5, | |
label="NPS (Notes Per Second)", | |
) | |
bpm = gr.Slider( | |
value=240, | |
minimum=160, | |
maximum=640, | |
step=1, | |
label="BPM (Used by TJA Quantization)", | |
) | |
offset = gr.Slider( | |
value=0.0, | |
minimum=-5.0, | |
maximum=5.0, | |
step=0.01, | |
label="Offset (in seconds)", | |
info="Adjust the offset for TJA", | |
) | |
with gr.Row(): | |
difficulty = gr.Slider( | |
value=3.0, | |
minimum=1.0, | |
maximum=3.0, | |
step=1.0, | |
label="Difficulty", | |
visible=False, | |
info="1=Normal, 2=Hard, 3=Oni", | |
) | |
level = gr.Slider( | |
value=8.0, | |
minimum=1.0, | |
maximum=10.0, | |
step=1.0, | |
label="Level", | |
visible=False, | |
info="Difficulty level from 1 to 10", | |
) | |
with gr.Row(): | |
with_gpu = gr.Checkbox( | |
value=True, | |
label="Use GPU for Inference", | |
info="Enable this to use GPU for faster inference (if available)", | |
) | |
run_btn = gr.Button("Run Inference", variant="primary") | |
audio_output = gr.Audio(label="Generated Audio", type="filepath") | |
plot_output = gr.Plot(label="Onset/Energy Plot") | |
tja_output = gr.Textbox(label="TJA File Content", show_copy_button=True) | |
# Update visibility of TC7-specific controls based on model selection | |
def update_visibility(model_choice): | |
if model_choice == "TC7" or model_choice == "TC6": | |
return gr.update(visible=True), gr.update(visible=True) | |
else: | |
return gr.update(visible=False), gr.update(visible=False) | |
model_choice.change( | |
update_visibility, inputs=[model_choice], outputs=[difficulty, level] | |
) | |
run_btn.click( | |
run_inference, | |
inputs=[ | |
with_gpu, | |
audio_input, | |
model_choice, | |
nps, | |
bpm, | |
offset, | |
difficulty, | |
level, | |
], | |
outputs=[audio_output, plot_output, tja_output], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |