import os import logging from huggingface_hub import hf_hub_download # Configure logger logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") # Supported Hugging Face models (key → repo + file) HF_MODELS = { # Depth Estimation "dpt_hybrid_384": ("isl-org/MiDaS", "dpt_hybrid_384.pt"), "midas_v21_small_256": ("isl-org/MiDaS", "midas_v21_small_256.pt"), "midas_v21_384": ("isl-org/MiDaS", "midas_v21_384.pt"), "dpt_swin2_large_384": ("isl-org/MiDaS", "dpt_swin2_large_384.pt"), "dpt_beit_large_512": ("isl-org/MiDaS", "dpt_beit_large_512.pt"), # Object Detection "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"), "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"), "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"), "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"), "rtdetr": ("IDEA-Research/RT-DETR", "rtdetr_r50vd_detr.pth"), # Semantic Segmentation "segformer_b0": ("nvidia/segformer-b0-finetuned-ade-512-512", "model.safetensors"), "segformer_b5": ("nvidia/segformer-b5-finetuned-ade-512-512", "model.safetensors"), "deeplabv3_resnet50": ("facebook/deeplabv3-resnet50", "pytorch_model.bin"), } def download_model_if_needed(model_key: str, save_path: str): """ Downloads the model from Hugging Face Hub if it's not already present. Args: model_key (str): Key from HF_MODELS dict. save_path (str): Local path to store the downloaded model. Raises: ValueError: If the model_key is not supported. """ if model_key not in HF_MODELS: logger.error(f" Model key '{model_key}' not found in registry.") raise ValueError(f"Unsupported model key: {model_key}") repo_id, filename = HF_MODELS[model_key] if os.path.exists(save_path): logger.info(f" Model '{model_key}' already exists at '{save_path}'. Skipping download.") return os.makedirs(os.path.dirname(save_path), exist_ok=True) logger.info(f" Downloading '{model_key}' from Hugging Face Hub...") try: hf_hub_download( repo_id=repo_id, filename=filename, cache_dir=os.path.dirname(save_path), force_download=True # Set to False later if you want to cache ) logger.info(f" Successfully downloaded '{model_key}' to '{save_path}'") except Exception as e: logger.error(f" Download failed for '{model_key}': {e}") raise