masszhou commited on
Commit
b22128f
·
1 Parent(s): 62b370a

fixed client

Browse files
Files changed (3) hide show
  1. app.py +42 -14
  2. requirements.txt +4 -0
  3. utils.py +151 -0
app.py CHANGED
@@ -2,6 +2,18 @@ import gradio as gr
2
  import shutil
3
  import numpy as np
4
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def inference(audio_file):
@@ -10,6 +22,13 @@ def inference(audio_file):
10
  output_path1 = "downloaded_audio_1.wav"
11
  output_path2 = "downloaded_audio_2.wav"
12
 
 
 
 
 
 
 
 
13
  shutil.copy(audio_file, output_path1)
14
  shutil.copy(audio_file, output_path2)
15
 
@@ -21,6 +40,7 @@ def get_gui(theme, title, description):
21
  # Add title and description
22
  gr.Markdown(title)
23
  gr.Markdown(description)
 
24
 
25
  audio_input = gr.Audio(label="Audio file", type="filepath") # type: str | Path | bytes | tuple[int, np.ndarray] | None
26
  download_button = gr.Button("Inference")
@@ -39,24 +59,32 @@ if __name__ == "__main__":
39
  title = "<center><strong><font size='7'>Vocal BGM Separator</font></strong></center>"
40
  description = "This demo uses the MDX-Net models to perform Ultimate Vocal Remover (uvr) task for vocal and background sound separation."
41
  theme = "NoCrypt/miku"
42
- # app_gui = get_gui(theme, title, description)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  app_api = gr.Interface(
44
  fn=inference,
45
  inputs=gr.Audio(type="filepath"), # 接收文件路径(也可以换成 type="file")
46
  outputs=gr.File(file_count="multiple"), # 返回多个文件
47
  )
48
 
49
- # app = gr.TabbedInterface(
50
- # interface_list=[app_api, app_gui],
51
- # tab_names=["GUI", "API"]
52
- # )
53
-
54
- app_api.queue(default_concurrency_limit=40)
55
 
56
- app_api.launch(
57
- max_threads=40,
58
- share=False,
59
- show_error=True,
60
- quiet=False,
61
- debug=False,
62
- )
 
2
  import shutil
3
  import numpy as np
4
  from pathlib import Path
5
+ import os
6
+ from utils import get_hash
7
+ import time
8
+ import torch
9
+
10
+
11
+ def get_device_info():
12
+ if torch.cuda.is_available():
13
+ device = f"GPU ({torch.cuda.get_device_name(0)})"
14
+ else:
15
+ device = "CPU"
16
+ return f"当前运行环境: {device}"
17
 
18
 
19
  def inference(audio_file):
 
22
  output_path1 = "downloaded_audio_1.wav"
23
  output_path2 = "downloaded_audio_2.wav"
24
 
25
+ hash_audio = str(get_hash(audio_file))
26
+ media_dir = os.path.dirname(audio_file)
27
+
28
+ outputs = []
29
+
30
+ start_time = time.time()
31
+
32
  shutil.copy(audio_file, output_path1)
33
  shutil.copy(audio_file, output_path2)
34
 
 
40
  # Add title and description
41
  gr.Markdown(title)
42
  gr.Markdown(description)
43
+ gr.Markdown(get_device_info())
44
 
45
  audio_input = gr.Audio(label="Audio file", type="filepath") # type: str | Path | bytes | tuple[int, np.ndarray] | None
46
  download_button = gr.Button("Inference")
 
59
  title = "<center><strong><font size='7'>Vocal BGM Separator</font></strong></center>"
60
  description = "This demo uses the MDX-Net models to perform Ultimate Vocal Remover (uvr) task for vocal and background sound separation."
61
  theme = "NoCrypt/miku"
62
+
63
+ BASE_DIR = "." # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
64
+ mdxnet_models_dir = os.path.join(BASE_DIR, "mdx_models")
65
+ output_dir = os.path.join(BASE_DIR, "output")
66
+
67
+ # confirm entry points from client
68
+ # client_local = Client("http://127.0.0.1:7860")
69
+ # client = Client(f"{HF_USERNAME}/{HF_SPACENAME}", hf_token=HF_TOKEN)
70
+ # client_local.view_api()
71
+
72
+ # entry point for GUI
73
+ # predict(audio_file, api_name="/inference") -> result
74
+ app_gui = get_gui(theme, title, description)
75
+
76
+ # entry point for API
77
+ # predict(audio_file, api_name="/predict") -> output
78
  app_api = gr.Interface(
79
  fn=inference,
80
  inputs=gr.Audio(type="filepath"), # 接收文件路径(也可以换成 type="file")
81
  outputs=gr.File(file_count="multiple"), # 返回多个文件
82
  )
83
 
84
+ app = gr.TabbedInterface(
85
+ interface_list=[app_gui, app_api],
86
+ tab_names=["GUI", "API"]
87
+ )
 
 
88
 
89
+ app.queue(default_concurrency_limit=40)
90
+ app.launch()
 
 
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
  gradio
2
  torch
