import spaces import gradio as gr import torch from tc5.config import SAMPLE_RATE, HOP_LENGTH from tc5.model import TaikoConformer5 from tc5 import infer as tc5infer from tc6.model import TaikoConformer6 from tc6 import infer as tc6infer from tc7.model import TaikoConformer7 from tc7 import infer as tc7infer from gradio_client import Client, handle_file import tempfile GPU_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model once tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5") tc5.to(GPU_DEVICE) tc5.eval() tc5_cpu = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5") tc5_cpu.to("cpu") tc5_cpu.eval() # Load TC6 model tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6") tc6.to(GPU_DEVICE) tc6.eval() tc6_cpu = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6") tc6_cpu.to("cpu") tc6_cpu.eval() # Load TC7 model tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7") tc7.to(GPU_DEVICE) tc7.eval() tc7_cpu = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7") tc7_cpu.to("cpu") tc7_cpu.eval() synthesizer = Client("ryanlinjui/taiko-music-generator") def infer_tc5(audio, nps, bpm, offset, DEVICE, MODEL): audio_path = audio filename = audio_path.split("/")[-1] # Preprocess mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps) # Inference don_energy, ka_energy, drumroll_energy = tc5infer.run_inference( MODEL, mel_input, nps_input, DEVICE ) output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE onsets = tc5infer.decode_onsets( don_energy, ka_energy, drumroll_energy, output_frame_hop_sec, threshold=0.3, min_distance_frames=3, ) # Generate plot plot = tc5infer.plot_results( mel_input, don_energy, ka_energy, drumroll_energy, onsets, output_frame_hop_sec, ) # Generate TJA content tja_content = tc5infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset) # wrtie TJA content to a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: temp_tja_file.write(tja_content.encode("utf-8")) tja_path = temp_tja_file.name result = synthesizer.predict( param_0=handle_file(tja_path), param_1=handle_file(audio_path), param_2="達人譜面 / Master", param_3=16, param_4=7, param_5=5, param_6=5, param_7=5, param_8=5, param_9=5, param_10=5, param_11=5, param_12=5, param_13=5, param_14=5, param_15=5, api_name="/handle", ) oni_audio = result[1] return oni_audio, plot, tja_content def infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL): audio_path = audio filename = audio_path.split("/")[-1] # Preprocess mel_input = tc6infer.preprocess_audio(audio_path) nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE) difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE) level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE) # Inference don_energy, ka_energy, drumroll_energy = tc6infer.run_inference( MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE ) output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE onsets = tc6infer.decode_onsets( don_energy, ka_energy, drumroll_energy, output_frame_hop_sec, threshold=0.3, min_distance_frames=3, ) # Generate plot plot = tc6infer.plot_results( mel_input, don_energy, ka_energy, drumroll_energy, onsets, output_frame_hop_sec, ) # Generate TJA content tja_content = tc6infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset) # wrtie TJA content to a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: temp_tja_file.write(tja_content.encode("utf-8")) tja_path = temp_tja_file.name result = synthesizer.predict( param_0=handle_file(tja_path), param_1=handle_file(audio_path), param_2="達人譜面 / Master", param_3=16, param_4=7, param_5=5, param_6=5, param_7=5, param_8=5, param_9=5, param_10=5, param_11=5, param_12=5, param_13=5, param_14=5, param_15=5, api_name="/handle", ) oni_audio = result[1] return oni_audio, plot, tja_content def infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL): audio_path = audio filename = audio_path.split("/")[-1] # Preprocess mel_input = tc7infer.preprocess_audio(audio_path) nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE) difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE) level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE) # Inference don_energy, ka_energy, drumroll_energy = tc7infer.run_inference( MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE ) output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE onsets = tc7infer.decode_onsets( don_energy, ka_energy, drumroll_energy, output_frame_hop_sec, threshold=0.3, min_distance_frames=3, ) # Generate plot plot = tc7infer.plot_results( mel_input, don_energy, ka_energy, drumroll_energy, onsets, output_frame_hop_sec, ) # Generate TJA content tja_content = tc7infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset) # wrtie TJA content to a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: temp_tja_file.write(tja_content.encode("utf-8")) tja_path = temp_tja_file.name result = synthesizer.predict( param_0=handle_file(tja_path), param_1=handle_file(audio_path), param_2="達人譜面 / Master", param_3=16, param_4=7, param_5=5, param_6=5, param_7=5, param_8=5, param_9=5, param_10=5, param_11=5, param_12=5, param_13=5, param_14=5, param_15=5, api_name="/handle", ) oni_audio = result[1] return oni_audio, plot, tja_content @spaces.GPU() def run_inference_gpu(audio, model_choice, nps, bpm, offset, difficulty, level): if model_choice == "TC5": return infer_tc5(audio, nps, bpm, offset, GPU_DEVICE, tc5) elif model_choice == "TC6": return infer_tc6(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc6) else: # TC7 return infer_tc7(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc7) def run_inference_cpu(audio, model_choice, nps, bpm, offset, difficulty, level): DEVICE = torch.device("cpu") if model_choice == "TC5": return infer_tc5(audio, nps, bpm, offset, DEVICE, tc5_cpu) elif model_choice == "TC6": return infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, tc6_cpu) else: # TC7 return infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, tc7_cpu) def run_inference(with_gpu, audio, model_choice, nps, bpm, offset, difficulty, level): if with_gpu: return run_inference_gpu( audio, model_choice, nps, bpm, offset, difficulty, level ) else: return run_inference_cpu( audio, model_choice, nps, bpm, offset, difficulty, level ) with gr.Blocks() as demo: gr.Markdown("# Taiko Conformer 5/6/7 Demo") with gr.Row(): audio_input = gr.Audio(sources="upload", type="filepath", label="Input Audio") with gr.Row(): model_choice = gr.Dropdown( choices=["TC5", "TC6", "TC7"], value="TC7", label="Model Selection", info="Choose between TaikoConformer 5, 6 or 7", ) with gr.Row(): nps = gr.Slider( value=5.0, minimum=0.5, maximum=11.0, step=0.5, label="NPS (Notes Per Second)", ) bpm = gr.Slider( value=240, minimum=160, maximum=640, step=1, label="BPM (Used by TJA Quantization)", ) offset = gr.Slider( value=0.0, minimum=-5.0, maximum=5.0, step=0.01, label="Offset (in seconds)", info="Adjust the offset for TJA", ) with gr.Row(): difficulty = gr.Slider( value=3.0, minimum=1.0, maximum=3.0, step=1.0, label="Difficulty", visible=False, info="1=Normal, 2=Hard, 3=Oni", ) level = gr.Slider( value=8.0, minimum=1.0, maximum=10.0, step=1.0, label="Level", visible=False, info="Difficulty level from 1 to 10", ) with gr.Row(): with_gpu = gr.Checkbox( value=True, label="Use GPU for Inference", info="Enable this to use GPU for faster inference (if available)", ) run_btn = gr.Button("Run Inference", variant="primary") audio_output = gr.Audio(label="Generated Audio", type="filepath") plot_output = gr.Plot(label="Onset/Energy Plot") tja_output = gr.Textbox(label="TJA File Content", show_copy_button=True) # Update visibility of TC7-specific controls based on model selection def update_visibility(model_choice): if model_choice == "TC7" or model_choice == "TC6": return gr.update(visible=True), gr.update(visible=True) else: return gr.update(visible=False), gr.update(visible=False) model_choice.change( update_visibility, inputs=[model_choice], outputs=[difficulty, level] ) run_btn.click( run_inference, inputs=[ with_gpu, audio_input, model_choice, nps, bpm, offset, difficulty, level, ], outputs=[audio_output, plot_output, tja_output], ) if __name__ == "__main__": demo.launch()