File size: 3,169 Bytes
e132a83
6225fda
e132a83
afae8e5
6225fda
 
afae8e5
6225fda
afae8e5
6225fda
e132a83
6225fda
 
361cd34
6225fda
f71d4bc
361cd34
e132a83
6225fda
 
 
 
 
 
 
f71d4bc
6225fda
 
 
 
 
 
 
 
f71d4bc
6225fda
 
 
 
 
afae8e5
6225fda
 
 
afae8e5
6225fda
e132a83
afae8e5
e132a83
 
6225fda
e132a83
6225fda
e132a83
 
6225fda
e132a83
 
 
6225fda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e132a83
6225fda
 
 
 
 
 
 
e132a83
 
6225fda
 
 
 
 
e132a83
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import torch
import cv2
from PIL import Image
from transformers import pipeline
import gradio as gr

# ===== Device Setup =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_index = 0 if torch.cuda.is_available() else -1

# ===== MiDaS Depth Estimation Setup =====
# Load MiDaS model and transforms
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
midas.to(device).eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform

# ===== Segmentation Setup =====
segmenter = pipeline(
    "image-segmentation",
    model="nvidia/segformer-b0-finetuned-ade-512-512",
    device=device_index,
    torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
)

# ===== Utility Functions =====
def resize_image(img: Image.Image, max_size: int = 512) -> Image.Image:
    width, height = img.size
    if max(width, height) > max_size:
        ratio = max_size / max(width, height)
        new_size = (int(width * ratio), int(height * ratio))
        return img.resize(new_size, Image.LANCZOS)
    return img

# ===== Depth Prediction =====
def predict_depth(image: Image.Image) -> Image.Image:
    # Ensure input is PIL Image
    img = image.convert('RGB') if not isinstance(image, Image.Image) else image
    img_np = np.array(img)

    # Convert to the format expected by MiDaS
    input_tensor = transform(img_np).to(device)
    input_batch = input_tensor.unsqueeze(0) if input_tensor.ndim == 3 else input_tensor

    # Predict depth
    with torch.no_grad():
        prediction = midas(input_batch)
        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=img_np.shape[:2],
            mode="bicubic",
            align_corners=False
        ).squeeze()

    # Normalize to 0-255
    depth_map = prediction.cpu().numpy()
    depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
    depth_map = (depth_map * 255).astype(np.uint8)
    return Image.fromarray(depth_map)

# ===== Segmentation =====
def segment_image(img: Image.Image) -> Image.Image:
    img = img.convert('RGB')
    img_resized = resize_image(img)
    results = segmenter(img_resized)

    overlay = np.array(img_resized, dtype=np.uint8)
    for res in results:
        mask = np.array(res["mask"], dtype=bool)
        color = np.random.randint(50, 255, 3, dtype=np.uint8)
        overlay[mask] = (overlay[mask] * 0.6 + color * 0.4).astype(np.uint8)

    return Image.fromarray(overlay)

# ===== Gradio App =====
def predict_fn(input_img: Image.Image) -> Image.Image:
    # 1. Compute depth map
    depth_img = predict_depth(input_img)
    # 2. Segment the depth map
    seg_img = segment_image(depth_img)
    return seg_img

iface = gr.Interface(
    fn=predict_fn,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Image(type="pil", label="Segmented Depth Overlay"),
    title="Depth-then-Segmentation Pipeline",
    description="Upload an image. First computes a depth map via MiDaS, then applies SegFormer segmentation on the depth map."
)

if __name__ == "__main__":
    iface.launch()