3
  torchaudio
 
 
 
 
 
1
  gradio
2
  torch
3
  torchaudio
4
+ librosa
5
+ onnxruntime
6
+ numpy
7
+ tqdm
utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, zipfile, shutil, subprocess, shlex, sys # noqa
2
+ from urllib.parse import urlparse
3
+ import re
4
+ import logging
5
+ import hashlib
6
+
7
+
8
+ def load_file_from_url(
9
+ url: str,
10
+ model_dir: str,
11
+ file_name: str | None = None,
12
+ overwrite: bool = False,
13
+ progress: bool = True,
14
+ ) -> str:
15
+ """Download a file from `url` into `model_dir`,
16
+ using the file present if possible.
17
+
18
+ Returns the path to the downloaded file.
19
+ """
20
+ os.makedirs(model_dir, exist_ok=True)
21
+ if not file_name:
22
+ parts = urlparse(url)
23
+ file_name = os.path.basename(parts.path)
24
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
25
+
26
+ # Overwrite
27
+ if os.path.exists(cached_file):
28
+ if overwrite or os.path.getsize(cached_file) == 0:
29
+ remove_files(cached_file)
30
+
31
+ # Download
32
+ if not os.path.exists(cached_file):
33
+ logger.info(f'Downloading: "{url}" to {cached_file}\n')
34
+ from torch.hub import download_url_to_file
35
+
36
+ download_url_to_file(url, cached_file, progress=progress)
37
+ else:
38
+ logger.debug(cached_file)
39
+
40
+ return cached_file
41
+
42
+
43
+ def friendly_name(file: str):
44
+ if file.startswith("http"):
45
+ file = urlparse(file).path
46
+
47
+ file = os.path.basename(file)
48
+ model_name, extension = os.path.splitext(file)
49
+ return model_name, extension
50
+
51
+
52
+ def download_manager(
53
+ url: str,
54
+ path: str,
55
+ extension: str = "",
56
+ overwrite: bool = False,
57
+ progress: bool = True,
58
+ ):
59
+ url = url.strip()
60
+
61
+ name, ext = friendly_name(url)
62
+ name += ext if not extension else f".{extension}"
63
+
64
+ if url.startswith("http"):
65
+ filename = load_file_from_url(
66
+ url=url,
67
+ model_dir=path,
68
+ file_name=name,
69
+ overwrite=overwrite,
70
+ progress=progress,
71
+ )
72
+ else:
73
+ filename = path
74
+
75
+ return filename
76
+
77
+
78
+ def remove_files(file_list):
79
+ if isinstance(file_list, str):
80
+ file_list = [file_list]
81
+
82
+ for file in file_list:
83
+ if os.path.exists(file):
84
+ os.remove(file)
85
+
86
+
87
+ def remove_directory_contents(directory_path):
88
+ """
89
+ Removes all files and subdirectories within a directory.
90
+
91
+ Parameters:
92
+ directory_path (str): Path to the directory whose
93
+ contents need to be removed.
94
+ """
95
+ if os.path.exists(directory_path):
96
+ for filename in os.listdir(directory_path):
97
+ file_path = os.path.join(directory_path, filename)
98
+ try:
99
+ if os.path.isfile(file_path):
100
+ os.remove(file_path)
101
+ elif os.path.isdir(file_path):
102
+ shutil.rmtree(file_path)
103
+ except Exception as e:
104
+ logger.error(f"Failed to delete {file_path}. Reason: {e}")
105
+ logger.info(f"Content in '{directory_path}' removed.")
106
+ else:
107
+ logger.error(f"Directory '{directory_path}' does not exist.")
108
+
109
+
110
+ # Create directory if not exists
111
+ def create_directories(directory_path):
112
+ if isinstance(directory_path, str):
113
+ directory_path = [directory_path]
114
+ for one_dir_path in directory_path:
115
+ if not os.path.exists(one_dir_path):
116
+ os.makedirs(one_dir_path)
117
+ logger.debug(f"Directory '{one_dir_path}' created.")
118
+
119
+
120
+ def setup_logger(name_log):
121
+ logger = logging.getLogger(name_log)
122
+ logger.setLevel(logging.INFO)
123
+
124
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
125
+ _default_handler.flush = sys.stderr.flush
126
+ logger.addHandler(_default_handler)
127
+
128
+ logger.propagate = False
129
+
130
+ handlers = logger.handlers
131
+
132
+ for handler in handlers:
133
+ formatter = logging.Formatter("[%(levelname)s] >> %(message)s")
134
+ handler.setFormatter(formatter)
135
+
136
+ # logger.handlers
137
+
138
+ return logger
139
+
140
+
141
+ logger = setup_logger("ss")
142
+ logger.setLevel(logging.INFO)
143
+
144
+
145
+ def get_hash(filepath):
146
+ with open(filepath, 'rb') as f:
147
+ file_hash = hashlib.blake2b()
148
+ while chunk := f.read(8192):
149
+ file_hash.update(chunk)
150
+
151
+ return file_hash.hexdigest()[:18]