File size: 11,594 Bytes
427d150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from models.ProtoSAM import ModelWrapper
from segment_anything import sam_model_registry
from util.utils import rotate_tensor_no_crop, reverse_tensor, need_softmax, get_confidence_from_logits, get_connected_components, cca, plot_connected_components

class ProtoMedSAM(nn.Module):
    def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/medsam_vit_b.pth", debug=False, use_cca=False,  coarse_pred_only=False):
        super().__init__()
        if isinstance(image_size, int):
            image_size = (image_size, image_size)
        self.image_size = image_size
        self.coarse_segmentation_model = coarse_segmentation_model
        self.get_sam(sam_pretrained_path)
        self.coarse_pred_only = coarse_pred_only
        self.debug = debug
        self.use_cca = use_cca
        
    
    def get_sam(self, checkpoint_path):
        model_type="vit_b" # TODO make generic?
        if 'vit_h' in checkpoint_path:
            model_type = "vit_h"
        self.medsam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval()

        
    torch.no_grad()
    def medsam_inference(self, img_embed, box_1024, H, W, query_label=None):
        box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
        if len(box_torch.shape) == 2:
            box_torch = box_torch[:, None, :]  # (B, 1, 4)

        sparse_embeddings, dense_embeddings = self.medsam.prompt_encoder(
            points=None,
            boxes=box_torch,
            masks=None,
        )
        low_res_logits, conf = self.medsam.mask_decoder(
            image_embeddings=img_embed,  # (B, 256, 64, 64)
            image_pe=self.medsam.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings,  # (B, 256, 64, 64)
            multimask_output=True if query_label is not None else False,
        )

        low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

        low_res_pred = F.interpolate(
            low_res_pred,
            size=(H, W),
            mode="bilinear",
            align_corners=False,
        )  # (1, 1, gt.shape)
        low_res_pred = low_res_pred.squeeze().cpu()  # (256, 256)
        
        low_res_pred = low_res_pred.numpy()
        medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
        
        if query_label is not None:
            medsam_seg = self.get_best_mask(medsam_seg, query_label)[None, :]
        
        return medsam_seg, conf.cpu().detach().numpy()
    
    def get_iou(self, pred, label):
        """

        pred np array shape h,w type uint8

        label np array shpae h,w type uiint8

        """
        tp = np.logical_and(pred, label).sum()
        fp = np.logical_and(pred, 1-label).sum()
        fn = np.logical_and(1-pred, label).sum()
        iou = tp / (tp + fp + fn)
        return iou
    
    def get_best_mask(self, masks, labels):
        """

        masks np shape ( B, h, w)

        labels torch shape (1, H, W)

        """
        np_labels = labels[0].clone().detach().cpu().numpy()
        best_iou, best_mask = 0, None
        for mask in masks:
            iou = self.get_iou(mask, np_labels)
            if iou > best_iou:
                best_iou = iou
                best_mask = mask
                
        return best_mask
    
    def get_bbox(self, pred):
        """

        pred is tensor of shape (H,W) - 1 is fg, 0 is bg.

        return bbox of pred s.t np.array([xmin, y_min, xmax, ymax])

        """
        if isinstance(pred, np.ndarray):
            pred = torch.from_numpy(pred)
        if pred.max() == 0:
            return None
        indices = torch.nonzero(pred)
        ymin, xmin = indices.min(dim=0)[0]
        ymax, xmax = indices.max(dim=0)[0]
        return np.array([xmin, ymin, xmax, ymax])
            
    
    def get_bbox_per_cc(self, conn_components):
        """

        conn_components: output of cca function

        return list of bboxes per connected component, each bbox is a list of 2d points

        """
        bboxes = []
        for i in range(1, conn_components[0]):
            # get the indices of the foreground points
            pred = torch.tensor(conn_components[1] == i, dtype=torch.uint8)
            bboxes.append(self.get_bbox(pred))

        bboxes = np.array(bboxes)
        return bboxes

    def forward(self, query_image, coarse_model_input, degrees_rotate=0):
        """

        query_image: 3d tensor of shape (1, 3, H, W)

        images should be normalized with mean and std but not to [0, 1]?

        """
        original_size = query_image.shape[-2]
        # rotate query_image by degrees_rotate
        rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate)
        # print(f"rotating query image took {time.time() - start_time} seconds")
        coarse_model_input.set_query_images(rotated_img)
        output_logits_rot = self.coarse_segmentation_model(coarse_model_input)
        # print(f"ALPNet took {time.time() - start_time} seconds")
       
        if degrees_rotate != 0:
            output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate)
            # print(f"reversing rotated output_logits took {time.time() - start_time} seconds")
        else:
            output_logits = output_logits_rot
        
        # check if softmax is needed 
        # output_p = output_logits.softmax(dim=1)
        output_p = output_logits
        pred = output_logits.argmax(dim=1)[0]
        if self.debug:
            _pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu())
            plt.subplot(132)
            plt.imshow(query_image[0,0].detach().cpu())
            plt.imshow(_pred, alpha=0.5)
            plt.subplot(131)
            # plot heatmap of prob of being fg
            plt.imshow(output_p[0, 1].detach().cpu())
            # plot rotated query image and rotated pred
            output_p_rot = output_logits_rot.softmax(dim=1)
            _pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu())
            _pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0]
            plt.subplot(133)
            plt.imshow(rotated_img[0, 0].detach().cpu())
            plt.imshow(_pred_rot, alpha=0.5)
            plt.savefig('debug/coarse_pred.png')
            plt.close()
             
        if self.coarse_pred_only: 
            output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits
            pred = output_logits.argmax(dim=1)[0]
            conf = get_confidence_from_logits(output_logits) 
            if self.use_cca:
                _pred = np.array(pred.detach().cpu())
                _pred, conf = cca(_pred, output_logits, return_conf=True)
                pred = torch.from_numpy(_pred)
            if self.training:
                return output_logits, [conf]
            return pred, [conf]
        
        if query_image.shape[-2:] != self.image_size:
            query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear')
            output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear')
        if need_softmax(output_logits):
            output_logits = output_logits.softmax(dim=1)
        
        output_p = output_logits
        pred = output_p.argmax(dim=1)[0]
       
        _pred = np.array(output_p.argmax(dim=1)[0].detach().cpu()) 
        if self.use_cca:
            conn_components = cca(_pred, output_logits, return_cc=True)
            conf=None
        else:
            conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True)
        if self.debug:
            plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf)
        # print(f"connected components took {time.time() - start_time} seconds")
        
        if _pred.max() == 0:
            if output_p.shape[-2:] != original_size:
                output_p = F.interpolate(output_p, size=original_size, mode='bilinear')
            return output_p.argmax(dim=1)[0], [0]

        H, W = query_image.shape[-2:]
        # bbox = self.get_bbox(_pred)
        bbox = self.get_bbox_per_cc(conn_components)
        bbox = bbox / np.array([W, H, W, H]) * max(self.image_size)
        query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min())
        with torch.no_grad():
            image_embedding = self.medsam.image_encoder(query_image)
            
        medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W)
        
        if self.debug:
            fig, ax = plt.subplots(1, 2)
            ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu())
            show_mask(medsam_seg, ax[0])
            ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu())
            show_box(bbox[0], ax[1])
            plt.savefig('debug/medsam_pred.png')
            plt.close()

        medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device)
        if medsam_seg.shape[-2:] != original_size:
            medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0]

        return medsam_seg, [conf]
    
    def segment_all(self, query_image, query_label):
        H, W = query_image.shape[-2:]
        # bbox = self.get_bbox(_pred)
        # bbox = self.get_bbox_per_cc(conn_components)
        # bbox = bbox / np.array([W, H, W, H]) * max(self.image_size)
        bbox = np.array([[0, 0, W, H]])
        query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min())
        with torch.no_grad():
            image_embedding = self.medsam.image_encoder(query_image)
            
        medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W, query_label)
        
        if self.debug:
            fig, ax = plt.subplots(1, 2)
            ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu())
            show_mask(medsam_seg, ax[0])
            ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu())
            show_box(bbox[0], ax[1])
            plt.savefig('debug/medsam_pred.png')
            plt.close()

        medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device)
        if medsam_seg.shape[-2:] != (H, W):
            medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=(H, W), mode='nearest')[0][0]

        return medsam_seg.view(H,W), [conf]

    
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2)
    )