JacobLinCool commited on
Commit
b7070f2
·
1 Parent(s): db8b2d5

Refactor inference functions to accept DEVICE and MODEL parameters for TC5, TC6, and TC7; update model loading to use GPU if available.

Browse files
Files changed (1) hide show
  1. app.py +36 -17
app.py CHANGED
@@ -11,34 +11,43 @@ from tc7 import infer as tc7infer
11
  from gradio_client import Client, handle_file
12
  import tempfile
13
 
14
- DEVICE = torch.device("cpu")
15
 
16
  # Load model once
17
  tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
18
- tc5.to(DEVICE)
19
  tc5.eval()
 
 
 
20
 
21
  # Load TC6 model
22
  tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
23
- tc6.to(DEVICE)
24
  tc6.eval()
 
 
 
25
 
26
  # Load TC7 model
27
  tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
28
- tc7.to(DEVICE)
29
  tc7.eval()
 
 
 
30
 
31
  synthesizer = Client("ryanlinjui/taiko-music-generator")
32
 
33
 
34
- def infer_tc5(audio, nps, bpm, offset):
35
  audio_path = audio
36
  filename = audio_path.split("/")[-1]
37
  # Preprocess
38
  mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps)
39
  # Inference
40
  don_energy, ka_energy, drumroll_energy = tc5infer.run_inference(
41
- tc5, mel_input, nps_input, DEVICE
42
  )
43
  output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
