File size: 3,676 Bytes
009ec32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
# Copyright    2023  Xiaomi Corp.        (authors: Fangjun Kuang)

from functools import lru_cache

import ffmpeg
import numpy as np
from huggingface_hub import hf_hub_download
import sherpa_onnx


sample_rate = 44100


def load_audio(filename):
    probe = ffmpeg.probe(filename)
    if "streams" not in probe or len(probe["streams"]) == 0:
        raise ValueError("No stream was found with ffprobe")

    metadata = next(
        stream for stream in probe["streams"] if stream["codec_type"] == "audio"
    )
    n_channels = metadata["channels"]

    process = (
        ffmpeg.input(filename)
        .output("pipe:", format="f32le", ar=sample_rate)
        .run_async(pipe_stdout=True, pipe_stderr=True)
    )
    buffer, _ = process.communicate()
    waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)

    if n_channels > 2:
        waveform = waveform[:, :2]

    return waveform


@lru_cache(maxsize=10)
def get_file(
    repo_id: str,
    filename: str,
    subfolder: str = "2stems",
) -> str:
    nn_model_filename = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        subfolder=subfolder,
    )
    return nn_model_filename


@lru_cache(maxsize=30)
def load_model(name: str):
    if "spleeter" in name:
        return load_spleeter_model(name)
    elif "UVR" in name:
        return load_uvr_model(name)

    raise ValueError(f"Unsupported model name {name}")


def load_uvr_model(name: str):
    model = get_file(
        repo_id="k2-fsa/sherpa-onnx-models",
        subfolder="source-separation-models",
        filename=name,
    )

    config = sherpa_onnx.OfflineSourceSeparationConfig(
        model=sherpa_onnx.OfflineSourceSeparationModelConfig(
            uvr=sherpa_onnx.OfflineSourceSeparationUvrModelConfig(
                model=model,
            ),
            num_threads=2,
            debug=False,
            provider="cpu",
        )
    )
    if not config.validate():
        raise ValueError("Please check your config.")

    return sherpa_onnx.OfflineSourceSeparation(config)


def load_spleeter_model(name: str):
    if "fp16" in name:
        suffix = "fp16.onnx"
    elif "int8" in name:
        suffix = "int8.onnx"
    else:
        suffix = ".onnx"

    vocals = get_file(repo_id=f"csukuangfj/{name}", filename=f"vocals.{suffix}")
    accompaniment = get_file(
        repo_id=f"csukuangfj/{name}", filename=f"accompaniment.{suffix}"
    )

    config = sherpa_onnx.OfflineSourceSeparationConfig(
        model=sherpa_onnx.OfflineSourceSeparationModelConfig(
            spleeter=sherpa_onnx.OfflineSourceSeparationSpleeterModelConfig(
                vocals=vocals,
                accompaniment=accompaniment,
            ),
            num_threads=2,
            debug=False,
            provider="cpu",
        )
    )
    if not config.validate():
        raise ValueError("Please check your config.")

    return sherpa_onnx.OfflineSourceSeparation(config)


model_list = [
    "sherpa-onnx-spleeter-2stems-fp16",
    "sherpa-onnx-spleeter-2stems-int8",
    "sherpa-onnx-spleeter-2stems",
    "UVR-MDX-NET-Inst_1.onnx",
    "UVR-MDX-NET-Inst_2.onnx",
    "UVR-MDX-NET-Inst_3.onnx",
    "UVR-MDX-NET-Inst_HQ_1.onnx",
    "UVR-MDX-NET-Inst_HQ_2.onnx",
    "UVR-MDX-NET-Inst_HQ_3.onnx",
    "UVR-MDX-NET-Inst_HQ_4.onnx",
    "UVR-MDX-NET-Inst_HQ_5.onnx",
    "UVR-MDX-NET-Inst_Main.onnx",
    "UVR-MDX-NET-Voc_FT.onnx",
    "UVR-MDX-NET_Crowd_HQ_1.onnx",
    "UVR_MDXNET_1_9703.onnx",
    "UVR_MDXNET_2_9682.onnx",
    "UVR_MDXNET_3_9662.onnx",
    "UVR_MDXNET_9482.onnx",
    "UVR_MDXNET_KARA.onnx",
    "UVR_MDXNET_KARA_2.onnx",
    "UVR_MDXNET_Main.onnx",
]