EgoHackZero
try to add segmentation step
6225fda
raw
history blame
3.17 kB
import numpy as np
import torch
import cv2
from PIL import Image
from transformers import pipeline
import gradio as gr
# ===== Device Setup =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_index = 0 if torch.cuda.is_available() else -1
# ===== MiDaS Depth Estimation Setup =====
# Load MiDaS model and transforms
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
midas.to(device).eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform
# ===== Segmentation Setup =====
segmenter = pipeline(
"image-segmentation",
model="nvidia/segformer-b0-finetuned-ade-512-512",
device=device_index,
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
)
# ===== Utility Functions =====
def resize_image(img: Image.Image, max_size: int = 512) -> Image.Image:
width, height = img.size
if max(width, height) > max_size:
ratio = max_size / max(width, height)
new_size = (int(width * ratio), int(height * ratio))
return img.resize(new_size, Image.LANCZOS)
return img
# ===== Depth Prediction =====
def predict_depth(image: Image.Image) -> Image.Image:
# Ensure input is PIL Image
img = image.convert('RGB') if not isinstance(image, Image.Image) else image
img_np = np.array(img)
# Convert to the format expected by MiDaS
input_tensor = transform(img_np).to(device)
input_batch = input_tensor.unsqueeze(0) if input_tensor.ndim == 3 else input_tensor
# Predict depth
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img_np.shape[:2],
mode="bicubic",
align_corners=False
).squeeze()
# Normalize to 0-255
depth_map = prediction.cpu().numpy()
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
depth_map = (depth_map * 255).astype(np.uint8)
return Image.fromarray(depth_map)
# ===== Segmentation =====
def segment_image(img: Image.Image) -> Image.Image:
img = img.convert('RGB')
img_resized = resize_image(img)
results = segmenter(img_resized)
overlay = np.array(img_resized, dtype=np.uint8)
for res in results:
mask = np.array(res["mask"], dtype=bool)
color = np.random.randint(50, 255, 3, dtype=np.uint8)
overlay[mask] = (overlay[mask] * 0.6 + color * 0.4).astype(np.uint8)
return Image.fromarray(overlay)
# ===== Gradio App =====
def predict_fn(input_img: Image.Image) -> Image.Image:
# 1. Compute depth map
depth_img = predict_depth(input_img)
# 2. Segment the depth map
seg_img = segment_image(depth_img)
return seg_img
iface = gr.Interface(
fn=predict_fn,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Image(type="pil", label="Segmented Depth Overlay"),
title="Depth-then-Segmentation Pipeline",
description="Upload an image. First computes a depth map via MiDaS, then applies SegFormer segmentation on the depth map."
)
if __name__ == "__main__":
iface.launch()