Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from models.ProtoSAM import ModelWrapper | |
from segment_anything import sam_model_registry | |
from util.utils import rotate_tensor_no_crop, reverse_tensor, need_softmax, get_confidence_from_logits, get_connected_components, cca, plot_connected_components | |
class ProtoMedSAM(nn.Module): | |
def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/medsam_vit_b.pth", debug=False, use_cca=False, coarse_pred_only=False): | |
super().__init__() | |
if isinstance(image_size, int): | |
image_size = (image_size, image_size) | |
self.image_size = image_size | |
self.coarse_segmentation_model = coarse_segmentation_model | |
self.get_sam(sam_pretrained_path) | |
self.coarse_pred_only = coarse_pred_only | |
self.debug = debug | |
self.use_cca = use_cca | |
def get_sam(self, checkpoint_path): | |
model_type="vit_b" # TODO make generic? | |
if 'vit_h' in checkpoint_path: | |
model_type = "vit_h" | |
self.medsam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval() | |
torch.no_grad() | |
def medsam_inference(self, img_embed, box_1024, H, W, query_label=None): | |
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device) | |
if len(box_torch.shape) == 2: | |
box_torch = box_torch[:, None, :] # (B, 1, 4) | |
sparse_embeddings, dense_embeddings = self.medsam.prompt_encoder( | |
points=None, | |
boxes=box_torch, | |
masks=None, | |
) | |
low_res_logits, conf = self.medsam.mask_decoder( | |
image_embeddings=img_embed, # (B, 256, 64, 64) | |
image_pe=self.medsam.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) | |
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) | |
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) | |
multimask_output=True if query_label is not None else False, | |
) | |
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256) | |
low_res_pred = F.interpolate( | |
low_res_pred, | |
size=(H, W), | |
mode="bilinear", | |
align_corners=False, | |
) # (1, 1, gt.shape) | |
low_res_pred = low_res_pred.squeeze().cpu() # (256, 256) | |
low_res_pred = low_res_pred.numpy() | |
medsam_seg = (low_res_pred > 0.5).astype(np.uint8) | |
if query_label is not None: | |
medsam_seg = self.get_best_mask(medsam_seg, query_label)[None, :] | |
return medsam_seg, conf.cpu().detach().numpy() | |
def get_iou(self, pred, label): | |
""" | |
pred np array shape h,w type uint8 | |
label np array shpae h,w type uiint8 | |
""" | |
tp = np.logical_and(pred, label).sum() | |
fp = np.logical_and(pred, 1-label).sum() | |
fn = np.logical_and(1-pred, label).sum() | |
iou = tp / (tp + fp + fn) | |
return iou | |
def get_best_mask(self, masks, labels): | |
""" | |
masks np shape ( B, h, w) | |
labels torch shape (1, H, W) | |
""" | |
np_labels = labels[0].clone().detach().cpu().numpy() | |
best_iou, best_mask = 0, None | |
for mask in masks: | |
iou = self.get_iou(mask, np_labels) | |
if iou > best_iou: | |
best_iou = iou | |
best_mask = mask | |
return best_mask | |
def get_bbox(self, pred): | |
""" | |
pred is tensor of shape (H,W) - 1 is fg, 0 is bg. | |
return bbox of pred s.t np.array([xmin, y_min, xmax, ymax]) | |
""" | |
if isinstance(pred, np.ndarray): | |
pred = torch.from_numpy(pred) | |
if pred.max() == 0: | |
return None | |
indices = torch.nonzero(pred) | |
ymin, xmin = indices.min(dim=0)[0] | |
ymax, xmax = indices.max(dim=0)[0] | |
return np.array([xmin, ymin, xmax, ymax]) | |
def get_bbox_per_cc(self, conn_components): | |
""" | |
conn_components: output of cca function | |
return list of bboxes per connected component, each bbox is a list of 2d points | |
""" | |
bboxes = [] | |
for i in range(1, conn_components[0]): | |
# get the indices of the foreground points | |
pred = torch.tensor(conn_components[1] == i, dtype=torch.uint8) | |
bboxes.append(self.get_bbox(pred)) | |
bboxes = np.array(bboxes) | |
return bboxes | |
def forward(self, query_image, coarse_model_input, degrees_rotate=0): | |
""" | |
query_image: 3d tensor of shape (1, 3, H, W) | |
images should be normalized with mean and std but not to [0, 1]? | |
""" | |
original_size = query_image.shape[-2] | |
# rotate query_image by degrees_rotate | |
rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate) | |
# print(f"rotating query image took {time.time() - start_time} seconds") | |
coarse_model_input.set_query_images(rotated_img) | |
output_logits_rot = self.coarse_segmentation_model(coarse_model_input) | |
# print(f"ALPNet took {time.time() - start_time} seconds") | |
if degrees_rotate != 0: | |
output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate) | |
# print(f"reversing rotated output_logits took {time.time() - start_time} seconds") | |
else: | |
output_logits = output_logits_rot | |
# check if softmax is needed | |
# output_p = output_logits.softmax(dim=1) | |
output_p = output_logits | |
pred = output_logits.argmax(dim=1)[0] | |
if self.debug: | |
_pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu()) | |
plt.subplot(132) | |
plt.imshow(query_image[0,0].detach().cpu()) | |
plt.imshow(_pred, alpha=0.5) | |
plt.subplot(131) | |
# plot heatmap of prob of being fg | |
plt.imshow(output_p[0, 1].detach().cpu()) | |
# plot rotated query image and rotated pred | |
output_p_rot = output_logits_rot.softmax(dim=1) | |
_pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu()) | |
_pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0] | |
plt.subplot(133) | |
plt.imshow(rotated_img[0, 0].detach().cpu()) | |
plt.imshow(_pred_rot, alpha=0.5) | |
plt.savefig('debug/coarse_pred.png') | |
plt.close() | |
if self.coarse_pred_only: | |
output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits | |
pred = output_logits.argmax(dim=1)[0] | |
conf = get_confidence_from_logits(output_logits) | |
if self.use_cca: | |
_pred = np.array(pred.detach().cpu()) | |
_pred, conf = cca(_pred, output_logits, return_conf=True) | |
pred = torch.from_numpy(_pred) | |
if self.training: | |
return output_logits, [conf] | |
return pred, [conf] | |
if query_image.shape[-2:] != self.image_size: | |
query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear') | |
output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear') | |
if need_softmax(output_logits): | |
output_logits = output_logits.softmax(dim=1) | |
output_p = output_logits | |
pred = output_p.argmax(dim=1)[0] | |
_pred = np.array(output_p.argmax(dim=1)[0].detach().cpu()) | |
if self.use_cca: | |
conn_components = cca(_pred, output_logits, return_cc=True) | |
conf=None | |
else: | |
conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True) | |
if self.debug: | |
plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf) | |
# print(f"connected components took {time.time() - start_time} seconds") | |
if _pred.max() == 0: | |
if output_p.shape[-2:] != original_size: | |
output_p = F.interpolate(output_p, size=original_size, mode='bilinear') | |
return output_p.argmax(dim=1)[0], [0] | |
H, W = query_image.shape[-2:] | |
# bbox = self.get_bbox(_pred) | |
bbox = self.get_bbox_per_cc(conn_components) | |
bbox = bbox / np.array([W, H, W, H]) * max(self.image_size) | |
query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min()) | |
with torch.no_grad(): | |
image_embedding = self.medsam.image_encoder(query_image) | |
medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W) | |
if self.debug: | |
fig, ax = plt.subplots(1, 2) | |
ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu()) | |
show_mask(medsam_seg, ax[0]) | |
ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu()) | |
show_box(bbox[0], ax[1]) | |
plt.savefig('debug/medsam_pred.png') | |
plt.close() | |
medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device) | |
if medsam_seg.shape[-2:] != original_size: | |
medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0] | |
return medsam_seg, [conf] | |
def segment_all(self, query_image, query_label): | |
H, W = query_image.shape[-2:] | |
# bbox = self.get_bbox(_pred) | |
# bbox = self.get_bbox_per_cc(conn_components) | |
# bbox = bbox / np.array([W, H, W, H]) * max(self.image_size) | |
bbox = np.array([[0, 0, W, H]]) | |
query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min()) | |
with torch.no_grad(): | |
image_embedding = self.medsam.image_encoder(query_image) | |
medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W, query_label) | |
if self.debug: | |
fig, ax = plt.subplots(1, 2) | |
ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu()) | |
show_mask(medsam_seg, ax[0]) | |
ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu()) | |
show_box(bbox[0], ax[1]) | |
plt.savefig('debug/medsam_pred.png') | |
plt.close() | |
medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device) | |
if medsam_seg.shape[-2:] != (H, W): | |
medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=(H, W), mode='nearest')[0][0] | |
return medsam_seg.view(H,W), [conf] | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def show_box(box, ax): | |
x0, y0 = box[0], box[1] | |
w, h = box[2] - box[0], box[3] - box[1] | |
ax.add_patch( | |
plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2) | |
) | |