import os import torch import numpy as np from PIL import Image import logging from utils.model_downloader import download_model_if_needed # Configure Logger logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") class DepthEstimator: """ Generalized Depth Estimation Model Wrapper for MiDaS and DPT models. Supports: MiDaS v2.1 Small, MiDaS v2.1 Large, DPT Hybrid, DPT Large. """ def __init__(self, model_key="midas_v21_small_256", weights_dir="models/depth/weights", device="cpu"): """ Initialize the Depth Estimation model. Args: model_key (str): Model identifier as defined in model_downloader.py. weights_dir (str): Directory to store/download model weights. device (str): Inference device ("cpu" or "cuda"). """ weights_path = os.path.join(weights_dir, f"{model_key}.pt") download_model_if_needed(model_key, weights_path) logger.info(f"Loading Depth model '{model_key}' from MiDaS hub") self.device = device self.model_type = self._resolve_model_type(model_key) self.midas = torch.hub.load("intel-isl/MiDaS", self.model_type).to(self.device).eval() self.transform = self._resolve_transform() def _resolve_model_type(self, model_key): """ Maps model_key to MiDaS hub model type. """ mapping = { "midas_v21_small_256": "MiDaS_small", "midas_v21_384": "MiDaS", "dpt_hybrid_384": "DPT_Hybrid", "dpt_large_384": "DPT_Large", "dpt_swin2_large_384": "DPT_Large", # fallback to DPT_Large if not explicitly supported "dpt_beit_large_512": "DPT_Large", # fallback to DPT_Large if not explicitly supported } return mapping.get(model_key, "MiDaS_small") def _resolve_transform(self): """ Returns the correct transformation pipeline based on model type. """ transforms = torch.hub.load("intel-isl/MiDaS", "transforms") if self.model_type == "MiDaS_small": return transforms.small_transform else: return transforms.default_transform def predict(self, image: Image.Image, **kwargs): """ Generates a depth map for the given image. Args: image (PIL.Image.Image): Input image. Returns: np.ndarray: Depth map as a 2D numpy array. """ logger.info("Running depth estimation") input_tensor = self.transform(image).to(self.device) with torch.no_grad(): prediction = self.midas(input_tensor) prediction = torch.nn.functional.interpolate( prediction.unsqueeze(1), size=image.size[::-1], mode="bicubic", align_corners=False, ).squeeze() depth_map = prediction.cpu().numpy() logger.info("Depth estimation completed successfully") return depth_map