JacobLinCool commited on
Commit
db8b2d5
·
1 Parent(s): 812b01c

Add offset parameter to TJA writing functions and update inference methods for TC5, TC6, and TC7

Browse files
Files changed (4) hide show
  1. app.py +54 -16
  2. tc5/infer.py +2 -2
  3. tc6/infer.py +2 -2
  4. tc7/infer.py +2 -2
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import torch
3
  from tc5.config import SAMPLE_RATE, HOP_LENGTH
@@ -30,7 +31,7 @@ tc7.eval()
30
  synthesizer = Client("ryanlinjui/taiko-music-generator")
31
 
32
 
33
- def infer_tc5(audio, nps, bpm):
34
  audio_path = audio
35
  filename = audio_path.split("/")[-1]
36
  # Preprocess
@@ -58,7 +59,7 @@ def infer_tc5(audio, nps, bpm):
58
  output_frame_hop_sec,
59
  )
60
  # Generate TJA content
61
- tja_content = tc5infer.write_tja(onsets, bpm=bpm, audio=filename)
62
 
63
  # wrtie TJA content to a temporary file
64
  with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
@@ -70,7 +71,7 @@ def infer_tc5(audio, nps, bpm):
70
  param_1=handle_file(audio_path),
71
  param_2="達人譜面 / Master",
72
  param_3=16,
73
- param_4=5,
74
  param_5=5,
75
  param_6=5,
76
  param_7=5,
@@ -90,7 +91,7 @@ def infer_tc5(audio, nps, bpm):
90
  return oni_audio, plot, tja_content
91
 
92
 
93
- def infer_tc6(audio, nps, bpm, difficulty, level):
94
  audio_path = audio
95
  filename = audio_path.split("/")[-1]
96
  # Preprocess
@@ -121,7 +122,7 @@ def infer_tc6(audio, nps, bpm, difficulty, level):
121
  output_frame_hop_sec,
122
  )
123
  # Generate TJA content
124
- tja_content = tc6infer.write_tja(onsets, bpm=bpm, audio=filename)
125
 
126
  # wrtie TJA content to a temporary file
127
  with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
@@ -133,7 +134,7 @@ def infer_tc6(audio, nps, bpm, difficulty, level):
133
  param_1=handle_file(audio_path),
134
  param_2="達人譜面 / Master",
135
  param_3=16,
136
- param_4=5,
137
  param_5=5,
138
  param_6=5,
139
  param_7=5,
@@ -153,7 +154,7 @@ def infer_tc6(audio, nps, bpm, difficulty, level):
153
  return oni_audio, plot, tja_content
154
 
155
 
156
- def infer_tc7(audio, nps, bpm, difficulty, level):
157
  audio_path = audio
158
  filename = audio_path.split("/")[-1]
159
  # Preprocess
@@ -184,7 +185,7 @@ def infer_tc7(audio, nps, bpm, difficulty, level):
184
  output_frame_hop_sec,
185
  )
186
  # Generate TJA content
187
- tja_content = tc7infer.write_tja(onsets, bpm=bpm, audio=filename)
188
 
189
  # wrtie TJA content to a temporary file
190
  with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
@@ -196,7 +197,7 @@ def infer_tc7(audio, nps, bpm, difficulty, level):
196
  param_1=handle_file(audio_path),
197
  param_2="達人譜面 / Master",
198
  param_3=16,
199
- param_4=5,
200
  param_5=5,
201
  param_6=5,
202
  param_7=5,
@@ -216,17 +217,38 @@ def infer_tc7(audio, nps, bpm, difficulty, level):
216
  return oni_audio, plot, tja_content
217
 
218
 
219
- def run_inference(audio, model_choice, nps, bpm, difficulty, level):
 
220
  if model_choice == "TC5":
221
- return infer_tc5(audio, nps, bpm)
222
  elif model_choice == "TC6":
223
- return infer_tc6(audio, nps, bpm, difficulty, level)
224
  else: # TC7
225
- return infer_tc7(audio, nps, bpm, difficulty, level)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
 
228
  with gr.Blocks() as demo:
229
- gr.Markdown("# Taiko Conformer 5/7 Demo")
230
  with gr.Row():
231
  audio_input = gr.Audio(sources="upload", type="filepath", label="Input Audio")
232
 
@@ -253,6 +275,14 @@ with gr.Blocks() as demo:
253
  step=1,
254
  label="BPM (Used by TJA Quantization)",
255
  )
 
 
 
 
 
 
 
 
256
 
257
  with gr.Row():
258
  difficulty = gr.Slider(
@@ -274,10 +304,18 @@ with gr.Blocks() as demo:
274
  info="Difficulty level from 1 to 10",
275
  )
276
 
 
 
 
 
 
 
 
 
 
277
  audio_output = gr.Audio(label="Generated Audio", type="filepath")
278
  plot_output = gr.Plot(label="Onset/Energy Plot")
279
  tja_output = gr.Textbox(label="TJA File Content", show_copy_button=True)
