File size: 1,825 Bytes
b23251f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import urllib.request
import logging

# Configure Logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# Model URLs for downloading if not present locally
MODEL_URLS = {
    "dpt_hybrid_384": "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt",
    "midas_v21_small_256": "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt",
    "yolov5n-seg": "https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n-seg.pt",
    "yolov5s-seg": "https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s-seg.pt",
}


def download_model_if_needed(model_key: str, save_path: str):
    """

    Downloads a model file if it does not already exist.



    Args:

        model_key (str): The key representing the model in MODEL_URLS.

        save_path (str): The local path where the model should be saved.



    Raises:

        ValueError: If the model_key does not exist in MODEL_URLS.

    """
    url = MODEL_URLS.get(model_key)

    if not url:
        logger.error(f"Model key '{model_key}' is not defined in MODEL_URLS.")
        raise ValueError(f"No URL configured for model key: {model_key}")

    if os.path.exists(save_path):
        logger.info(f"Model '{model_key}' already exists at '{save_path}'. Skipping download.")
        return

    try:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        logger.info(f"Downloading '{model_key}' from '{url}' to '{save_path}'")
        urllib.request.urlretrieve(url, save_path)
        logger.info(f"Successfully downloaded '{model_key}' to '{save_path}'")
    except Exception as e:
        logger.error(f"Failed to download '{model_key}': {e}")
        raise