File size: 8,552 Bytes
427d150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
from PIL import Image
import torchvision.transforms as transforms
from models.ProtoSAM import ProtoSAM, ALPNetWrapper, InputFactory, TYPE_ALPNET
from models.grid_proto_fewshot import FewShotSeg
from models.segment_anything.utils.transforms import ResizeLongestSide

# Set environment variables for model caching
os.environ['TORCH_HOME'] = "./pretrained_model"

# Function to load the model
def load_model(config):
    # Initial segmentation model
    alpnet = FewShotSeg(
        config["input_size"][0],
        config["reload_model_path"],
        config["model"]
    )
    alpnet.cuda()
    base_model = ALPNetWrapper(alpnet)
    
    # ProtoSAM model
    sam_checkpoint = "pretrained_model/sam_vit_h.pth"
    model = ProtoSAM(
        image_size=(1024, 1024),
        coarse_segmentation_model=base_model,
        use_bbox=config["use_bbox"],
        use_points=config["use_points"],
        use_mask=config["use_mask"],
        debug=False,
        num_points_for_sam=1,
        use_cca=config["do_cca"],
        point_mode=config["point_mode"],
        use_sam_trans=True,
        coarse_pred_only=config["coarse_pred_only"],
        sam_pretrained_path=sam_checkpoint,
        use_neg_points=config["use_neg_points"],
    )
    model = model.to(torch.device("cuda"))
    model.eval()
    return model

# Function to preprocess images
def preprocess_image(image, transform):
    if isinstance(image, np.ndarray):
        image_np = image
    else:
        # Convert PIL Image to numpy array
        image_np = np.array(image)
        
    # Convert to RGB if grayscale
    if len(image_np.shape) == 2:
        image_np = np.stack([image_np] * 3, axis=2)
    elif image_np.shape[2] == 1:
        image_np = np.concatenate([image_np] * 3, axis=2)
        
    # Apply transforms
    image_tensor = transform(image_np).unsqueeze(0)
    return image_tensor

# Function to create overlay visualization
def create_overlay(query_image, prediction, colormap='YlOrRd'):
    """
    Create an overlay of the prediction on the query image
    """
    # Convert tensors to numpy arrays for visualization
    if isinstance(query_image, torch.Tensor):
        query_image = query_image.cpu().squeeze().numpy()
        
    if isinstance(prediction, torch.Tensor):
        prediction = prediction.cpu().squeeze().numpy()
    
    # Normalize image for visualization
    query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min() + 1e-8)
    
    # Ensure binary mask
    prediction = (prediction > 0).astype(np.float32)
    
    # Create mask overlay
    mask_cmap = plt.cm.get_cmap(colormap)
    pred_rgba = mask_cmap(prediction)
    pred_rgba[..., 3] = prediction * 0.7  # Set alpha channel
    
    # Create matplotlib figure for overlay
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Handle grayscale vs RGB images
    if len(query_image.shape) == 2:
        ax.imshow(query_image, cmap='gray')
    else:
        if query_image.shape[0] == 3:  # Channel-first format
            query_image = np.transpose(query_image, (1, 2, 0))
        ax.imshow(query_image)
        
    ax.imshow(pred_rgba)
    ax.axis('off')
    plt.tight_layout()
    
    # Convert to PIL Image
    fig.canvas.draw()
    img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)
    
    return img

# Model configuration
config = {
    "input_size": [224],
    "reload_model_path": "path/to/your/model.pth",  # Update with your model path
    "model": {"encoder": "resnet50", "decoder": "pspnet"},
    "use_bbox": True,
    "use_points": True,
    "use_mask": True,
    "do_cca": True,
    "point_mode": "extreme",
    "coarse_pred_only": False,
    "use_neg_points": False,
    "base_model": TYPE_ALPNET
}

# Function to run inference
def run_inference(query_image, support_image, support_mask, use_bbox, use_points, use_mask, use_cca, coarse_pred_only):
    try:
        # Update config based on user selections
        config["use_bbox"] = use_bbox
        config["use_points"] = use_points
        config["use_mask"] = use_mask
        config["do_cca"] = use_cca
        config["coarse_pred_only"] = coarse_pred_only
        
        # Check if CUDA is available
        if not torch.cuda.is_available():
            return None, "CUDA is not available. This demo requires GPU support."
        
        # Load the model
        model = load_model(config)
        
        # Preprocess images
        sam_trans = ResizeLongestSide(1024)
        
        # Transform for images
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((1024, 1024), antialias=True),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Process query image
        query_img_tensor = preprocess_image(query_image, transform)
        
        # Process support image
        support_img_tensor = preprocess_image(support_image, transform)
        
        # Process support mask (should be binary)
        support_mask_np = np.array(support_mask)
        support_mask_np = (support_mask_np > 127).astype(np.float32)  # Binarize mask
        support_mask_tensor = torch.from_numpy(support_mask_np).unsqueeze(0).unsqueeze(0)
        support_mask_tensor = torch.nn.functional.interpolate(
            support_mask_tensor, size=(1024, 1024), mode='nearest'
        )
        
        # Prepare model inputs
        support_images = [support_img_tensor.cuda()]
        support_masks = [support_mask_tensor.cuda()]
        
        # Create model input
        coarse_model_input = InputFactory.create_input(
            input_type=config["base_model"],
            query_image=query_img_tensor.cuda(),
            support_images=support_images,
            support_labels=support_masks,
            isval=True,
            val_wsize=3,
            original_sz=query_img_tensor.shape[-2:],
            img_sz=query_img_tensor.shape[-2:],
            gts=None,
        )
        coarse_model_input.to(torch.device("cuda"))
        
        # Run inference
        with torch.no_grad():
            query_pred, scores = model(
                query_img_tensor.cuda(), coarse_model_input, degrees_rotate=0
            )
        
        # Create overlay visualization
        result_image = create_overlay(query_img_tensor, query_pred)
        
        confidence_score = np.mean(scores)
        return result_image, f"Confidence Score: {confidence_score:.4f}"
    
    except Exception as e:
        return None, f"Error during inference: {str(e)}"

# Define the Gradio interface
def create_interface():
    with gr.Blocks(title="ProtoSAM Segmentation Demo") as demo:
        gr.Markdown("# ProtoSAM Segmentation Demo")
        gr.Markdown("Upload a query image, support image, and support mask to generate a segmentation prediction.")
        
        with gr.Row():
            with gr.Column():
                query_image = gr.Image(label="Query Image", type="pil")
                support_image = gr.Image(label="Support Image", type="pil")
                support_mask = gr.Image(label="Support Mask", type="pil")
            
            with gr.Column():
                result_image = gr.Image(label="Prediction Result")
                result_text = gr.Textbox(label="Result Information")
        
        with gr.Row():
            with gr.Column():
                use_bbox = gr.Checkbox(label="Use Bounding Box", value=True)
                use_points = gr.Checkbox(label="Use Points", value=True)
                use_mask = gr.Checkbox(label="Use Mask", value=True)
            
            with gr.Column():
                use_cca = gr.Checkbox(label="Use CCA", value=True)
                coarse_pred_only = gr.Checkbox(label="Coarse Prediction Only", value=False)
                run_button = gr.Button("Run Inference")
        
        run_button.click(
            fn=run_inference,
            inputs=[
                query_image, 
                support_image, 
                support_mask, 
                use_bbox,
                use_points,
                use_mask,
                use_cca,
                coarse_pred_only
            ],
            outputs=[result_image, result_text]
        )
    
    return demo

# Create and launch the interface
if __name__ == "__main__":
    demo = create_interface()
    demo.launch(share=True)