File size: 3,076 Bytes
8601497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import torch
import numpy as np
from PIL import Image
import logging
from utils.model_downloader import download_model_if_needed

# Configure Logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


class DepthEstimator:
    """
    Generalized Depth Estimation Model Wrapper for MiDaS and DPT models.
    Supports: MiDaS v2.1 Small, MiDaS v2.1 Large, DPT Hybrid, DPT Large.
    """

    def __init__(self, model_key="midas_v21_small_256", weights_dir="models/depth/weights", device="cpu"):
        """
        Initialize the Depth Estimation model.

        Args:
            model_key (str): Model identifier as defined in model_downloader.py.
            weights_dir (str): Directory to store/download model weights.
            device (str): Inference device ("cpu" or "cuda").
        """
        weights_path = os.path.join(weights_dir, f"{model_key}.pt")
        download_model_if_needed(model_key, weights_path)

        logger.info(f"Loading Depth model '{model_key}' from MiDaS hub")
        self.device = device
        self.model_type = self._resolve_model_type(model_key)
        self.midas = torch.hub.load("intel-isl/MiDaS", self.model_type).to(self.device).eval()
        self.transform = self._resolve_transform()

    def _resolve_model_type(self, model_key):
        """
        Maps model_key to MiDaS hub model type.
        """
        mapping = {
            "midas_v21_small_256": "MiDaS_small",
            "midas_v21_384": "MiDaS",
            "dpt_hybrid_384": "DPT_Hybrid",
            "dpt_large_384": "DPT_Large",
            "dpt_swin2_large_384": "DPT_Large",  # fallback to DPT_Large if not explicitly supported
            "dpt_beit_large_512": "DPT_Large",   # fallback to DPT_Large if not explicitly supported
        }
        return mapping.get(model_key, "MiDaS_small")

    def _resolve_transform(self):
        """
        Returns the correct transformation pipeline based on model type.
        """
        transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
        if self.model_type == "MiDaS_small":
            return transforms.small_transform
        else:
            return transforms.default_transform

    def predict(self, image: Image.Image, **kwargs):
        """
        Generates a depth map for the given image.

        Args:
            image (PIL.Image.Image): Input image.

        Returns:
            np.ndarray: Depth map as a 2D numpy array.
        """
        logger.info("Running depth estimation")
        input_tensor = self.transform(image).to(self.device)

        with torch.no_grad():
            prediction = self.midas(input_tensor)
            prediction = torch.nn.functional.interpolate(
                prediction.unsqueeze(1),
                size=image.size[::-1],
                mode="bicubic",
                align_corners=False,
            ).squeeze()

        depth_map = prediction.cpu().numpy()
        logger.info("Depth estimation completed successfully")
        return depth_map