masszhou commited on
Commit
3b06e9c
·
1 Parent(s): 67a6eaa

add mdxnet model

Browse files
Files changed (1) hide show
  1. mdxnet_model.py +313 -0
mdxnet_model.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # reference: https://huggingface.co/spaces/r3gm/Audio_separator
2
+ import torch
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ import hashlib
6
+ import queue
7
+ import threading
8
+ from tqdm import tqdm
9
+
10
+
11
+ class MDXModel:
12
+ def __init__(
13
+ self,
14
+ device,
15
+ dim_f,
16
+ dim_t,
17
+ n_fft,
18
+ hop=1024,
19
+ stem_name=None,
20
+ compensation=1.000,
21
+ ):
22
+ self.dim_f = dim_f # frequency bins
23
+ self.dim_t = dim_t
24
+ self.dim_c = 4
25
+ self.n_fft = n_fft
26
+ self.hop = hop
27
+ self.stem_name = stem_name
28
+ self.compensation = compensation
29
+
30
+ self.n_bins = self.n_fft // 2 + 1
31
+ self.chunk_size = hop * (self.dim_t - 1)
32
+ self.window = torch.hann_window(
33
+ window_length=self.n_fft, periodic=True
34
+ ).to(device)
35
+
36
+ out_c = self.dim_c
37
+
38
+ self.freq_pad = torch.zeros(
39
+ [1, out_c, self.n_bins - self.dim_f, self.dim_t]
40
+ ).to(device)
41
+
42
+ def stft(self, x):
43
+ """
44
+ computes the Fourier transform of short overlapping windows of the input
45
+ """
46
+ x = x.reshape([-1, self.chunk_size])
47
+ x = torch.stft(
48
+ x,
49
+ n_fft=self.n_fft,
50
+ hop_length=self.hop,
51
+ window=self.window,
52
+ center=True,
53
+ return_complex=True,
54
+ )
55
+ x = torch.view_as_real(x)
56
+ x = x.permute([0, 3, 1, 2])
57
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
58
+ [-1, 4, self.n_bins, self.dim_t]
59
+ )
60
+ return x[:, :, : self.dim_f]
61
+
62
+ def istft(self, x, freq_pad=None):
63
+ """
64
+ computes the inverse Fourier transform of short overlapping windows of the input
65
+ """
66
+ freq_pad = (
67
+ self.freq_pad.repeat([x.shape[0], 1, 1, 1])
68
+ if freq_pad is None
69
+ else freq_pad
70
+ )
71
+ x = torch.cat([x, freq_pad], -2)
72
+ # c = 4*2 if self.target_name=='*' else 2
73
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
74
+ [-1, 2, self.n_bins, self.dim_t]
75
+ )
76
+ x = x.permute([0, 2, 3, 1])
77
+ x = x.contiguous()
78
+ x = torch.view_as_complex(x)
79
+ x = torch.istft(
80
+ x,
81
+ n_fft=self.n_fft,
82
+ hop_length=self.hop,
83
+ window=self.window,
84
+ center=True,
85
+ )
86
+ return x.reshape([-1, 2, self.chunk_size])
87
+
88
+
89
+ class MDX:
90
+ DEFAULT_SR = 44100 # unit: Hz
91
+ # Unit: seconds
92
+ DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
93
+ DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
94
+
95
+ def __init__(self, model_path: str, params: MDXModel, processor=0):
96
+ # Set the device and the provider (CPU or CUDA)
97
+ self.device = (
98
+ torch.device(f"cuda:{processor}")
99
+ if processor >= 0
100
+ else torch.device("cpu")
101
+ )
102
+ self.provider = (
103
+ ["CUDAExecutionProvider"]
104
+ if processor >= 0
105
+ else ["CPUExecutionProvider"]
106
+ )
107
+
108
+ self.model = params
109
+
110
+ # Load the ONNX model using ONNX Runtime
111
+ self.ort = ort.InferenceSession(model_path, providers=self.provider)
112
+ # Preload the model for faster performance
113
+ self.ort.run(
114
+ None,
115
+ {"input": torch.rand(1, 4, params.dim_f, params.dim_t).numpy()},
116
+ )
117
+ self.process = lambda spec: self.ort.run(
118
+ None, {"input": spec.cpu().numpy()}
119
+ )[0]
120
+
121
+ self.prog = None
122
+
123
+ @staticmethod
124
+ def get_hash(model_path: str) -> str:
125
+ try:
126
+ with open(model_path, "rb") as f:
127
+ f.seek(-10000 * 1024, 2)
128
+ model_hash = hashlib.md5(f.read()).hexdigest()
129
+ except: # noqa
130
+ model_hash = hashlib.md5(open(model_path, "rb").read()).hexdigest()
131
+
132
+ return model_hash
133
+
134
+ @staticmethod
135
+ def segment(
136
+ wave,
137
+ combine=True,
138
+ chunk_size=DEFAULT_CHUNK_SIZE,
139
+ margin_size=DEFAULT_MARGIN_SIZE,
140
+ ):
141
+ """
142
+ Segment or join segmented wave array
143
+
144
+ Args:
145
+ wave: (np.array) Wave array to be segmented or joined
146
+ combine: (bool) If True, combines segmented wave array.
147
+ If False, segments wave array.
148
+ chunk_size: (int) Size of each segment (in samples)
149
+ margin_size: (int) Size of margin between segments (in samples)
150
+
151
+ Returns:
152
+ numpy array: Segmented or joined wave array
153
+ """
154
+
155
+ if combine:
156
+ # Initializing as None instead of [] for later numpy array concatenation
157
+ processed_wave = None
158
+ for segment_count, segment in enumerate(wave):
159
+ start = 0 if segment_count == 0 else margin_size
160
+ end = None if segment_count == len(wave) - 1 else -margin_size
161
+ if margin_size == 0:
162
+ end = None
163
+ if processed_wave is None: # Create array for first segment
164
+ processed_wave = segment[:, start:end]
165
+ else: # Concatenate to existing array for subsequent segments
166
+ processed_wave = np.concatenate(
167
+ (processed_wave, segment[:, start:end]), axis=-1
168
+ )
169
+
170
+ else:
171
+ processed_wave = []
172
+ sample_count = wave.shape[-1]
173
+
174
+ if chunk_size <= 0 or chunk_size > sample_count:
175
+ chunk_size = sample_count
176
+
177
+ if margin_size > chunk_size:
178
+ margin_size = chunk_size
179
+
180
+ for segment_count, skip in enumerate(
181
+ range(0, sample_count, chunk_size)
182
+ ):
183
+ margin = 0 if segment_count == 0 else margin_size
184
+ end = min(skip + chunk_size + margin_size, sample_count)
185
+ start = skip - margin
186
+
187
+ cut = wave[:, start:end].copy()
188
+ processed_wave.append(cut)
189
+
190
+ if end == sample_count:
191
+ break
192
+
193
+ return processed_wave
194
+
195
+ def pad_wave(self, wave):
196
+ """
197
+ Pad the wave array to match the required chunk size
198
+
199
+ Args:
200
+ wave: (np.array) Wave array to be padded
201
+
202
+ Returns:
203
+ tuple: (padded_wave, pad, trim)
204
+ - padded_wave: Padded wave array
205
+ - pad: Number of samples that were padded
206
+ - trim: Number of samples that were trimmed
207
+ """
208
+ n_sample = wave.shape[1]
209
+ trim = self.model.n_fft // 2
210
+ gen_size = self.model.chunk_size - 2 * trim
211
+ pad = gen_size - n_sample % gen_size
212
+
213
+ # Padded wave
214
+ wave_p = np.concatenate(
215
+ (
216
+ np.zeros((2, trim)),
217
+ wave,
218
+ np.zeros((2, pad)),
219
+ np.zeros((2, trim)),
220
+ ),
221
+ 1,
222
+ )
223
+
224
+ mix_waves = []
225
+ for i in range(0, n_sample + pad, gen_size):
226
+ waves = np.array(wave_p[:, i:i + self.model.chunk_size])
227
+ mix_waves.append(waves)
228
+
229
+ mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(
230
+ self.device
231
+ )
232
+
233
+ return mix_waves, pad, trim
234
+
235
+ def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
236
+ """
237
+ Process each wave segment in a multi-threaded environment
238
+
239
+ Args:
240
+ mix_waves: (torch.Tensor) Wave segments to be processed
241
+ trim: (int) Number of samples trimmed during padding
242
+ pad: (int) Number of samples padded during padding
243
+ q: (queue.Queue) Queue to hold the processed wave segments
244
+ _id: (int) Identifier of the processed wave segment
245
+
246
+ Returns:
247
+ numpy array: Processed wave segment
248
+ """
249
+ mix_waves = mix_waves.split(1)
250
+ with torch.no_grad():
251
+ pw = []
252
+ for mix_wave in mix_waves:
253
+ self.prog.update()
254
+ spec = self.model.stft(mix_wave)
255
+ processed_spec = torch.tensor(self.process(spec))
256
+ processed_wav = self.model.istft(
257
+ processed_spec.to(self.device)
258
+ )
259
+ processed_wav = (
260
+ processed_wav[:, :, trim:-trim]
261
+ .transpose(0, 1)
262
+ .reshape(2, -1)
263
+ .cpu()
264
+ .numpy()
265
+ )
266
+ pw.append(processed_wav)
267
+ processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
268
+ q.put({_id: processed_signal})
269
+ return processed_signal
270
+
271
+ def process_wave(self, wave: np.array, mt_threads=1):
272
+ """
273
+ Process the wave array in a multi-threaded environment
274
+
275
+ Args:
276
+ wave: (np.array) Wave array to be processed
277
+ mt_threads: (int) Number of threads to be used for processing
278
+
279
+ Returns:
280
+ numpy array: Processed wave array
281
+ """
282
+ self.prog = tqdm(total=0)
283
+ chunk = wave.shape[-1] // mt_threads
284
+ waves = self.segment(wave, False, chunk)
285
+
286
+ # Create a queue to hold the processed wave segments
287
+ q = queue.Queue()
288
+ threads = []
289
+ for c, batch in enumerate(waves):
290
+ mix_waves, pad, trim = self.pad_wave(batch)
291
+ self.prog.total = len(mix_waves) * mt_threads
292
+ thread = threading.Thread(
293
+ target=self._process_wave, args=(mix_waves, trim, pad, q, c)
294
+ )
295
+ thread.start()
296
+ threads.append(thread)
297
+ for thread in threads:
298
+ thread.join()
299
+ self.prog.close()
300
+
301
+ processed_batches = []
302
+ while not q.empty():
303
+ processed_batches.append(q.get())
304
+ processed_batches = [
305
+ list(wave.values())[0]
306
+ for wave in sorted(
307
+ processed_batches, key=lambda d: list(d.keys())[0]
308
+ )
309
+ ]
310
+ assert len(processed_batches) == len(
311
+ waves
312
+ ), "Incomplete processed batches, please reduce batch size!"
313
+ return self.segment(processed_batches, True, chunk)