280
- run_btn = gr.Button("Run Inference")
281
 
282
  # Update visibility of TC7-specific controls based on model selection
283
  def update_visibility(model_choice):
@@ -292,7 +330,7 @@ with gr.Blocks() as demo:
292
 
293
  run_btn.click(
294
  run_inference,
295
- inputs=[audio_input, model_choice, nps, bpm, difficulty, level],
296
  outputs=[audio_output, plot_output, tja_output],
297
  )
298
 
 
1
+ import spaces
2
  import gradio as gr
3
  import torch
4
  from tc5.config import SAMPLE_RATE, HOP_LENGTH
 
31
  synthesizer = Client("ryanlinjui/taiko-music-generator")
32
 
33
 
34
+ def infer_tc5(audio, nps, bpm, offset):
35
  audio_path = audio
36
  filename = audio_path.split("/")[-1]
37
  # Preprocess
 
59
  output_frame_hop_sec,
60
  )
61
  # Generate TJA content
62
+ tja_content = tc5infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset)
63
 
64
  # wrtie TJA content to a temporary file
65
  with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
 
71
  param_1=handle_file(audio_path),
72
  param_2="達人譜面 / Master",
73
  param_3=16,
74
+ param_4=7,
75
  param_5=5,
76
  param_6=5,
77
  param_7=5,
 
91
  return oni_audio, plot, tja_content
92
 
93
 
94
+ def infer_tc6(audio, nps, bpm, offset, difficulty, level):
95
  audio_path = audio
96
  filename = audio_path.split("/")[-1]
97
  # Preprocess
 
122
  output_frame_hop_sec,
123
  )
124
  # Generate TJA content
125
+ tja_content = tc6infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset)
126
 
127
  # wrtie TJA content to a temporary file
128
  with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
 
134
  param_1=handle_file(audio_path),
135
  param_2="達人譜面 / Master",
136
  param_3=16,
137
+ param_4=7,
138
  param_5=5,
139
  param_6=5,
140
  param_7=5,
 
154
  return oni_audio, plot, tja_content
155
 
156
 
157
+ def infer_tc7(audio, nps, bpm, offset, difficulty, level):
158
  audio_path = audio
159
  filename = audio_path.split("/")[-1]
160
  # Preprocess
 
185
  output_frame_hop_sec,
186
  )
187
  # Generate TJA content
188
+ tja_content = tc7infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset)
189
 
190
  # wrtie TJA content to a temporary file
191
  with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
 
197
  param_1=handle_file(audio_path),
198
  param_2="達人譜面 / Master",
199
  param_3=16,
200
+ param_4=7,
201
  param_5=5,
202
  param_6=5,
203
  param_7=5,
 
217
  return oni_audio, plot, tja_content
218
 
219
 
220
+ @spaces.GPU
221
+ def run_inference_gpu(audio, model_choice, nps, bpm, offset, difficulty, level):
222
  if model_choice == "TC5":
223
+ return infer_tc5(audio, nps, bpm, offset)
224
  elif model_choice == "TC6":
225
+ return infer_tc6(audio, nps, bpm, offset, difficulty, level)
226
  else: # TC7
227
+ return infer_tc7(audio, nps, bpm, offset, difficulty, level)
228
+
229
+
230
+ def run_inference_cpu(audio, model_choice, nps, bpm, offset, difficulty, level):
231
+ if model_choice == "TC5":
232
+ return infer_tc5(audio, nps, bpm, offset)
233
+ elif model_choice == "TC6":
234
+ return infer_tc6(audio, nps, bpm, offset, difficulty, level)
235
+ else: # TC7
236
+ return infer_tc7(audio, nps, bpm, offset, difficulty, level)
237
+
238
+
239
+ def run_inference(with_gpu, audio, model_choice, nps, bpm, offset, difficulty, level):
240
+ if with_gpu:
241
+ return run_inference_gpu(
242
+ audio, model_choice, nps, bpm, offset, difficulty, level
243
+ )
244
+ else:
245
+ return run_inference_cpu(
246
+ audio, model_choice, nps, bpm, offset, difficulty, level
247
+ )
248
 
249
 
250
  with gr.Blocks() as demo:
251
+ gr.Markdown("# Taiko Conformer 5/6/7 Demo")
252
  with gr.Row():
253
  audio_input = gr.Audio(sources="upload", type="filepath", label="Input Audio")
254
 
 
275
  step=1,
276
  label="BPM (Used by TJA Quantization)",
277
  )
278
+ offset = gr.Slider(
279
+ value=0.0,
280
+ minimum=-5.0,
281
+ maximum=5.0,
282
+ step=0.01,
283
+ label="Offset (in seconds)",
284
+ info="Adjust the offset for TJA",
285
+ )
286
 
287
  with gr.Row():
288
  difficulty = gr.Slider(
 
304
  info="Difficulty level from 1 to 10",
305
  )
306
 
