LoGoSAM_demo / models /ProtoMedSAM.py
quandn2003's picture
Upload folder using huggingface_hub
427d150 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from models.ProtoSAM import ModelWrapper
from segment_anything import sam_model_registry
from util.utils import rotate_tensor_no_crop, reverse_tensor, need_softmax, get_confidence_from_logits, get_connected_components, cca, plot_connected_components
class ProtoMedSAM(nn.Module):
def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/medsam_vit_b.pth", debug=False, use_cca=False, coarse_pred_only=False):
super().__init__()
if isinstance(image_size, int):
image_size = (image_size, image_size)
self.image_size = image_size
self.coarse_segmentation_model = coarse_segmentation_model
self.get_sam(sam_pretrained_path)
self.coarse_pred_only = coarse_pred_only
self.debug = debug
self.use_cca = use_cca
def get_sam(self, checkpoint_path):
model_type="vit_b" # TODO make generic?
if 'vit_h' in checkpoint_path:
model_type = "vit_h"
self.medsam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval()
torch.no_grad()
def medsam_inference(self, img_embed, box_1024, H, W, query_label=None):
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
sparse_embeddings, dense_embeddings = self.medsam.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_logits, conf = self.medsam.mask_decoder(
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=self.medsam.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=True if query_label is not None else False,
)
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
low_res_pred = F.interpolate(
low_res_pred,
size=(H, W),
mode="bilinear",
align_corners=False,
) # (1, 1, gt.shape)
low_res_pred = low_res_pred.squeeze().cpu() # (256, 256)
low_res_pred = low_res_pred.numpy()
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
if query_label is not None:
medsam_seg = self.get_best_mask(medsam_seg, query_label)[None, :]
return medsam_seg, conf.cpu().detach().numpy()
def get_iou(self, pred, label):
"""
pred np array shape h,w type uint8
label np array shpae h,w type uiint8
"""
tp = np.logical_and(pred, label).sum()
fp = np.logical_and(pred, 1-label).sum()
fn = np.logical_and(1-pred, label).sum()
iou = tp / (tp + fp + fn)
return iou
def get_best_mask(self, masks, labels):
"""
masks np shape ( B, h, w)
labels torch shape (1, H, W)
"""
np_labels = labels[0].clone().detach().cpu().numpy()
best_iou, best_mask = 0, None
for mask in masks:
iou = self.get_iou(mask, np_labels)
if iou > best_iou:
best_iou = iou
best_mask = mask
return best_mask
def get_bbox(self, pred):
"""
pred is tensor of shape (H,W) - 1 is fg, 0 is bg.
return bbox of pred s.t np.array([xmin, y_min, xmax, ymax])
"""
if isinstance(pred, np.ndarray):
pred = torch.from_numpy(pred)
if pred.max() == 0:
return None
indices = torch.nonzero(pred)
ymin, xmin = indices.min(dim=0)[0]
ymax, xmax = indices.max(dim=0)[0]
return np.array([xmin, ymin, xmax, ymax])
def get_bbox_per_cc(self, conn_components):
"""
conn_components: output of cca function
return list of bboxes per connected component, each bbox is a list of 2d points
"""
bboxes = []
for i in range(1, conn_components[0]):
# get the indices of the foreground points
pred = torch.tensor(conn_components[1] == i, dtype=torch.uint8)
bboxes.append(self.get_bbox(pred))
bboxes = np.array(bboxes)
return bboxes
def forward(self, query_image, coarse_model_input, degrees_rotate=0):
"""
query_image: 3d tensor of shape (1, 3, H, W)
images should be normalized with mean and std but not to [0, 1]?
"""
original_size = query_image.shape[-2]
# rotate query_image by degrees_rotate
rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate)
# print(f"rotating query image took {time.time() - start_time} seconds")
coarse_model_input.set_query_images(rotated_img)
output_logits_rot = self.coarse_segmentation_model(coarse_model_input)
# print(f"ALPNet took {time.time() - start_time} seconds")
if degrees_rotate != 0:
output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate)
# print(f"reversing rotated output_logits took {time.time() - start_time} seconds")
else:
output_logits = output_logits_rot
# check if softmax is needed
# output_p = output_logits.softmax(dim=1)
output_p = output_logits
pred = output_logits.argmax(dim=1)[0]
if self.debug:
_pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu())
plt.subplot(132)
plt.imshow(query_image[0,0].detach().cpu())
plt.imshow(_pred, alpha=0.5)
plt.subplot(131)
# plot heatmap of prob of being fg
plt.imshow(output_p[0, 1].detach().cpu())
# plot rotated query image and rotated pred
output_p_rot = output_logits_rot.softmax(dim=1)
_pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu())
_pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0]
plt.subplot(133)
plt.imshow(rotated_img[0, 0].detach().cpu())
plt.imshow(_pred_rot, alpha=0.5)
plt.savefig('debug/coarse_pred.png')
plt.close()
if self.coarse_pred_only:
output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits
pred = output_logits.argmax(dim=1)[0]
conf = get_confidence_from_logits(output_logits)
if self.use_cca:
_pred = np.array(pred.detach().cpu())
_pred, conf = cca(_pred, output_logits, return_conf=True)
pred = torch.from_numpy(_pred)
if self.training:
return output_logits, [conf]
return pred, [conf]
if query_image.shape[-2:] != self.image_size:
query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear')
output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear')
if need_softmax(output_logits):
output_logits = output_logits.softmax(dim=1)
output_p = output_logits
pred = output_p.argmax(dim=1)[0]
_pred = np.array(output_p.argmax(dim=1)[0].detach().cpu())
if self.use_cca:
conn_components = cca(_pred, output_logits, return_cc=True)
conf=None
else:
conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True)
if self.debug:
plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf)
# print(f"connected components took {time.time() - start_time} seconds")
if _pred.max() == 0:
if output_p.shape[-2:] != original_size:
output_p = F.interpolate(output_p, size=original_size, mode='bilinear')
return output_p.argmax(dim=1)[0], [0]
H, W = query_image.shape[-2:]
# bbox = self.get_bbox(_pred)
bbox = self.get_bbox_per_cc(conn_components)
bbox = bbox / np.array([W, H, W, H]) * max(self.image_size)
query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min())
with torch.no_grad():
image_embedding = self.medsam.image_encoder(query_image)
medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W)
if self.debug:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu())
show_mask(medsam_seg, ax[0])
ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu())
show_box(bbox[0], ax[1])
plt.savefig('debug/medsam_pred.png')
plt.close()
medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device)
if medsam_seg.shape[-2:] != original_size:
medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0]
return medsam_seg, [conf]
def segment_all(self, query_image, query_label):
H, W = query_image.shape[-2:]
# bbox = self.get_bbox(_pred)
# bbox = self.get_bbox_per_cc(conn_components)
# bbox = bbox / np.array([W, H, W, H]) * max(self.image_size)
bbox = np.array([[0, 0, W, H]])
query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min())
with torch.no_grad():
image_embedding = self.medsam.image_encoder(query_image)
medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W, query_label)
if self.debug:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu())
show_mask(medsam_seg, ax[0])
ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu())
show_box(bbox[0], ax[1])
plt.savefig('debug/medsam_pred.png')
plt.close()
medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device)
if medsam_seg.shape[-2:] != (H, W):
medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=(H, W), mode='nearest')[0][0]
return medsam_seg.view(H,W), [conf]
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2)
)