Spaces:
Sleeping
Sleeping
File size: 2,473 Bytes
427d150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
import torch.nn as nn
import numpy as np
from models.segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
from models.segment_anything.utils.transforms import ResizeLongestSide
import cv2
def get_iou(mask, label):
tp = (mask * label).sum()
fp = (mask * (1-label)).sum()
fn = ((1-mask) * label).sum()
iou = tp / (tp + fp + fn)
return iou
class SamWrapper(nn.Module):
def __init__(self,sam_args):
"""
sam_args: dict should include the following
{
"model_type": "vit_h",
"sam_checkpoint": "path to checkpoint" pretrained_model/sam_vit_h.pth
}
"""
super().__init__()
self.sam = sam_model_registry[sam_args['model_type']](checkpoint=sam_args['sam_checkpoint'])
self.mask_generator = SamAutomaticMaskGenerator(self.sam)
self.transform = ResizeLongestSide(self.sam.image_encoder.img_size)
def forward(self, image, image_labels):
"""
generate masks for a batch of images
return mask that has the largest iou with the image label
Args:
images (np.ndarray): The image to generate masks for, in HWC uint8 format.
image_labels (np.ndarray): The image labels to generate masks for, in HWC uint8 format. assuming binary labels
"""
image = self.transform.apply_image(image)
masks = self.mask_generator.generate(image)
best_index, best_iou = None, 0
for i, mask in enumerate(masks):
segmentation = mask['segmentation']
iou = get_iou(segmentation.astype(np.uint8), image_labels)
if best_index is None or iou > best_iou:
best_index = i
best_iou = iou
return masks[best_index]['segmentation']
def to(self, device):
self.sam.to(device)
self.mask_generator.to(device)
self.mask_generator.predictor.to(device)
if __name__ == "__main__":
sam_args = {
"model_type": "vit_h",
"sam_checkpoint": "pretrained_model/sam_vit_h.pth"
}
sam_wrapper = SamWrapper(sam_args).cuda()
image = cv2.imread("./Kheops-Pyramid.jpg")
image = np.array(image).astype('uint8')
image_labels = torch.rand(1,3,224,224)
sam_wrapper(image, image_labels)
|