import warnings import torch import torch.nn as nn from torch.nn import functional as F import matplotlib.pyplot as plt import numpy as np from models.grid_proto_fewshot import FewShotSeg from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor from models.SamWrapper import SamWrapper from util.utils import cca, get_connected_components, rotate_tensor_no_crop, reverse_tensor, get_confidence_from_logits from util.lora import inject_trainable_lora from models.segment_anything.utils.transforms import ResizeLongestSide import cv2 import time from abc import ABC, abstractmethod CONF_MODE="conf" CENTROID_MODE="centroid" BOTH_MODE="both" POINT_MODES=(CONF_MODE, CENTROID_MODE, BOTH_MODE) TYPE_ALPNET="alpnet" TYPE_SAM="sam" def plot_connected_components(cca_output, original_image, confidences:dict=None, title="debug/connected_components.png"): num_labels, labels, stats, centroids = cca_output # Create an output image with random colors for each component output_image = np.zeros((labels.shape[0], labels.shape[1], 3), np.uint8) for label in range(1, num_labels): # Start from 1 to skip the background mask = labels == label output_image[mask] = np.random.randint(0, 255, size=3) # Plotting the original and the colored components image plt.figure(figsize=(10, 5)) plt.subplot(121), plt.imshow(original_image), plt.title('Original Image') plt.subplot(122), plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)), plt.title('Connected Components') if confidences is not None: # Plot the axes color chart with the confidences, use the same colors as the connected components plt.subplot(122) scatter = plt.scatter(centroids[:, 0], centroids[:, 1], c=list(confidences.values()), cmap='jet') plt.colorbar(scatter) plt.savefig(title) plt.close() class SegmentationInput(ABC): @abstractmethod def set_query_images(self, query_images): pass def to(self, device): pass class SegmentationOutput(ABC): @abstractmethod def get_prediction(self): pass class ALPNetInput(SegmentationInput): # for alpnet def __init__(self, support_images:list, support_labels:list, query_images:torch.Tensor, isval, val_wsize, show_viz=False, supp_fts=None): self.supp_imgs = [support_images] self.fore_mask = [support_labels] self.back_mask = [[1 - sup_labels for sup_labels in support_labels]] self.qry_imgs = [query_images] self.isval = isval self.val_wsize = val_wsize self.show_viz = show_viz self.supp_fts = supp_fts def set_query_images(self, query_images): self.qry_imgs = [query_images] def to(self, device): self.supp_imgs = [[supp_img.to(device) for way in self.supp_imgs for supp_img in way]] self.fore_mask = [[fore_mask.to(device) for way in self.fore_mask for fore_mask in way]] self.back_mask = [[back_mask.to(device) for way in self.back_mask for back_mask in way]] self.qry_imgs = [qry_img.to(device) for qry_img in self.qry_imgs] if self.supp_fts is not None: self.supp_fts = self.supp_fts.to(device) class ALPNetOutput(SegmentationOutput): def __init__(self, pred, align_loss, sim_maps, assign_maps, proto_grid, supp_fts, qry_fts): self.pred = pred self.align_loss = align_loss self.sim_maps = sim_maps self.assign_maps = assign_maps self.proto_grid = proto_grid self.supp_fts = supp_fts self.qry_fts = qry_fts def get_prediction(self): return self.pred class SAMWrapperInput(SegmentationInput): def __init__(self, image, image_labels): self.image = image self.image_labels = image_labels def set_query_images(self, query_images): B, C, H, W = query_images.shape if isinstance(query_images, torch.Tensor): query_images = query_images.cpu().detach().numpy() assert B == 1, "batch size must be 1" query_images = (query_images - query_images.min()) / (query_images.max() - query_images.min()) * 255 query_images = query_images.astype(np.uint8) self.image = np.transpose(query_images[0], (1, 2, 0)) def to(self, device): pass class InputFactory(ABC): @staticmethod def create_input(input_type, query_image, support_images=None, support_labels=None, isval=False, val_wsize=None, show_viz=False, supp_fts=None, original_sz=None, img_sz=None, gts=None): if input_type == TYPE_ALPNET: return ALPNetInput(support_images, support_labels, query_image, isval, val_wsize, show_viz, supp_fts) elif input_type == TYPE_SAM: qimg = np.array(query_image.detach().cpu()) B,C,H,W = qimg.shape assert B == 1, "batch size must be 1" gts = np.array(gts.detach().cpu()).astype(np.uint8).reshape(H,W) assert np.unique(gts).shape[0] <= 2, "support labels must be binary" gts[gts > 0] = 1 qimg = qimg.reshape(H,W,C) qimg = (qimg - qimg.min()) / (qimg.max() - qimg.min()) * 255 qimg = qimg.astype(np.uint8) return SAMWrapperInput(qimg, gts) else: raise ValueError(f"input_type not supported") class ModelWrapper(ABC): def __init__(self, model): self.model = model def __call__(self, input_data: SegmentationInput)->SegmentationOutput: pass def state_dict(self): return self.model.state_dict() def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) def eval(self): self.model.eval() def train(self): self.model.train() def parameters(self): pass class ALPNetWrapper(ModelWrapper): def __init__(self, model: FewShotSeg): super().__init__(model) def __call__(self, input_data: ALPNetInput): output = self.model(**input_data.__dict__) output = ALPNetOutput(*output) return output.pred def parameters(self): return self.model.encoder.parameters() def train(self): self.model.encoder.train() class SamWrapperWrapper(ModelWrapper): def __init__(self, model:SamWrapper): super().__init__(model) def __call__(self, input_data: SAMWrapperInput): pred = self.model(**input_data.__dict__) # make pred look like logits pred = torch.tensor(pred).float()[None, None, ...] pred = torch.cat([1-pred, pred], dim=1) return pred def to(self, device): self.model.sam.to(device) class ProtoSAM(nn.Module): def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/sam_default.pth", num_points_for_sam=1, use_points=True, use_bbox=False, use_mask=False, debug=False, use_cca=False, point_mode=CONF_MODE, use_sam_trans=True, coarse_pred_only=False, alpnet_image_size=None, use_neg_points=False, ): super().__init__() if isinstance(image_size, int): image_size = (image_size, image_size) self.image_size = image_size self.coarse_segmentation_model = coarse_segmentation_model self.get_sam(sam_pretrained_path, use_sam_trans) self.num_points_for_sam = num_points_for_sam self.use_points = use_points self.use_bbox = use_bbox # if False then uses points self.use_mask = use_mask self.use_neg_points = use_neg_points assert self.use_bbox or self.use_points or self.use_mask, "must use at least one of bbox, points, or mask" self.use_cca = use_cca self.point_mode = point_mode if self.point_mode not in POINT_MODES: raise ValueError(f"point mode must be one of {POINT_MODES}") self.debug=debug self.coarse_pred_only = coarse_pred_only def get_sam(self, checkpoint_path, use_sam_trans): model_type="vit_b" # TODO make generic? if 'vit_h' in checkpoint_path: model_type = "vit_h" self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval() self.predictor = SamPredictor(self.sam) self.sam.requires_grad_(False) if use_sam_trans: # sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size, pixel_mean=[0], pixel_std=[1]) sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size) sam_trans.pixel_mean = torch.tensor([0, 0, 0]).view(3, 1, 1) sam_trans.pixel_std = torch.tensor([1, 1, 1]).view(3, 1, 1) else: sam_trans = None self.sam_trans = sam_trans def get_bbox(self, pred): ''' pred tensor of shape (H, W) where 1 represents foreground and 0 represents background returns a list of 2d points representing the bbox ''' if isinstance(pred, np.ndarray): pred = torch.tensor(pred) # get the indices of the foreground points indices = torch.nonzero(pred) # get the min and max of the indices min_x = indices[:, 1].min() max_x = indices[:, 1].max() min_y = indices[:, 0].min() max_y = indices[:, 0].max() # get the bbox bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]] return bbox def get_bbox_per_cc(self, conn_components): """ conn_components: output of cca function return list of bboxes per connected component, each bbox is a list of 2d points """ bboxes = [] for i in range(1, conn_components[0]): # get the indices of the foreground points indices = torch.nonzero(torch.tensor(conn_components[1] == i)) # get the min and max of the indices min_x = indices[:, 1].min() max_x = indices[:, 1].max() min_y = indices[:, 0].min() max_y = indices[:, 0].max() # get the bbox # bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]] # bbox = [[min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y]] # bbox should be in a XYXY format bbox = [min_x, min_y, max_x, max_y] bboxes.append(bbox) bboxes = np.array(bboxes) return bboxes def get_most_conf_points(self, output_p_fg, pred, k): ''' get the k most confident points from pred output_p: 3d tensor of shape (H, W) pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background ''' # Create a mask where pred is 1 mask = pred.bool() # Apply the mask to output_p_fg masked_output_p_fg = output_p_fg[mask] if masked_output_p_fg.numel() == 0: return None, None # Get the top k probabilities and their indices confidences, indices = torch.topk(masked_output_p_fg, k) # Get the locations of the top k points in xy format locations = torch.nonzero(mask)[indices] # convert locations to xy format locations = locations[:, [1, 0]] # convert locations to list of lists # points = [loc.tolist() for loc in locations] return locations.numpy(), [float(conf.item()) for conf in confidences] def plot_most_conf_points(self, points, confidences, pred, image, bboxes=None, title=None): ''' points: np array of shape (N, 2) where each row is a point in xy format pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background image: 2d tensor of shape (H,W) representing the image bbox: list or np array of shape (N, 4) where each row is a bbox in xyxy format ''' warnings.filterwarnings('ignore', category=UserWarning) if isinstance(pred, torch.Tensor): pred = pred.cpu().detach().numpy() if len(image.shape) == 3 and image.shape[0] == 3: image = image.permute(1, 2, 0) if title is None: title="debug/most_conf_points.png" fig = plt.figure() image = (image - image.min()) / (image.max() - image.min()) plt.imshow(image) plt.imshow(pred, alpha=0.5) for i, point in enumerate(points): plt.scatter(point[0][0], point[0][1], cmap='viridis', marker='*', c='red') if confidences is not None: plt.text(point[0], point[1], f"{confidences[i]:.3f}", fontsize=12, color='red') # assume points is a list of lists if bboxes is not None: for bbox in bboxes: if bbox is None: continue bbox = np.array(bbox) # plt.scatter(bbox[:, 1], bbox[:, 0], c='red') # plot a line connecting the points box = np.array([[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]]) box = np.vstack([box, box[0]]) plt.plot(box[:, 0], box[:, 1], c='green') plt.colorbar() fig.savefig(title) plt.close(fig) def plot_sam_preds(self, masks, scores, image, input_point, input_label, input_box=None): if len(image.shape) == 3: image = image.permute(1, 2, 0) image = (image - image.min()) / (image.max() - image.min()) for i, (mask, score) in enumerate(zip(masks, scores)): plt.figure(figsize=(10,10)) plt.imshow(image) show_mask(mask, plt.gca()) if input_point is not None: show_points(input_point, input_label, plt.gca()) if input_box is not None: show_box(input_box, plt.gca()) plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) # plt.axis('off') plt.savefig(f'debug/sam_mask_{i+1}.png') plt.close() if i > 5: break def get_sam_input_points(self, conn_components, output_p, get_neg_points=False, l=1): """ args: conn_components: output of cca function output_p: 3d tensor of shape (1, 2, H, W) get_neg_points: bool, if True then return the negative points l: int, number of negative points to get """ sam_input_points = [] sam_neg_points = [] fg_p = output_p[0, 1].detach().cpu() if get_neg_points: # get global negative points bg_p = output_p[0, 0].detach().cpu() bg_p[bg_p < 0.95] = 0 bg_pred = torch.where(bg_p > 0, 1, 0) glob_neg_points, _ = self.get_most_conf_points(bg_p, bg_pred, 1) if self.debug: # plot the bg_p as a heatmap plt.figure() plt.imshow(bg_p) plt.colorbar() plt.savefig('debug/bg_p_heatmap.png') plt.close() for i, cc_id in enumerate(np.unique(conn_components[1])): # get self.num_points_for_sam most confident points from pred if cc_id == 0: continue # skip background pred = torch.tensor(conn_components[1] == cc_id).float() if self.point_mode == CONF_MODE: points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam) # (N, 2) elif self.point_mode == CENTROID_MODE: points = conn_components[3][cc_id][None, :] # (1, 2) confidences = [1 for _ in range(len(points))] elif self.point_mode == BOTH_MODE: points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam) point = conn_components[3][cc_id][None, :] points = np.vstack([points, point]) # (N+1, 2) confidences.append(1) else: raise NotImplementedError(f"point mode {self.point_mode} not implemented") sam_input_points.append(np.array(points)) if get_neg_points: pred_uint8 = (pred.numpy() * 255).astype(np.uint8) # Dilate the mask to expand it kernel_size = 3 # Size of the dilation kernel, adjust accordingly kernel = np.ones((kernel_size, kernel_size), np.uint8) dilation_iterations = 10 # Number of times dilation is applied, adjust as needed dilated_mask = cv2.dilate(pred_uint8, kernel, iterations=dilation_iterations) # Subtract the original mask from the dilated mask # This will give a boundary that is only outside the original mask outside_boundary = dilated_mask - pred_uint8 # Convert back to torch tensor and normalize boundary = torch.tensor(outside_boundary).float() / 255 try: bg_p = output_p[0, 0].detach().cpu() neg_points, neg_confidences = self.get_most_conf_points(bg_p, boundary, l) except RuntimeError as e: # make each point (None, None) neg_points = None # append global negative points to the negative points if neg_points is not None and glob_neg_points is not None: neg_points = np.vstack([neg_points, glob_neg_points]) else: neg_points = glob_neg_points if neg_points is None else neg_points if self.debug and neg_points is not None: # draw an image with 2 subplots, one is the pred and the other is the boundary plt.figure() plt.subplot(121) plt.imshow(pred) plt.imshow(boundary, alpha=0.5) # plot the neg points plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red') plt.subplot(122) plt.imshow(pred) plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red') plt.savefig('debug/pred_and_boundary.png') plt.close() sam_neg_points.append(neg_points) else: # create a list of None same shape as points sam_neg_points = [None for _ in range(len(sam_input_points))] sam_input_labels = np.array([l+1 for l, cc_points in enumerate(sam_input_points) for _ in range(len(cc_points))]) sam_input_points = np.stack(sam_input_points) # should be of shape (num_connected_components, num_points_for_sam, 2) # if get_neg_points: sam_neg_input_points = np.stack(sam_neg_points) if sam_neg_points is not None else None if sam_neg_input_points is not None: sam_neg_input_points = sam_neg_points sam_neg_input_labels = np.array([0] * len(sam_neg_input_points) ) else: sam_neg_input_points = None sam_neg_input_labels = None return sam_input_points, sam_input_labels, sam_neg_input_points, sam_neg_input_labels def get_sam_input_mask(self, conn_components): sam_input_masks = [] sam_input_mask_lables = [] for i, cc_id in enumerate(np.unique(conn_components[1])): # get self.num_points_for_sam most confident points from pred if cc_id == 0: continue pred = torch.tensor(conn_components[1] == cc_id).float() sam_input_masks.append(pred) sam_input_mask_lables.append(cc_id) sam_input_masks = np.stack(sam_input_masks) sam_input_mask_lables = np.array(sam_input_mask_lables) return sam_input_masks, sam_input_mask_lables def predict_w_masks(self, sam_input_masks, qry_img, original_size): masks = [] scores = [] for in_mask in sam_input_masks: in_mask = cv2.resize(in_mask, (256, 256), interpolation=cv2.INTER_NEAREST) in_mask[in_mask == 1] = 10 in_mask[in_mask == 0] = -8 assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8 self.predictor.set_image(qry_img) mask, score, _ = self.predictor.predict( mask_input=in_mask[None, ...].astype(np.uint8), multimask_output=True) # get max index from score if self.debug: # plot each channel of mask fig, ax = plt.subplots(1, 4, figsize=(15, 5)) for i in range(mask.shape[0]): ax[i].imshow(qry_img) ax[i].imshow(mask[i], alpha=0.5) ax[i].set_title(f"Mask {i+1}, Score: {score[i]:.3f}", fontsize=18) # ax[i].axis('off') ax[-1].imshow(cv2.resize(in_mask, original_size, interpolation=cv2.INTER_NEAREST)) fig.savefig(f'debug/sam_mask_from_mask_prompts.png') plt.close(fig) max_index = score.argmax() masks.append(mask[max_index]) scores.append(score[max_index]) return masks, scores def predict_w_points_bbox(self, sam_input_points, bboxes, sam_neg_input_points, qry_img, pred, return_logits=False): masks, scores = [], [] self.predictor.set_image(qry_img) # if sam_input_points is None: # sam_input_points = [None for _ in range(len(bboxes))] for point, bbox_xyxy, neg_point in zip(sam_input_points, bboxes, sam_neg_input_points): assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8 points = point point_labels = np.array([1] * len(point)) if point is not None else None if self.use_neg_points: neg_points = [npoint for npoint in neg_point if None not in npoint] points = np.vstack([point, *neg_points]) point_labels = np.array([1] * len(point) + [0] * len(neg_points)) if self.debug: self.plot_most_conf_points(points[:, None, ...], None, pred, qry_img, bboxes=bbox_xyxy[None,...] if bbox_xyxy is not None else None, title="debug/pos_neg_points.png") # TODO add plots for all points not just the first set of points mask, score, _ = self.predictor.predict( point_coords=points, point_labels=point_labels, # box=bbox_xyxy[None, :] if bbox_xyxy is not None else None, box = bbox_xyxy if bbox_xyxy is not None else None, # mask_input=sam_mask_input, return_logits=return_logits, multimask_output=False if self.use_cca else True ) # best_pred_idx = np.argmax(score) best_pred_idx = 0 masks.append(mask[best_pred_idx]) scores.append(score[best_pred_idx]) if self.debug: # pass self.plot_sam_preds(mask, score, qry_img[...,0], points.reshape(-1,2) if sam_input_points is not None else None, point_labels, input_box=bbox_xyxy if bbox_xyxy is not None else None) return masks, scores def forward(self, query_image, coarse_model_input, degrees_rotate=0): """ query_image: 3d tensor of shape (1, 3, H, W) images should be normalized with mean and std but not to [0, 1]? """ original_size = query_image.shape[-2] # rotate query_image by degrees_rotate start_time = time.time() rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate) # print(f"rotating query image took {time.time() - start_time} seconds") start_time = time.time() coarse_model_input.set_query_images(rotated_img) output_logits_rot = self.coarse_segmentation_model(coarse_model_input) # print(f"ALPNet took {time.time() - start_time} seconds") if degrees_rotate != 0: start_time = time.time() output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate) # print(f"reversing rotated output_logits took {time.time() - start_time} seconds") else: output_logits = output_logits_rot # check if softmax is needed output_p = output_logits.softmax(dim=1) # output_p = output_logits pred = output_logits.argmax(dim=1)[0] if self.debug: _pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu()) plt.subplot(132) plt.imshow(query_image[0,0].detach().cpu()) plt.imshow(_pred, alpha=0.5) plt.subplot(131) # plot heatmap of prob of being fg plt.imshow(output_p[0, 1].detach().cpu()) # plot rotated query image and rotated pred output_p_rot = output_logits_rot.softmax(dim=1) _pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu()) _pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0] plt.subplot(133) plt.imshow(rotated_img[0, 0].detach().cpu()) plt.imshow(_pred_rot, alpha=0.5) plt.savefig('debug/coarse_pred.png') plt.close() if self.coarse_pred_only: output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits pred = output_logits.argmax(dim=1)[0] conf = get_confidence_from_logits(output_logits) if self.use_cca: _pred = np.array(pred.detach().cpu()) _pred, conf = cca(_pred, output_logits, return_conf=True) pred = torch.from_numpy(_pred) if self.training: return output_logits, [conf] # Ensure pred is a float tensor for consistent visualization return pred.float(), [conf] if query_image.shape[-2:] != self.image_size: query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear') output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear') # if need_softmax(output_logits): # output_logits = output_logits.softmax(dim=1) # output_p = output_logits output_p = output_logits.softmax(dim=1) pred = output_p.argmax(dim=1)[0] _pred = np.array(output_p.argmax(dim=1)[0].detach().cpu()) start_time = time.time() if self.use_cca: conn_components = cca(_pred, output_logits, return_cc=True) conf=None else: conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True) if self.debug: plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf) # print(f"connected components took {time.time() - start_time} seconds") if _pred.max() == 0: return output_p.argmax(dim=1)[0], [0] # get bbox from pred if self.use_bbox: start_time = time.time() try: bboxes = self.get_bbox_per_cc(conn_components) except: bboxes = [None] * conn_components[0] else: bboxes = [None] * conn_components[0] # print(f"getting bboxes took {time.time() - start_time} seconds") start_time = time.time() if self.use_points: sam_input_points, sam_input_point_labels, sam_neg_input_points, sam_neg_input_labels = self.get_sam_input_points(conn_components, output_p, get_neg_points=self.use_neg_points, l=1) else: sam_input_points = [None] * conn_components[0] sam_input_point_labels = [None] * conn_components[0] sam_neg_input_points = [None] * conn_components[0] sam_neg_input_labels = [None] * conn_components[0] # print(f"getting sam input points took {time.time() - start_time} seconds") if self.use_mask: sam_input_masks, sam_input_mask_labels = self.get_sam_input_mask(conn_components) else: sam_input_masks = None sam_input_mask_labels = None if self.debug and sam_input_points is not None: title = f'debug/most_conf_points.png' if self.use_cca: title = f'debug/most_conf_points_cca.png' # convert points to a list where each item is a list of 2 elements in xy format self.plot_most_conf_points(sam_input_points, None, _pred, query_image[0, 0].detach().cpu(), bboxes=bboxes, title=title) # TODO add plots for all points not just the first set of points # self.sam_trans = None if self.sam_trans is None: query_image = query_image.permute(1, 2, 0).detach().cpu().numpy() else: query_image = self.sam_trans.apply_image_torch(query_image[0]) query_image = self.sam_trans.preprocess(query_image) query_image = query_image.permute(1, 2, 0).detach().cpu().numpy() # mask = self.sam_trans.preprocess(mask) query_image = ((query_image - query_image.min()) / (query_image.max() - query_image.min()) * 255).astype(np.uint8) if self.use_mask: masks, scores = self.predict_w_masks(sam_input_masks, query_image, original_size) start_time = time.time() if self.use_points or self.use_bbox: masks, scores = self.predict_w_points_bbox(sam_input_points, bboxes, sam_neg_input_points, query_image, pred, return_logits=True if self.training else False) # print(f"predicting w points/bbox took {time.time() - start_time} seconds") pred = sum(masks) if not self.training: pred = pred > 0 pred = torch.tensor(pred).float().to(output_p.device) # pred = torch.tensor(masks[0]).float().cuda() # resize pred to the size of the input pred = F.interpolate(pred.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0] return pred, scores def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels==1] neg_points = coords[labels==0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) def need_softmax(tensor, dim=1): return not torch.all(torch.isclose(tensor.sum(dim=dim), torch.ones_like(tensor.sum(dim=dim))) & (tensor >= 0))