DurgaDeepak commited on
Commit
8601497
·
verified ·
1 Parent(s): 08d991a

Update models/depth/depth_estimator.py

Browse files
Files changed (1) hide show
  1. models/depth/depth_estimator.py +85 -85
models/depth/depth_estimator.py CHANGED
@@ -1,85 +1,85 @@
1
- import os
2
- import torch
3
- import numpy as np
4
- from PIL import Image
5
- import logging
6
- from utils.model_downloader import download_model_if_needed
7
-
8
- # Configure Logger
9
- logger = logging.getLogger(__name__)
10
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
11
-
12
-
13
- class DepthEstimator:
14
- """
15
- Generalized Depth Estimation Model Wrapper for MiDaS and DPT models.
16
- Supports: MiDaS v2.1 Small, MiDaS v2.1 Large, DPT Hybrid, DPT Large.
17
- """
18
-
19
- def __init__(self, model_key="midas_v21_small_256", weights_dir="models/depth/weights", device="cpu"):
20
- """
21
- Initialize the Depth Estimation model.
22
-
23
- Args:
24
- model_key (str): Model identifier as defined in model_downloader.py.
25
- weights_dir (str): Directory to store/download model weights.
26
- device (str): Inference device ("cpu" or "cuda").
27
- """
28
- weights_path = os.path.join(weights_dir, f"{model_key}.pt")
29
- download_model_if_needed(model_key, weights_path)
30
-
31
- logger.info(f"Loading Depth model '{model_key}' from MiDaS hub")
32
- self.device = device
33
- self.model_type = self._resolve_model_type(model_key)
34
- self.midas = torch.hub.load("intel-isl/MiDaS", self.model_type).to(self.device).eval()
35
- self.transform = self._resolve_transform()
36
-
37
- def _resolve_model_type(self, model_key):
38
- """
39
- Maps model_key to MiDaS hub model type.
40
- """
41
- mapping = {
42
- "midas_v21_small_256": "MiDaS_small",
43
- "midas_v21_384": "MiDaS",
44
- "dpt_hybrid_384": "DPT_Hybrid",
45
- "dpt_large_384": "DPT_Large",
46
- "dpt_swin2_large_384": "DPT_Large", # fallback to DPT_Large if not explicitly supported
47
- "dpt_beit_large_512": "DPT_Large", # fallback to DPT_Large if not explicitly supported
48
- }
49
- return mapping.get(model_key, "MiDaS_small")
50
-
51
- def _resolve_transform(self):
52
- """
53
- Returns the correct transformation pipeline based on model type.
54
- """
55
- transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
56
- if self.model_type == "MiDaS_small":
57
- return transforms.small_transform
58
- else:
59
- return transforms.default_transform
60
-
61
- def predict(self, image: Image.Image):
62
- """
63
- Generates a depth map for the given image.
64
-
65
- Args:
66
- image (PIL.Image.Image): Input image.
67
-
68
- Returns:
69
- np.ndarray: Depth map as a 2D numpy array.
70
- """
71
- logger.info("Running depth estimation")
72
- input_tensor = self.transform(image).to(self.device)
73
-
74
- with torch.no_grad():
75
- prediction = self.midas(input_tensor)
76
- prediction = torch.nn.functional.interpolate(
77
- prediction.unsqueeze(1),
78
- size=image.size[::-1],
79
- mode="bicubic",
80
- align_corners=False,
81
- ).squeeze()
82
-
83
- depth_map = prediction.cpu().numpy()
84
- logger.info("Depth estimation completed successfully")
85
- return depth_map
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import logging
6
+ from utils.model_downloader import download_model_if_needed
7
+
8
+ # Configure Logger
9
+ logger = logging.getLogger(__name__)
10
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
11
+
12
+
13
+ class DepthEstimator:
14
+ """
15
+ Generalized Depth Estimation Model Wrapper for MiDaS and DPT models.
16
+ Supports: MiDaS v2.1 Small, MiDaS v2.1 Large, DPT Hybrid, DPT Large.
17
+ """
18
+
19
+ def __init__(self, model_key="midas_v21_small_256", weights_dir="models/depth/weights", device="cpu"):
20
+ """
21
+ Initialize the Depth Estimation model.
22
+
23
+ Args:
24
+ model_key (str): Model identifier as defined in model_downloader.py.
25
+ weights_dir (str): Directory to store/download model weights.
26
+ device (str): Inference device ("cpu" or "cuda").
27
+ """
28
+ weights_path = os.path.join(weights_dir, f"{model_key}.pt")
29
+ download_model_if_needed(model_key, weights_path)
30
+
31
+ logger.info(f"Loading Depth model '{model_key}' from MiDaS hub")
32
+ self.device = device
33
+ self.model_type = self._resolve_model_type(model_key)
34
+ self.midas = torch.hub.load("intel-isl/MiDaS", self.model_type).to(self.device).eval()
35
+ self.transform = self._resolve_transform()
36
+
37
+ def _resolve_model_type(self, model_key):
38
+ """
39
+ Maps model_key to MiDaS hub model type.
40
+ """
41
+ mapping = {
42
+ "midas_v21_small_256": "MiDaS_small",
43
+ "midas_v21_384": "MiDaS",
44
+ "dpt_hybrid_384": "DPT_Hybrid",
45
+ "dpt_large_384": "DPT_Large",
46
+ "dpt_swin2_large_384": "DPT_Large", # fallback to DPT_Large if not explicitly supported
47
+ "dpt_beit_large_512": "DPT_Large", # fallback to DPT_Large if not explicitly supported
48
+ }
49
+ return mapping.get(model_key, "MiDaS_small")
50
+
51
+ def _resolve_transform(self):
52
+ """
53
+ Returns the correct transformation pipeline based on model type.
54
+ """
55
+ transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
56
+ if self.model_type == "MiDaS_small":
57
+ return transforms.small_transform
58
+ else:
59
+ return transforms.default_transform
60
+
61
+ def predict(self, image: Image.Image, **kwargs):
62
+ """
63
+ Generates a depth map for the given image.
64
+
65
+ Args:
66
+ image (PIL.Image.Image): Input image.
67
+
68
+ Returns:
69
+ np.ndarray: Depth map as a 2D numpy array.
70
+ """
71
+ logger.info("Running depth estimation")
72
+ input_tensor = self.transform(image).to(self.device)
73
+
74
+ with torch.no_grad():
75
+ prediction = self.midas(input_tensor)
76
+ prediction = torch.nn.functional.interpolate(
77
+ prediction.unsqueeze(1),
78
+ size=image.size[::-1],
79
+ mode="bicubic",
80
+ align_corners=False,
81
+ ).squeeze()
82
+
83
+ depth_map = prediction.cpu().numpy()
84
+ logger.info("Depth estimation completed successfully")
85
+ return depth_map