UVIS / models /depth /depth_estimator.py
DurgaDeepak's picture
Update models/depth/depth_estimator.py
8601497 verified
import os
import torch
import numpy as np
from PIL import Image
import logging
from utils.model_downloader import download_model_if_needed
# Configure Logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
class DepthEstimator:
"""
Generalized Depth Estimation Model Wrapper for MiDaS and DPT models.
Supports: MiDaS v2.1 Small, MiDaS v2.1 Large, DPT Hybrid, DPT Large.
"""
def __init__(self, model_key="midas_v21_small_256", weights_dir="models/depth/weights", device="cpu"):
"""
Initialize the Depth Estimation model.
Args:
model_key (str): Model identifier as defined in model_downloader.py.
weights_dir (str): Directory to store/download model weights.
device (str): Inference device ("cpu" or "cuda").
"""
weights_path = os.path.join(weights_dir, f"{model_key}.pt")
download_model_if_needed(model_key, weights_path)
logger.info(f"Loading Depth model '{model_key}' from MiDaS hub")
self.device = device
self.model_type = self._resolve_model_type(model_key)
self.midas = torch.hub.load("intel-isl/MiDaS", self.model_type).to(self.device).eval()
self.transform = self._resolve_transform()
def _resolve_model_type(self, model_key):
"""
Maps model_key to MiDaS hub model type.
"""
mapping = {
"midas_v21_small_256": "MiDaS_small",
"midas_v21_384": "MiDaS",
"dpt_hybrid_384": "DPT_Hybrid",
"dpt_large_384": "DPT_Large",
"dpt_swin2_large_384": "DPT_Large", # fallback to DPT_Large if not explicitly supported
"dpt_beit_large_512": "DPT_Large", # fallback to DPT_Large if not explicitly supported
}
return mapping.get(model_key, "MiDaS_small")
def _resolve_transform(self):
"""
Returns the correct transformation pipeline based on model type.
"""
transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
if self.model_type == "MiDaS_small":
return transforms.small_transform
else:
return transforms.default_transform
def predict(self, image: Image.Image, **kwargs):
"""
Generates a depth map for the given image.
Args:
image (PIL.Image.Image): Input image.
Returns:
np.ndarray: Depth map as a 2D numpy array.
"""
logger.info("Running depth estimation")
input_tensor = self.transform(image).to(self.device)
with torch.no_grad():
prediction = self.midas(input_tensor)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
).squeeze()
depth_map = prediction.cpu().numpy()
logger.info("Depth estimation completed successfully")
return depth_map