307
+ with gr.Row():
308
+ with_gpu = gr.Checkbox(
309
+ value=True,
310
+ label="Use GPU for Inference",
311
+ info="Enable this to use GPU for faster inference (if available)",
312
+ )
313
+
314
+ run_btn = gr.Button("Run Inference", variant="primary")
315
+
316
  audio_output = gr.Audio(label="Generated Audio", type="filepath")
317
  plot_output = gr.Plot(label="Onset/Energy Plot")
318
  tja_output = gr.Textbox(label="TJA File Content", show_copy_button=True)
 
319
 
320
  # Update visibility of TC7-specific controls based on model selection
321
  def update_visibility(model_choice):
 
330
 
331
  run_btn.click(
332
  run_inference,
333
+ inputs=[audio_input, model_choice, nps, bpm, offset, difficulty, level],
334
  outputs=[audio_output, plot_output, tja_output],
335
  )
336
 
tc5/infer.py CHANGED
@@ -258,7 +258,7 @@ def plot_results(
258
  return fig
259
 
260
 
261
- def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
262
  # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
263
  # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
264
  sec_per_beat = 60 / bpm
@@ -336,7 +336,7 @@ def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
336
  tja_content.append(f"TITLE:{audio} (TC5, {time.strftime('%Y-%m-%d %H:%M:%S')})")
337
  tja_content.append(f"BPM:{bpm}")
338
  tja_content.append(f"WAVE:{audio}")
339
- tja_content.append("OFFSET:0")
340
  tja_content.append("COURSE:Oni\nLEVEL:9\n")
341
  tja_content.append("#START")
342
  for i in range(max_measure_idx + 1):
 
258
  return fig
259
 
260
 
261
+ def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav", offset=0):
262
  # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
263
  # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
264
  sec_per_beat = 60 / bpm
 
336
  tja_content.append(f"TITLE:{audio} (TC5, {time.strftime('%Y-%m-%d %H:%M:%S')})")
337
  tja_content.append(f"BPM:{bpm}")
338
  tja_content.append(f"WAVE:{audio}")
339
+ tja_content.append(f"OFFSET:{offset}")
340
  tja_content.append("COURSE:Oni\nLEVEL:9\n")
341
  tja_content.append("#START")
342
  for i in range(max_measure_idx + 1):
tc6/infer.py CHANGED
@@ -257,7 +257,7 @@ def plot_results(
257
  return fig
258
 
259
 
260
- def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
261
  # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
262
  # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
263
  sec_per_beat = 60 / bpm
@@ -334,7 +334,7 @@ def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
334
  tja_content.append(f"TITLE:{audio} (TC6, {time.strftime('%Y-%m-%d %H:%M:%S')})")
335
  tja_content.append(f"BPM:{bpm}")
336
  tja_content.append(f"WAVE:{audio}")
337
- tja_content.append("OFFSET:0")
338
  tja_content.append("COURSE:Oni\nLEVEL:9\n")
339
  tja_content.append("#START")
340
  for i in range(max_measure_idx + 1):
 
257
  return fig
258
 
259
 
260
+ def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav", offset=0):
261
  # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
262
  # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
263
  sec_per_beat = 60 / bpm
 
334
  tja_content.append(f"TITLE:{audio} (TC6, {time.strftime('%Y-%m-%d %H:%M:%S')})")
335
  tja_content.append(f"BPM:{bpm}")
336
  tja_content.append(f"WAVE:{audio}")
337
+ tja_content.append(f"OFFSET:{offset}")
338
  tja_content.append("COURSE:Oni\nLEVEL:9\n")
339
  tja_content.append("#START")
340
  for i in range(max_measure_idx + 1):
tc7/infer.py CHANGED
@@ -257,7 +257,7 @@ def plot_results(
257
  return fig
258
 
259
 
260
- def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
261
  # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
262
  # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
263
  sec_per_beat = 60 / bpm
@@ -334,7 +334,7 @@ def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
334
  tja_content.append(f"TITLE:{audio} (TC7, {time.strftime('%Y-%m-%d %H:%M:%S')})")
335
  tja_content.append(f"BPM:{bpm}")
336
  tja_content.append(f"WAVE:{audio}")
337
- tja_content.append("OFFSET:0")
338
  tja_content.append("COURSE:Oni\nLEVEL:9\n")
339
  tja_content.append("#START")
340
  for i in range(max_measure_idx + 1):
 
257
  return fig
258
 
259
 
260
+ def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav", offset=0):
261
  # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
262
  # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
263
  sec_per_beat = 60 / bpm
 
334
  tja_content.append(f"TITLE:{audio} (TC7, {time.strftime('%Y-%m-%d %H:%M:%S')})")
335
  tja_content.append(f"BPM:{bpm}")
336
  tja_content.append(f"WAVE:{audio}")
337
+ tja_content.append(f"OFFSET:{offset}")
338
  tja_content.append("COURSE:Oni\nLEVEL:9\n")
339
  tja_content.append("#START")
340
  for i in range(max_measure_idx + 1):