DurgaDeepak commited on
Commit
79acde6
·
verified ·
1 Parent(s): e999761

Update utils/model_downloader.py

Browse files
Files changed (1) hide show
  1. utils/model_downloader.py +67 -46
utils/model_downloader.py CHANGED
@@ -1,46 +1,67 @@
1
- import os
2
- import urllib.request
3
- import logging
4
-
5
- # Configure Logger
6
- logger = logging.getLogger(__name__)
7
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
8
-
9
- # Model URLs for downloading if not present locally
10
- MODEL_URLS = {
11
- "dpt_hybrid_384": "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt",
12
- "midas_v21_small_256": "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt",
13
- "yolov5n-seg": "https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n-seg.pt",
14
- "yolov5s-seg": "https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s-seg.pt",
15
- }
16
-
17
-
18
- def download_model_if_needed(model_key: str, save_path: str):
19
- """
20
- Downloads a model file if it does not already exist.
21
-
22
- Args:
23
- model_key (str): The key representing the model in MODEL_URLS.
24
- save_path (str): The local path where the model should be saved.
25
-
26
- Raises:
27
- ValueError: If the model_key does not exist in MODEL_URLS.
28
- """
29
- url = MODEL_URLS.get(model_key)
30
-
31
- if not url:
32
- logger.error(f"Model key '{model_key}' is not defined in MODEL_URLS.")
33
- raise ValueError(f"No URL configured for model key: {model_key}")
34
-
35
- if os.path.exists(save_path):
36
- logger.info(f"Model '{model_key}' already exists at '{save_path}'. Skipping download.")
37
- return
38
-
39
- try:
40
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
41
- logger.info(f"Downloading '{model_key}' from '{url}' to '{save_path}'")
42
- urllib.request.urlretrieve(url, save_path)
43
- logger.info(f"Successfully downloaded '{model_key}' to '{save_path}'")
44
- except Exception as e:
45
- logger.error(f"Failed to download '{model_key}': {e}")
46
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ # Configure logger
6
+ logger = logging.getLogger(__name__)
7
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
8
+
9
+ # Supported Hugging Face models (key repo + file)
10
+ HF_MODELS = {
11
+ # Depth Estimation
12
+ "dpt_hybrid_384": ("isl-org/MiDaS", "dpt_hybrid_384.pt"),
13
+ "midas_v21_small_256": ("isl-org/MiDaS", "midas_v21_small_256.pt"),
14
+ "midas_v21_384": ("isl-org/MiDaS", "midas_v21_384.pt"),
15
+ "dpt_swin2_large_384": ("isl-org/MiDaS", "dpt_swin2_large_384.pt"),
16
+ "dpt_beit_large_512": ("isl-org/MiDaS", "dpt_beit_large_512.pt"),
17
+
18
+ # Object Detection
19
+ "yolov5n": ("ultralytics/yolov5", "yolov5n.pt"),
20
+ "yolov5s": ("ultralytics/yolov5", "yolov5s.pt"),
21
+ "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
22
+ "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
23
+ "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
24
+ "rtdetr": ("IDEA-Research/RT-DETR", "rtdetr_r50vd_detr.pth"),
25
+
26
+ # Semantic Segmentation
27
+ "segformer_b0": ("nvidia/segformer-b0-finetuned-ade-512-512", "model.safetensors"),
28
+ "segformer_b5": ("nvidia/segformer-b5-finetuned-ade-512-512", "model.safetensors"),
29
+ "deeplabv3_resnet50": ("facebook/deeplabv3-resnet50", "pytorch_model.bin"),
30
+ }
31
+
32
+
33
+ def download_model_if_needed(model_key: str, save_path: str):
34
+ """
35
+ Downloads the model from Hugging Face Hub if it's not already present.
36
+
37
+ Args:
38
+ model_key (str): Key from HF_MODELS dict.
39
+ save_path (str): Local path to store the downloaded model.
40
+
41
+ Raises:
42
+ ValueError: If the model_key is not supported.
43
+ """
44
+ if model_key not in HF_MODELS:
45
+ logger.error(f" Model key '{model_key}' not found in registry.")
46
+ raise ValueError(f"Unsupported model key: {model_key}")
47
+
48
+ repo_id, filename = HF_MODELS[model_key]
49
+
50
+ if os.path.exists(save_path):
51
+ logger.info(f"✅ Model '{model_key}' already exists at '{save_path}'. Skipping download.")
52
+ return
53
+
54
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
55
+ logger.info(f"⬇️ Downloading '{model_key}' from Hugging Face Hub...")
56
+
57
+ try:
58
+ hf_hub_download(
59
+ repo_id=repo_id,
60
+ filename=filename,
61
+ cache_dir=os.path.dirname(save_path),
62
+ force_download=True # Set to False later if you want to cache
63
+ )
64
+ logger.info(f"✅ Successfully downloaded '{model_key}' to '{save_path}'")
65
+ except Exception as e:
66
+ logger.error(f"❌ Download failed for '{model_key}': {e}")
67
+ raise