source-separation / separate.py
csukuangfj's picture
first version
009ec32
raw
history blame
3.68 kB
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
from functools import lru_cache
import ffmpeg
import numpy as np
from huggingface_hub import hf_hub_download
import sherpa_onnx
sample_rate = 44100
def load_audio(filename):
probe = ffmpeg.probe(filename)
if "streams" not in probe or len(probe["streams"]) == 0:
raise ValueError("No stream was found with ffprobe")
metadata = next(
stream for stream in probe["streams"] if stream["codec_type"] == "audio"
)
n_channels = metadata["channels"]
process = (
ffmpeg.input(filename)
.output("pipe:", format="f32le", ar=sample_rate)
.run_async(pipe_stdout=True, pipe_stderr=True)
)
buffer, _ = process.communicate()
waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
if n_channels > 2:
waveform = waveform[:, :2]
return waveform
@lru_cache(maxsize=10)
def get_file(
repo_id: str,
filename: str,
subfolder: str = "2stems",
) -> str:
nn_model_filename = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
)
return nn_model_filename
@lru_cache(maxsize=30)
def load_model(name: str):
if "spleeter" in name:
return load_spleeter_model(name)
elif "UVR" in name:
return load_uvr_model(name)
raise ValueError(f"Unsupported model name {name}")
def load_uvr_model(name: str):
model = get_file(
repo_id="k2-fsa/sherpa-onnx-models",
subfolder="source-separation-models",
filename=name,
)
config = sherpa_onnx.OfflineSourceSeparationConfig(
model=sherpa_onnx.OfflineSourceSeparationModelConfig(
uvr=sherpa_onnx.OfflineSourceSeparationUvrModelConfig(
model=model,
),
num_threads=2,
debug=False,
provider="cpu",
)
)
if not config.validate():
raise ValueError("Please check your config.")
return sherpa_onnx.OfflineSourceSeparation(config)
def load_spleeter_model(name: str):
if "fp16" in name:
suffix = "fp16.onnx"
elif "int8" in name:
suffix = "int8.onnx"
else:
suffix = ".onnx"
vocals = get_file(repo_id=f"csukuangfj/{name}", filename=f"vocals.{suffix}")
accompaniment = get_file(
repo_id=f"csukuangfj/{name}", filename=f"accompaniment.{suffix}"
)
config = sherpa_onnx.OfflineSourceSeparationConfig(
model=sherpa_onnx.OfflineSourceSeparationModelConfig(
spleeter=sherpa_onnx.OfflineSourceSeparationSpleeterModelConfig(
vocals=vocals,
accompaniment=accompaniment,
),
num_threads=2,
debug=False,
provider="cpu",
)
)
if not config.validate():
raise ValueError("Please check your config.")
return sherpa_onnx.OfflineSourceSeparation(config)
model_list = [
"sherpa-onnx-spleeter-2stems-fp16",
"sherpa-onnx-spleeter-2stems-int8",
"sherpa-onnx-spleeter-2stems",
"UVR-MDX-NET-Inst_1.onnx",
"UVR-MDX-NET-Inst_2.onnx",
"UVR-MDX-NET-Inst_3.onnx",
"UVR-MDX-NET-Inst_HQ_1.onnx",
"UVR-MDX-NET-Inst_HQ_2.onnx",
"UVR-MDX-NET-Inst_HQ_3.onnx",
"UVR-MDX-NET-Inst_HQ_4.onnx",
"UVR-MDX-NET-Inst_HQ_5.onnx",
"UVR-MDX-NET-Inst_Main.onnx",
"UVR-MDX-NET-Voc_FT.onnx",
"UVR-MDX-NET_Crowd_HQ_1.onnx",
"UVR_MDXNET_1_9703.onnx",
"UVR_MDXNET_2_9682.onnx",
"UVR_MDXNET_3_9662.onnx",
"UVR_MDXNET_9482.onnx",
"UVR_MDXNET_KARA.onnx",
"UVR_MDXNET_KARA_2.onnx",
"UVR_MDXNET_Main.onnx",
]