kemuriririn commited on
Commit
506ecd3
·
1 Parent(s): 71f87cb

Download models

Browse files
Files changed (2) hide show
  1. tools/download_files.py +111 -0
  2. webui.py +3 -2
tools/download_files.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import zipfile
3
+ import os
4
+ import argparse
5
+
6
+ def download_file_from_google_drive(file_id, destination):
7
+ """
8
+ 通过文件ID下载Google Drive共享文件
9
+
10
+ Args:
11
+ file_id (str): Google Drive文件的ID
12
+ destination (str): 本地保存路径
13
+ """
14
+ # 基本的下载URL
15
+ URL = "https://docs.google.com/uc?export=download"
16
+
17
+ session = requests.Session()
18
+
19
+ # 发起初始GET请求
20
+ response = session.get(URL, params={'id': file_id}, stream=True)
21
+ token = get_confirm_token(response) # 从响应中获取确认令牌(如果需要)
22
+
23
+ if token: # 如果需要确认(大文件)
24
+ params = {'id': file_id, 'confirm': token}
25
+ response = session.get(URL, params=params, stream=True)
26
+
27
+ # 将响应内容保存到文件
28
+ save_response_content(response, destination)
29
+
30
+ def get_confirm_token(response):
31
+ """
32
+ 从响应中检查是否存在下载确认令牌(cookie)
33
+
34
+ Args:
35
+ response (requests.Response): 响应对象
36
+
37
+ Returns:
38
+ str: 确认令牌的值(如果存在),否则为None
39
+ """
40
+ for key, value in response.cookies.items():
41
+ if key.startswith('download_warning'): # 确认令牌的cookie通常以这个开头
42
+ return value
43
+ return None
44
+
45
+ def save_response_content(response, destination, chunk_size=32768):
46
+ """
47
+ 以流式方式将响应内容写入文件,支持大文件下载。
48
+
49
+ Args:
50
+ response (requests.Response): 流式响应对象
51
+ destination (str): 本地保存路径
52
+ chunk_size (int, optional): 每次迭代写入的块大小. Defaults to 32768.
53
+ """
54
+ with open(destination, "wb") as f:
55
+ for chunk in response.iter_content(chunk_size):
56
+ if chunk: # 过滤掉保持连接的空白块
57
+ f.write(chunk)
58
+
59
+ def download_model_from_modelscope(model_id, destination):
60
+ """
61
+ 从ModelScope下载模型(伪代码,需根据实际API实现)
62
+ Args:
63
+ model_id (str): ModelScope模型ID
64
+ destination (str): 本地保存路径
65
+ """
66
+ print(f"[ModelScope] Downloading models to {destination},model cache dir={hf_cache_dir}")
67
+ from modelscope import snapshot_download
68
+ snapshot_download("IndexTeam/IndexTTS-2", local_dir="checkpoints")
69
+ snapshot_download("amphion/MaskGCT", local_dir="checkpoints/hf_cache/models--amphion--MaskGCT")
70
+ snapshot_download("facebook/w2v-bert-2.0",local_dir="checkpoints/hf_cache/models--facebook--w2v-bert-2.0")
71
+ snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir="checkpoints/hf_cache/models--nvidia--bigvgan_v2_22khz_80band_256x")
72
+ # models--funasr--campplus
73
+ snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir="checkpoints/hf_cache/models--nvidia--bigvgan_v2_22khz_80band_256x")
74
+
75
+ def download_model_from_huggingface(destination,hf_cache_dir):
76
+ """
77
+ 从HuggingFace下载模型(伪代码,需根据实际API实现)
78
+ Args:
79
+ model_id (str): HuggingFace模型ID
80
+ destination (str): 本地保存路径
81
+ """
82
+ print(f"[HuggingFace] Downloading models to {destination},model cache dir={hf_cache_dir}")
83
+ from huggingface_hub import snapshot_download
84
+ snapshot_download("IndexTeam/IndexTTS-2", local_dir=destination)
85
+ snapshot_download("amphion/MaskGCT", local_dir=os.path.join(hf_cache_dir,"models--amphion--MaskGCT"))
86
+ snapshot_download("facebook/w2v-bert-2.0",local_dir=os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"))
87
+ snapshot_download("nvidia/bigvgan_v2_22khz_80band_256x",local_dir=os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"))
88
+ snapshot_download("funasr/campplus",local_dir=os.path.join(hf_cache_dir,"models--funasr--campplus"))
89
+
90
+ # 使用示例
91
+ if __name__ == "__main__":
92
+ parser = argparse.ArgumentParser(description="下载文件和模型工具")
93
+ parser.add_argument('--model_source', choices=['modelscope', 'huggingface'], default=None, help='模型下载来源')
94
+ args = parser.parse_args()
95
+
96
+ if args.model_source:
97
+ if args.model_source == 'modelscope':
98
+ download_model_from_modelscope("checkpoints",os.path.join("checkpoints","hf_cache"))
99
+ elif args.model_source == 'huggingface':
100
+ download_model_from_huggingface("checkpoints",os.path.join("checkpoints","hf_cache"))
101
+
102
+ print("Downloading example files from Google Drive...")
103
+ file_id = "1o_dCMzwjaA2azbGOxAE7-4E7NbJkgdgO"
104
+ destination = "example_wavs.zip" # 替换为你希望的本地路径
105
+ download_file_from_google_drive(file_id, destination)
106
+ print(f"File downloaded to: {destination}")
107
+ # 解压下载的zip文件到examples目录
108
+ examples_dir = "examples"
109
+ with zipfile.ZipFile(destination, 'r') as zip_ref:
110
+ zip_ref.extractall(examples_dir)
111
+ print(f"File extracted to: {examples_dir}")
webui.py CHANGED
@@ -24,8 +24,9 @@ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the
24
  parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
25
  parser.add_argument("--is_fp16", action="store_true", default=False, help="Fp16 infer")
26
  cmd_args = parser.parse_args()
27
- from huggingface_hub import snapshot_download
28
- snapshot_download(repo_id="IndexTeam/IndexTTS-2", local_dir="./checkpoints")
 
29
 
30
  if not os.path.exists(cmd_args.model_dir):
31
  print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
 
24
  parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
25
  parser.add_argument("--is_fp16", action="store_true", default=False, help="Fp16 infer")
26
  cmd_args = parser.parse_args()
27
+
28
+ from tools.download_files import download_model_from_huggingface
29
+ download_model_from_huggingface("checkpoints",os.path.join(current_dir, "hf_cache"))
30
 
31
  if not os.path.exists(cmd_args.model_dir):
32
  print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")