csukuangfj commited on
Commit
33085cc
·
1 Parent(s): 5856dbb

minor fixes

Browse files
Files changed (2) hide show
  1. app.py +19 -7
  2. separate.py +27 -21
app.py CHANGED
@@ -107,11 +107,11 @@ def process(model_name, in_filename: str):
107
  logging.info(f"model_name: {model_name}")
108
  logging.info(f"in_filename: {in_filename}")
109
 
110
- waveform = load_audio(in_filename)
111
- waveform = np.transpose(waveform)
112
- waveform = np.ascontiguousarray(waveform)
113
 
114
- duration = waveform.shape[1] / 44100 # in seconds
115
 
116
  sp = load_model(model_name)
117
 
@@ -121,7 +121,7 @@ def process(model_name, in_filename: str):
121
 
122
  start = time.time()
123
 
124
- output = sp.process(sample_rate=44100, samples=waveform)
125
 
126
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
127
  end = time.time()
@@ -154,6 +154,17 @@ def process(model_name, in_filename: str):
154
 
155
 
156
  title = "# Source separation with Next-gen Kaldi"
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  # css style is copied from
159
  # https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113
@@ -172,9 +183,9 @@ with demo:
172
  gr.Markdown(title)
173
 
174
  model_dropdown = gr.Dropdown(
175
- choices=model_list[model_list[0]],
176
  label="Select a model",
177
- value=model_list[model_list[0]],
178
  )
179
 
180
  with gr.Tabs():
@@ -259,6 +270,7 @@ with demo:
259
  inputs=[model_dropdown, url_textbox],
260
  outputs=[url_vocals, url_non_vocals, url_html_info],
261
  )
 
262
 
263
  if __name__ == "__main__":
264
  formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
107
  logging.info(f"model_name: {model_name}")
108
  logging.info(f"in_filename: {in_filename}")
109
 
110
+ samples, sample_rate = load_audio(in_filename)
111
+ samples = np.transpose(samples)
112
+ samples = np.ascontiguousarray(samples)
113
 
114
+ duration = samples.shape[1] / sample_rate # in seconds
115
 
116
  sp = load_model(model_name)
117
 
 
121
 
122
  start = time.time()
123
 
124
+ output = sp.process(sample_rate=sample_rate, samples=samples)
125
 
126
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
127
  end = time.time()
 
154
 
155
 
156
  title = "# Source separation with Next-gen Kaldi"
157
+ description = """
158
+ This space shows how to do source separation with Next-gen Kaldi.
159
+
160
+ It is running on CPU within a docker container provided by Hugging Face.
161
+
162
+ See more information by visiting the following links:
163
+
164
+ - <https://github.com/k2-fsa/sherpa-onnx>
165
+
166
+ Everything is open-sourced.
167
+ """
168
 
169
  # css style is copied from
170
  # https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113
 
183
  gr.Markdown(title)
184
 
185
  model_dropdown = gr.Dropdown(
186
+ choices=model_list[0],
187
  label="Select a model",
188
+ value=model_list[0],
189
  )
190
 
191
  with gr.Tabs():
 
270
  inputs=[model_dropdown, url_textbox],
271
  outputs=[url_vocals, url_non_vocals, url_html_info],
272
  )
273
+ gr.Markdown(description)
274
 
275
  if __name__ == "__main__":
276
  formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
separate.py CHANGED
@@ -1,39 +1,45 @@
1
  #!/usr/bin/env python3
2
  # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
3
 
 
 
4
  from functools import lru_cache
5
 
6
- import ffmpeg
7
  import numpy as np
8
- from huggingface_hub import hf_hub_download
9
  import sherpa_onnx
 
 
 
10
 
11
 
12
- sample_rate = 44100
 
 
 
13
 
 
 
 
 
14
 
15
- def load_audio(filename):
16
- probe = ffmpeg.probe(filename)
17
- if "streams" not in probe or len(probe["streams"]) == 0:
18
- raise ValueError("No stream was found with ffprobe")
19
 
20
- metadata = next(
21
- stream for stream in probe["streams"] if stream["codec_type"] == "audio"
22
- )
23
- n_channels = metadata["channels"]
24
 
25
- process = (
26
- ffmpeg.input(filename)
27
- .output("pipe:", format="f32le", ar=sample_rate)
28
- .run_async(pipe_stdout=True, pipe_stderr=True)
29
- )
30
- buffer, _ = process.communicate()
31
- waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
 
 
32
 
33
- if n_channels > 2:
34
- waveform = waveform[:, :2]
 
35
 
36
- return waveform
37
 
38
 
39
  @lru_cache(maxsize=10)
 
1
  #!/usr/bin/env python3
2
  # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
3
 
4
+ import logging
5
+ import os
6
  from functools import lru_cache
7
 
 
8
  import numpy as np
 
9
  import sherpa_onnx
10
+ import soundfile as sf
11
+ from huggingface_hub import hf_hub_download
12
+ import uuid
13
 
14
 
15
+ def convert_to_wav(in_filename: str) -> str:
16
+ """Convert the input audio file to a wave file"""
17
+ out_filename = str(uuid.uuid4())
18
+ out_filename = f"{in_filename}.wav"
19
 
20
+ logging.info(f"Converting '{in_filename}' to '{out_filename}'")
21
+ _ = os.system(
22
+ f"ffmpeg -hide_banner -loglevel error -i '{in_filename}' -ar 441000 -ac 2 '{out_filename}' -y"
23
+ )
24
 
25
+ return out_filename
 
 
 
26
 
 
 
 
 
27
 
28
+ def load_audio(filename):
29
+ filename = convert_to_wav(filename)
30
+
31
+ samples, sample_rate = sf.read(filename, dtype="float32", always_2d=True)
32
+ samples = np.transpose(samples)
33
+ # now samples is of shape (num_channels, num_samples)
34
+ assert (
35
+ samples.shape[1] > samples.shape[0]
36
+ ), f"You should use (num_channels, num_samples). {samples.shape}"
37
 
38
+ assert (
39
+ samples.dtype == np.float32
40
+ ), f"Expect np.float32 as dtype. Given: {samples.dtype}"
41
 
42
+ return samples, sample_rate
43
 
44
 
45
  @lru_cache(maxsize=10)