import gc import os import re import hashlib import queue import threading import json import shlex import sys import subprocess import librosa import numpy as np import soundfile as sf import torch from tqdm import tqdm import random import spaces import onnxruntime as ort import warnings import spaces import gradio as gr import logging import time import traceback import numpy as np from pathlib import Path from huggingface_hub import hf_hub_download from typing import Dict, Tuple MODEL_ID = "masszhou/mdxnet" MODELS_PATH = { "bgm": Path(hf_hub_download(repo_id=MODEL_ID, filename="UVR-MDX-NET-Inst_HQ_3.onnx")), "basic_vocal": Path(hf_hub_download(repo_id=MODEL_ID, filename="UVR-MDX-NET-Voc_FT.onnx")), "main_vocal": Path(hf_hub_download(repo_id=MODEL_ID, filename="UVR_MDXNET_KARA_2.onnx")) } STEM_NAMING = { "Vocals": "Instrumental", "Other": "Instruments", "Instrumental": "Vocals", "Drums": "Drumless", "Bass": "Bassless", } def convert_to_stereo_and_wav(audio_path: Path) -> Path: # loading takes time since resampling at 44100 Hz wave, sr = librosa.load(str(audio_path), mono=False, sr=44100) # check if mono if type(wave[0]) != np.ndarray or audio_path.suffix != ".wav": # noqa stereo_path = audio_path.with_name(audio_path.stem + "_stereo.wav") command = shlex.split( f'ffmpeg -y -loglevel error -i "{str(audio_path)}" -ac 2 -f wav "{str(stereo_path)}"' ) sub_params = { "stdout": subprocess.PIPE, "stderr": subprocess.PIPE, "creationflags": subprocess.CREATE_NO_WINDOW if sys.platform == "win32" else 0, } process_wav = subprocess.Popen(command, **sub_params) output, errors = process_wav.communicate() if process_wav.returncode != 0 or not stereo_path.exists(): raise Exception("Error processing audio to stereo wav") return stereo_path else: return Path(audio_path) class MDXModel: def __init__(self, device: torch.device, dim_f: int, dim_t: int, n_fft: int, hop: int = 1024, stem_name: str = "Vocals", compensation: float = 1.000,): self.dim_f = dim_f # frequency bins self.dim_t = dim_t self.dim_c = 4 self.n_fft = n_fft self.hop = hop self.stem_name = stem_name self.compensation = compensation self.n_bins = self.n_fft // 2 + 1 self.chunk_size = hop * (self.dim_t - 1) self.window = torch.hann_window( window_length=self.n_fft, periodic=True ).to(device) out_c = self.dim_c self.freq_pad = torch.zeros( [1, out_c, self.n_bins - self.dim_f, self.dim_t] ).to(device) def stft(self, x): """ computes the Fourier transform of short overlapping windows of the input """ x = x.reshape([-1, self.chunk_size]) x = torch.stft( x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True, ) x = torch.view_as_real(x) x = x.permute([0, 3, 1, 2]) x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape( [-1, 4, self.n_bins, self.dim_t] ) return x[:, :, : self.dim_f] def istft(self, x, freq_pad=None): """ computes the inverse Fourier transform of short overlapping windows of the input """ freq_pad = ( self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad ) x = torch.cat([x, freq_pad], -2) # c = 4*2 if self.target_name=='*' else 2 x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape( [-1, 2, self.n_bins, self.dim_t] ) x = x.permute([0, 2, 3, 1]) x = x.contiguous() x = torch.view_as_complex(x) x = torch.istft( x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, ) return x.reshape([-1, 2, self.chunk_size]) class MDX: DEFAULT_SR = 44100 # unit: Hz # Unit: seconds DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR def __init__(self, model_path: Path, params: MDXModel, processor: int = 0): # Set the device and the provider (CPU or CUDA) self.device = ( torch.device(f"cuda:{processor}") if processor >= 0 else torch.device("cpu") ) self.provider = ( ["CUDAExecutionProvider"] if processor >= 0 else ["CPUExecutionProvider"] ) self.model = params # Load the ONNX model using ONNX Runtime self.ort = ort.InferenceSession(model_path, providers=self.provider) # Preload the model for faster performance self.ort.run( None, {"input": torch.rand(1, 4, params.dim_f, params.dim_t).numpy()}, ) self.process = lambda spec: self.ort.run( None, {"input": spec.cpu().numpy()} )[0] self.prog = None @staticmethod def get_hash(model_path: Path) -> str: try: with open(model_path, "rb") as f: f.seek(-10000 * 1024, 2) model_hash = hashlib.md5(f.read()).hexdigest() except: # noqa model_hash = hashlib.md5(open(model_path, "rb").read()).hexdigest() return model_hash @staticmethod def segment(wave: np.array, combine: bool = True, chunk_size: int = DEFAULT_CHUNK_SIZE, margin_size: int = DEFAULT_MARGIN_SIZE, ) -> np.array: """ Segment or join segmented wave array Args: wave: (np.array) Wave array to be segmented or joined combine: (bool) If True, combines segmented wave array. If False, segments wave array. chunk_size: (int) Size of each segment (in samples) margin_size: (int) Size of margin between segments (in samples) Returns: numpy array: Segmented or joined wave array """ if combine: # Initializing as None instead of [] for later numpy array concatenation processed_wave = None for segment_count, segment in enumerate(wave): start = 0 if segment_count == 0 else margin_size end = None if segment_count == len(wave) - 1 else -margin_size if margin_size == 0: end = None if processed_wave is None: # Create array for first segment processed_wave = segment[:, start:end] else: # Concatenate to existing array for subsequent segments processed_wave = np.concatenate( (processed_wave, segment[:, start:end]), axis=-1 ) else: processed_wave = [] sample_count = wave.shape[-1] if chunk_size <= 0 or chunk_size > sample_count: chunk_size = sample_count if margin_size > chunk_size: margin_size = chunk_size for segment_count, skip in enumerate( range(0, sample_count, chunk_size) ): margin = 0 if segment_count == 0 else margin_size end = min(skip + chunk_size + margin_size, sample_count) start = skip - margin cut = wave[:, start:end].copy() processed_wave.append(cut) if end == sample_count: break return processed_wave def pad_wave(self, wave: np.array) -> Tuple[np.array, int, int]: """ Pad the wave array to match the required chunk size Args: wave: (np.array) Wave array to be padded Returns: tuple: (padded_wave, pad, trim) - padded_wave: Padded wave array - pad: Number of samples that were padded - trim: Number of samples that were trimmed """ n_sample = wave.shape[1] trim = self.model.n_fft // 2 gen_size = self.model.chunk_size - 2 * trim pad = gen_size - n_sample % gen_size # Padded wave wave_p = np.concatenate( ( np.zeros((2, trim)), wave, np.zeros((2, pad)), np.zeros((2, trim)), ), 1, ) mix_waves = [] for i in range(0, n_sample + pad, gen_size): waves = np.array(wave_p[:, i:i + self.model.chunk_size]) mix_waves.append(waves) mix_waves = torch.tensor(np.array(mix_waves), dtype=torch.float32).to(self.device) return mix_waves, pad, trim def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int) -> np.array: """ Process each wave segment in a multi-threaded environment Args: mix_waves: (torch.Tensor) Wave segments to be processed trim: (int) Number of samples trimmed during padding pad: (int) Number of samples padded during padding q: (queue.Queue) Queue to hold the processed wave segments _id: (int) Identifier of the processed wave segment Returns: numpy array: Processed wave segment """ mix_waves = mix_waves.split(1) with torch.no_grad(): pw = [] for mix_wave in mix_waves: self.prog.update() spec = self.model.stft(mix_wave) processed_spec = torch.tensor(self.process(spec)) processed_wav = self.model.istft( processed_spec.to(self.device) ) processed_wav = ( processed_wav[:, :, trim:-trim] .transpose(0, 1) .reshape(2, -1) .cpu() .numpy() ) pw.append(processed_wav) processed_signal = np.concatenate(pw, axis=-1)[:, :-pad] q.put({_id: processed_signal}) return processed_signal def process_wave(self, wave: np.array, mt_threads=1) -> np.array: """ Process the wave array in a multi-threaded environment Args: wave: (np.array) Wave array to be processed mt_threads: (int) Number of threads to be used for processing Returns: numpy array: Processed wave array """ self.prog = tqdm(total=0) chunk = wave.shape[-1] // mt_threads waves = self.segment(wave, False, chunk) # Create a queue to hold the processed wave segments q = queue.Queue() threads = [] for c, batch in enumerate(waves): mix_waves, pad, trim = self.pad_wave(batch) self.prog.total = len(mix_waves) * mt_threads thread = threading.Thread( target=self._process_wave, args=(mix_waves, trim, pad, q, c) ) thread.start() threads.append(thread) for thread in threads: thread.join() self.prog.close() processed_batches = [] while not q.empty(): processed_batches.append(q.get()) processed_batches = [ list(wave.values())[0] for wave in sorted( processed_batches, key=lambda d: list(d.keys())[0] ) ] assert len(processed_batches) == len( waves ), "Incomplete processed batches, please reduce batch size!" return self.segment(processed_batches, True, chunk) @spaces.GPU() def run_mdx(model_params: Dict, input_filename: Path, output_dir: Path, model_path: Path, denoise: bool = False, m_threads: int = 2, device_base: str = "cuda", ) -> Tuple[str, str]: """ Separate vocals using MDX model """ if device_base == "cuda": device = torch.device("cuda:0") processor_num = 0 device_properties = torch.cuda.get_device_properties(device) vram_gb = device_properties.total_memory / 1024**3 m_threads = 1 if vram_gb < 8 else (8 if vram_gb > 32 else 2) else: device = torch.device("cpu") processor_num = -1 m_threads = 1 print(f"device: {device}") model_hash = MDX.get_hash(model_path) # type: str mp = model_params.get(model_hash) model = MDXModel( device, dim_f=mp["mdx_dim_f_set"], dim_t=2 ** mp["mdx_dim_t_set"], n_fft=mp["mdx_n_fft_scale_set"], stem_name=mp["primary_stem"], compensation=mp["compensate"], ) mdx_sess = MDX(model_path, model, processor=processor_num) wave, sr = librosa.load(input_filename, mono=False, sr=44100) # normalizing input wave gives better output peak = max(np.max(wave), abs(np.min(wave))) wave /= peak if denoise: wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads)) # type: np.array wave_processed *= 0.5 else: wave_processed = mdx_sess.process_wave(wave, m_threads) # return to previous peak wave_processed *= peak stem_name = model.stem_name # output main track main_filepath = output_dir / input_filename.with_name(f"{input_filename.stem}_{stem_name}.wav") sf.write(main_filepath, wave_processed.T, sr) # output reverse track invert_filepath = output_dir / input_filename.with_name(f"{input_filename.stem}_{stem_name}_reverse.wav") sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr) del mdx_sess, wave_processed, wave gc.collect() torch.cuda.empty_cache() return main_filepath, invert_filepath @spaces.GPU() def run_mdx_return_np(model_params: Dict, input_filename: Path, model_path: Path, denoise: bool = False, m_threads: int = 2, device_base: str = "cuda", ) -> Tuple[np.ndarray, np.ndarray]: """ Separate vocals using MDX model """ if device_base == "cuda": device = torch.device("cuda:0") processor_num = 0 device_properties = torch.cuda.get_device_properties(device) vram_gb = device_properties.total_memory / 1024**3 m_threads = 1 if vram_gb < 8 else (8 if vram_gb > 32 else 2) else: device = torch.device("cpu") processor_num = -1 m_threads = 1 print(f"device: {device}") model_hash = MDX.get_hash(model_path) # type: str mp = model_params.get(model_hash) model = MDXModel( device, dim_f=mp["mdx_dim_f_set"], dim_t=2 ** mp["mdx_dim_t_set"], n_fft=mp["mdx_n_fft_scale_set"], stem_name=mp["primary_stem"], compensation=mp["compensate"], ) mdx_sess = MDX(model_path, model, processor=processor_num) wave, sr = librosa.load(input_filename, mono=False, sr=44100) # normalizing input wave gives better output peak = max(np.max(wave), abs(np.min(wave))) wave /= peak if denoise: wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads)) # type: np.array wave_processed *= 0.5 else: wave_processed = mdx_sess.process_wave(wave, m_threads) # return to previous peak wave_processed *= peak stem_name = model.stem_name # output main track main_track = wave_processed.T # output reverse track invert_track = (-wave_processed.T * model.compensation) + wave.T return main_track, invert_track def extract_bgm(mdx_model_params: Dict, input_filename: Path, model_bgm_path: Path, output_dir: Path, device_base: str = "cuda") -> Path: """ Extract pure background music, remove vocals """ background_path, _ = run_mdx(model_params=mdx_model_params, input_filename=input_filename, output_dir=output_dir, model_path=model_bgm_path, denoise=False, device_base=device_base, ) return background_path def extract_vocal(mdx_model_params: Dict, input_filename: Path, model_basic_vocal_path: Path, model_main_vocal_path: Path, output_dir: Path, main_vocals_flag: bool = False, device_base: str = "cuda") -> Path: """ Extract vocals """ # First use UVR-MDX-NET-Voc_FT.onnx basic vocal separation model vocals_path, _ = run_mdx(mdx_model_params, input_filename, output_dir, model_basic_vocal_path, denoise=True, device_base=device_base, ) # If "main_vocals_flag" is enabled, use UVR_MDXNET_KARA_2.onnx to further separate main vocals (Main) from backup vocals/background vocals (Backup) if main_vocals_flag: time.sleep(2) backup_vocals_path, main_vocals_path = run_mdx(mdx_model_params, output_dir, model_main_vocal_path, vocals_path, denoise=True, device_base=device_base, ) vocals_path = main_vocals_path # If "dereverb_flag" is enabled, use Reverb_HQ_By_FoxJoy.onnx for dereverberation # deactived since Model license unknown # if dereverb_flag: # time.sleep(2) # _, vocals_dereverb_path = run_mdx(mdx_model_params, # output_dir, # mdxnet_models_dir/"Reverb_HQ_By_FoxJoy.onnx", # vocals_path, # denoise=True, # device_base=device_base, # ) # vocals_path = vocals_dereverb_path return vocals_path def process_uvr_task(input_file_path: Path, output_dir: Path, models_path: Dict[str, Path], main_vocals_flag: bool = False, # If "Main" is enabled, use UVR_MDXNET_KARA_2.onnx to further separate main and backup vocals ) -> Tuple[Path, Path]: device_base = "cuda" if torch.cuda.is_available() else "cpu" # load mdx model definition with open("./mdx_models/model_data.json") as infile: mdx_model_params = json.load(infile) # type: Dict output_dir.mkdir(parents=True, exist_ok=True) input_file_path = convert_to_stereo_and_wav(input_file_path) # type: Path # 1. Extract pure background music, remove vocals background_path = extract_bgm(mdx_model_params, input_file_path, models_path["bgm"], output_dir, device_base=device_base) # 2. Separate vocals # First use UVR-MDX-NET-Voc_FT.onnx basic vocal separation model vocals_path = extract_vocal(mdx_model_params, input_file_path, models_path["basic_vocal"], models_path["main_vocal"], output_dir, main_vocals_flag=main_vocals_flag, device_base=device_base) return background_path, vocals_path def get_model_params(model_path: Path) -> Dict: """ Get model parameters from model path """ with open(model_path / "model_data.json") as infile: return json.load(infile) # type: Dict def inference_mdx(audio_file: str) -> list[str]: mdx_model_params = get_model_params(Path("./mdx_models")) audio_file = convert_to_stereo_and_wav(Path(audio_file)) # resampling at 44100 Hz device_base = "cuda" if torch.cuda.is_available() else "cpu" output_dir = Path("./out/mdx") os.makedirs(output_dir, exist_ok=True) model_bgm_path = MODELS_PATH["bgm"] background_path, vocal_path = run_mdx( model_params=mdx_model_params, input_filename=audio_file, output_dir=output_dir, model_path=model_bgm_path, denoise=False, device_base=device_base, ) return str(vocal_path), str(background_path) if __name__ == "__main__": # zero = torch.Tensor([0]).cuda() # print(f"zero.device: {zero.device}") app = gr.Interface( fn = inference_mdx, inputs = gr.Audio(type="filepath", label="Input"), outputs = [gr.Audio(type="filepath", label="Vocals"),gr.Audio(type="filepath", label="BGM")], title="MDXNET Music Source Separation", article="
KUIELab-MDX-Net: A Two-Stream Neural Network for Music Demixing | Github Repo | MIT License
", api_name="mdxnet_separation", ) app.launch()