Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import numpy as np | |
from PIL import Image | |
import logging | |
from utils.model_downloader import download_model_if_needed | |
# Configure Logger | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
class DepthEstimator: | |
""" | |
Generalized Depth Estimation Model Wrapper for MiDaS and DPT models. | |
Supports: MiDaS v2.1 Small, MiDaS v2.1 Large, DPT Hybrid, DPT Large. | |
""" | |
def __init__(self, model_key="midas_v21_small_256", weights_dir="models/depth/weights", device="cpu"): | |
""" | |
Initialize the Depth Estimation model. | |
Args: | |
model_key (str): Model identifier as defined in model_downloader.py. | |
weights_dir (str): Directory to store/download model weights. | |
device (str): Inference device ("cpu" or "cuda"). | |
""" | |
weights_path = os.path.join(weights_dir, f"{model_key}.pt") | |
download_model_if_needed(model_key, weights_path) | |
logger.info(f"Loading Depth model '{model_key}' from MiDaS hub") | |
self.device = device | |
self.model_type = self._resolve_model_type(model_key) | |
self.midas = torch.hub.load("intel-isl/MiDaS", self.model_type).to(self.device).eval() | |
self.transform = self._resolve_transform() | |
def _resolve_model_type(self, model_key): | |
""" | |
Maps model_key to MiDaS hub model type. | |
""" | |
mapping = { | |
"midas_v21_small_256": "MiDaS_small", | |
"midas_v21_384": "MiDaS", | |
"dpt_hybrid_384": "DPT_Hybrid", | |
"dpt_large_384": "DPT_Large", | |
"dpt_swin2_large_384": "DPT_Large", # fallback to DPT_Large if not explicitly supported | |
"dpt_beit_large_512": "DPT_Large", # fallback to DPT_Large if not explicitly supported | |
} | |
return mapping.get(model_key, "MiDaS_small") | |
def _resolve_transform(self): | |
""" | |
Returns the correct transformation pipeline based on model type. | |
""" | |
transforms = torch.hub.load("intel-isl/MiDaS", "transforms") | |
if self.model_type == "MiDaS_small": | |
return transforms.small_transform | |
else: | |
return transforms.default_transform | |
def predict(self, image: Image.Image, **kwargs): | |
""" | |
Generates a depth map for the given image. | |
Args: | |
image (PIL.Image.Image): Input image. | |
Returns: | |
np.ndarray: Depth map as a 2D numpy array. | |
""" | |
logger.info("Running depth estimation") | |
input_tensor = self.transform(image).to(self.device) | |
with torch.no_grad(): | |
prediction = self.midas(input_tensor) | |
prediction = torch.nn.functional.interpolate( | |
prediction.unsqueeze(1), | |
size=image.size[::-1], | |
mode="bicubic", | |
align_corners=False, | |
).squeeze() | |
depth_map = prediction.cpu().numpy() | |
logger.info("Depth estimation completed successfully") | |
return depth_map | |