Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import numpy as np | |
from models.segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator | |
from models.segment_anything.utils.transforms import ResizeLongestSide | |
import cv2 | |
def get_iou(mask, label): | |
tp = (mask * label).sum() | |
fp = (mask * (1-label)).sum() | |
fn = ((1-mask) * label).sum() | |
iou = tp / (tp + fp + fn) | |
return iou | |
class SamWrapper(nn.Module): | |
def __init__(self,sam_args): | |
""" | |
sam_args: dict should include the following | |
{ | |
"model_type": "vit_h", | |
"sam_checkpoint": "path to checkpoint" pretrained_model/sam_vit_h.pth | |
} | |
""" | |
super().__init__() | |
self.sam = sam_model_registry[sam_args['model_type']](checkpoint=sam_args['sam_checkpoint']) | |
self.mask_generator = SamAutomaticMaskGenerator(self.sam) | |
self.transform = ResizeLongestSide(self.sam.image_encoder.img_size) | |
def forward(self, image, image_labels): | |
""" | |
generate masks for a batch of images | |
return mask that has the largest iou with the image label | |
Args: | |
images (np.ndarray): The image to generate masks for, in HWC uint8 format. | |
image_labels (np.ndarray): The image labels to generate masks for, in HWC uint8 format. assuming binary labels | |
""" | |
image = self.transform.apply_image(image) | |
masks = self.mask_generator.generate(image) | |
best_index, best_iou = None, 0 | |
for i, mask in enumerate(masks): | |
segmentation = mask['segmentation'] | |
iou = get_iou(segmentation.astype(np.uint8), image_labels) | |
if best_index is None or iou > best_iou: | |
best_index = i | |
best_iou = iou | |
return masks[best_index]['segmentation'] | |
def to(self, device): | |
self.sam.to(device) | |
self.mask_generator.to(device) | |
self.mask_generator.predictor.to(device) | |
if __name__ == "__main__": | |
sam_args = { | |
"model_type": "vit_h", | |
"sam_checkpoint": "pretrained_model/sam_vit_h.pth" | |
} | |
sam_wrapper = SamWrapper(sam_args).cuda() | |
image = cv2.imread("./Kheops-Pyramid.jpg") | |
image = np.array(image).astype('uint8') | |
image_labels = torch.rand(1,3,224,224) | |
sam_wrapper(image, image_labels) | |