UVIS / utils /model_downloader.py
DurgaDeepak's picture
Update utils/model_downloader.py
84fb7e3 verified
import os
import logging
from huggingface_hub import hf_hub_download
# Configure logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
# Supported Hugging Face models (key → repo + file)
HF_MODELS = {
# Depth Estimation
"dpt_hybrid_384": ("isl-org/MiDaS", "dpt_hybrid_384.pt"),
"midas_v21_small_256": ("isl-org/MiDaS", "midas_v21_small_256.pt"),
"midas_v21_384": ("isl-org/MiDaS", "midas_v21_384.pt"),
"dpt_swin2_large_384": ("isl-org/MiDaS", "dpt_swin2_large_384.pt"),
"dpt_beit_large_512": ("isl-org/MiDaS", "dpt_beit_large_512.pt"),
# Object Detection
"yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
"yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
"yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
"yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
"rtdetr": ("IDEA-Research/RT-DETR", "rtdetr_r50vd_detr.pth"),
# Semantic Segmentation
"segformer_b0": ("nvidia/segformer-b0-finetuned-ade-512-512", "model.safetensors"),
"segformer_b5": ("nvidia/segformer-b5-finetuned-ade-512-512", "model.safetensors"),
"deeplabv3_resnet50": ("facebook/deeplabv3-resnet50", "pytorch_model.bin"),
}
def download_model_if_needed(model_key: str, save_path: str):
"""
Downloads the model from Hugging Face Hub if it's not already present.
Args:
model_key (str): Key from HF_MODELS dict.
save_path (str): Local path to store the downloaded model.
Raises:
ValueError: If the model_key is not supported.
"""
if model_key not in HF_MODELS:
logger.error(f" Model key '{model_key}' not found in registry.")
raise ValueError(f"Unsupported model key: {model_key}")
repo_id, filename = HF_MODELS[model_key]
if os.path.exists(save_path):
logger.info(f" Model '{model_key}' already exists at '{save_path}'. Skipping download.")
return
os.makedirs(os.path.dirname(save_path), exist_ok=True)
logger.info(f" Downloading '{model_key}' from Hugging Face Hub...")
try:
hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir=os.path.dirname(save_path),
force_download=True # Set to False later if you want to cache
)
logger.info(f" Successfully downloaded '{model_key}' to '{save_path}'")
except Exception as e:
logger.error(f" Download failed for '{model_key}': {e}")
raise