44
  onsets = tc5infer.decode_onsets(
@@ -91,7 +100,7 @@ def infer_tc5(audio, nps, bpm, offset):
91
  return oni_audio, plot, tja_content
92
 
93
 
94
- def infer_tc6(audio, nps, bpm, offset, difficulty, level):
95
  audio_path = audio
96
  filename = audio_path.split("/")[-1]
97
  # Preprocess
@@ -101,7 +110,7 @@ def infer_tc6(audio, nps, bpm, offset, difficulty, level):
101
  level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
102
  # Inference
103
  don_energy, ka_energy, drumroll_energy = tc6infer.run_inference(
104
- tc6, mel_input, nps_input, difficulty_input, level_input, DEVICE
105
  )
106
  output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
107
  onsets = tc6infer.decode_onsets(
@@ -154,7 +163,7 @@ def infer_tc6(audio, nps, bpm, offset, difficulty, level):
154
  return oni_audio, plot, tja_content
155
 
156
 
157
- def infer_tc7(audio, nps, bpm, offset, difficulty, level):
158
  audio_path = audio
159
  filename = audio_path.split("/")[-1]
160
  # Preprocess
@@ -164,7 +173,7 @@ def infer_tc7(audio, nps, bpm, offset, difficulty, level):
164
  level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
165
  # Inference
166
  don_energy, ka_energy, drumroll_energy = tc7infer.run_inference(
167
- tc7, mel_input, nps_input, difficulty_input, level_input, DEVICE
168
  )
169
  output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
170
  onsets = tc7infer.decode_onsets(
@@ -220,20 +229,21 @@ def infer_tc7(audio, nps, bpm, offset, difficulty, level):
220
  @spaces.GPU
221
  def run_inference_gpu(audio, model_choice, nps, bpm, offset, difficulty, level):
222
  if model_choice == "TC5":
223
- return infer_tc5(audio, nps, bpm, offset)
224
  elif model_choice == "TC6":
225
- return infer_tc6(audio, nps, bpm, offset, difficulty, level)
226
  else: # TC7
227
- return infer_tc7(audio, nps, bpm, offset, difficulty, level)
228
 
229
 
230
  def run_inference_cpu(audio, model_choice, nps, bpm, offset, difficulty, level):
 
231
  if model_choice == "TC5":
232
- return infer_tc5(audio, nps, bpm, offset)
233
  elif model_choice == "TC6":
234
- return infer_tc6(audio, nps, bpm, offset, difficulty, level)
235
  else: # TC7
236
- return infer_tc7(audio, nps, bpm, offset, difficulty, level)
237
 
238
 
239
  def run_inference(with_gpu, audio, model_choice, nps, bpm, offset, difficulty, level):
@@ -330,7 +340,16 @@ with gr.Blocks() as demo:
330
 
331
  run_btn.click(
332
  run_inference,
333
- inputs=[audio_input, model_choice, nps, bpm, offset, difficulty, level],
 
 
 
 
 
 
 
 
 
334
  outputs=[audio_output, plot_output, tja_output],
335
  )
336
 
 
11
  from gradio_client import Client, handle_file
12
  import tempfile
13
 
14
+ GPU_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  # Load model once
17
  tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
18
+ tc5.to(GPU_DEVICE)
19
  tc5.eval()
20
+ tc5_cpu = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
21
+ tc5_cpu.to("cpu")
22
+ tc5_cpu.eval()
23
 
24
  # Load TC6 model
25
  tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
26
+ tc6.to(GPU_DEVICE)
27
  tc6.eval()
28
+ tc6_cpu = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
29
+ tc6_cpu.to("cpu")
30
+ tc6_cpu.eval()
31
 
32
  # Load TC7 model
33
  tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
34
+ tc7.to(GPU_DEVICE)
35
  tc7.eval()
36
+ tc7_cpu = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
37
+ tc7_cpu.to("cpu")
38
+ tc7_cpu.eval()
39
 
40
  synthesizer = Client("ryanlinjui/taiko-music-generator")
41
 
42
 
43
+ def infer_tc5(audio, nps, bpm, offset, DEVICE, MODEL):
44
  audio_path = audio
45
  filename = audio_path.split("/")[-1]
46
  # Preprocess
47
  mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps)
48
  # Inference
49
  don_energy, ka_energy, drumroll_energy = tc5infer.run_inference(
50
+ MODEL, mel_input, nps_input, DEVICE
51
  )
52
  output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
53
  onsets = tc5infer.decode_onsets(
 
100
  return oni_audio, plot, tja_content
101
 
102
 
103
+ def infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL):
104
  audio_path = audio
105
  filename = audio_path.split("/")[-1]
106
  # Preprocess
 
110
  level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
111
  # Inference
112
  don_energy, ka_energy, drumroll_energy = tc6infer.run_inference(
113
+ MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE
114
  )
115
  output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
116
  onsets = tc6infer.decode_onsets(
 
163
  return oni_audio, plot, tja_content
164
 
165
 
166
+ def infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL):
167
  audio_path = audio
168
  filename = audio_path.split("/")[-1]
169
  # Preprocess
 
173
  level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
174
  # Inference
175
  don_energy, ka_energy, drumroll_energy = tc7infer.run_inference(
176
+ MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE
177
  )
178
  output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
179
  onsets = tc7infer.decode_onsets(
 
229
  @spaces.GPU
230
  def run_inference_gpu(audio, model_choice, nps, bpm, offset, difficulty, level):
231
  if model_choice == "TC5":
232
+ return infer_tc5(audio, nps, bpm, offset, GPU_DEVICE, tc5)
233
  elif model_choice == "TC6":
234
+ return infer_tc6(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc6)
235
  else: # TC7
236
+ return infer_tc7(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc7)
237
 
238
 
239
  def run_inference_cpu(audio, model_choice, nps, bpm, offset, difficulty, level):
240
+ DEVICE = torch.device("cpu")
241
  if model_choice == "TC5":
242
+ return infer_tc5(audio, nps, bpm, offset, DEVICE, tc5_cpu)
243
  elif model_choice == "TC6":
244
+ return infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, tc6_cpu)
245
  else: # TC7
246
+ return infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, tc7_cpu)
247
 
248
 
249
  def run_inference(with_gpu, audio, model_choice, nps, bpm, offset, difficulty, level):
 
340
 
341
  run_btn.click(
342
  run_inference,
343
+ inputs=[
344
+ with_gpu,
345
+ audio_input,
346
+ model_choice,
347
+ nps,
348
+ bpm,
349
+ offset,
350
+ difficulty,
351
+ level,
352
+ ],
353
  outputs=[audio_output, plot_output, tja_output],
354
  )
355