File size: 3,169 Bytes
e132a83 6225fda e132a83 afae8e5 6225fda afae8e5 6225fda afae8e5 6225fda e132a83 6225fda 361cd34 6225fda f71d4bc 361cd34 e132a83 6225fda f71d4bc 6225fda f71d4bc 6225fda afae8e5 6225fda afae8e5 6225fda e132a83 afae8e5 e132a83 6225fda e132a83 6225fda e132a83 6225fda e132a83 6225fda e132a83 6225fda e132a83 6225fda e132a83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import numpy as np
import torch
import cv2
from PIL import Image
from transformers import pipeline
import gradio as gr
# ===== Device Setup =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_index = 0 if torch.cuda.is_available() else -1
# ===== MiDaS Depth Estimation Setup =====
# Load MiDaS model and transforms
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
midas.to(device).eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform
# ===== Segmentation Setup =====
segmenter = pipeline(
"image-segmentation",
model="nvidia/segformer-b0-finetuned-ade-512-512",
device=device_index,
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
)
# ===== Utility Functions =====
def resize_image(img: Image.Image, max_size: int = 512) -> Image.Image:
width, height = img.size
if max(width, height) > max_size:
ratio = max_size / max(width, height)
new_size = (int(width * ratio), int(height * ratio))
return img.resize(new_size, Image.LANCZOS)
return img
# ===== Depth Prediction =====
def predict_depth(image: Image.Image) -> Image.Image:
# Ensure input is PIL Image
img = image.convert('RGB') if not isinstance(image, Image.Image) else image
img_np = np.array(img)
# Convert to the format expected by MiDaS
input_tensor = transform(img_np).to(device)
input_batch = input_tensor.unsqueeze(0) if input_tensor.ndim == 3 else input_tensor
# Predict depth
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img_np.shape[:2],
mode="bicubic",
align_corners=False
).squeeze()
# Normalize to 0-255
depth_map = prediction.cpu().numpy()
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
depth_map = (depth_map * 255).astype(np.uint8)
return Image.fromarray(depth_map)
# ===== Segmentation =====
def segment_image(img: Image.Image) -> Image.Image:
img = img.convert('RGB')
img_resized = resize_image(img)
results = segmenter(img_resized)
overlay = np.array(img_resized, dtype=np.uint8)
for res in results:
mask = np.array(res["mask"], dtype=bool)
color = np.random.randint(50, 255, 3, dtype=np.uint8)
overlay[mask] = (overlay[mask] * 0.6 + color * 0.4).astype(np.uint8)
return Image.fromarray(overlay)
# ===== Gradio App =====
def predict_fn(input_img: Image.Image) -> Image.Image:
# 1. Compute depth map
depth_img = predict_depth(input_img)
# 2. Segment the depth map
seg_img = segment_image(depth_img)
return seg_img
iface = gr.Interface(
fn=predict_fn,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Image(type="pil", label="Segmented Depth Overlay"),
title="Depth-then-Segmentation Pipeline",
description="Upload an image. First computes a depth map via MiDaS, then applies SegFormer segmentation on the depth map."
)
if __name__ == "__main__":
iface.launch() |