Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import numpy as np | |
import torch | |
from scipy import ndimage | |
from .utils import convert_to_numpy | |
class SAMImageAnnotator: | |
def __init__(self, cfg, device=None): | |
try: | |
from segment_anything import sam_model_registry, SamPredictor | |
from segment_anything.utils.transforms import ResizeLongestSide | |
except: | |
import warnings | |
warnings.warn("please pip install sam package, or you can refer to models/VACE-Annotators/sam/segment_anything-1.0-py3-none-any.whl") | |
self.task_type = cfg.get('TASK_TYPE', 'input_box') | |
self.return_mask = cfg.get('RETURN_MASK', False) | |
self.transform = ResizeLongestSide(1024) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device | |
seg_model = sam_model_registry[cfg.get('MODEL_NAME', 'vit_b')](checkpoint=cfg['PRETRAINED_MODEL']).eval().to(self.device) | |
self.predictor = SamPredictor(seg_model) | |
def forward(self, | |
image, | |
input_box=None, | |
mask=None, | |
task_type=None, | |
return_mask=None): | |
task_type = task_type if task_type is not None else self.task_type | |
return_mask = return_mask if return_mask is not None else self.return_mask | |
mask = convert_to_numpy(mask) if mask is not None else None | |
if task_type == 'mask_point': | |
if len(mask.shape) == 3: | |
scribble = mask.transpose(2, 1, 0)[0] | |
else: | |
scribble = mask.transpose(1, 0) # (H, W) -> (W, H) | |
labeled_array, num_features = ndimage.label(scribble >= 255) | |
centers = ndimage.center_of_mass(scribble, labeled_array, | |
range(1, num_features + 1)) | |
point_coords = np.array(centers) | |
point_labels = np.array([1] * len(centers)) | |
sample = { | |
'point_coords': point_coords, | |
'point_labels': point_labels | |
} | |
elif task_type == 'mask_box': | |
if len(mask.shape) == 3: | |
scribble = mask.transpose(2, 1, 0)[0] | |
else: | |
scribble = mask.transpose(1, 0) # (H, W) -> (W, H) | |
labeled_array, num_features = ndimage.label(scribble >= 255) | |
centers = ndimage.center_of_mass(scribble, labeled_array, | |
range(1, num_features + 1)) | |
centers = np.array(centers) | |
# (x1, y1, x2, y2) | |
x_min = centers[:, 0].min() | |
x_max = centers[:, 0].max() | |
y_min = centers[:, 1].min() | |
y_max = centers[:, 1].max() | |
bbox = np.array([x_min, y_min, x_max, y_max]) | |
sample = {'box': bbox} | |
elif task_type == 'input_box': | |
if isinstance(input_box, list): | |
input_box = np.array(input_box) | |
sample = {'box': input_box} | |
elif task_type == 'mask': | |
sample = {'mask_input': mask[None, :, :]} | |
else: | |
raise NotImplementedError | |
self.predictor.set_image(image) | |
masks, scores, logits = self.predictor.predict( | |
multimask_output=False, | |
**sample | |
) | |
sorted_ind = np.argsort(scores)[::-1] | |
masks = masks[sorted_ind] | |
scores = scores[sorted_ind] | |
logits = logits[sorted_ind] | |
if return_mask: | |
return masks[0] | |
else: | |
ret_data = { | |
"masks": masks, | |
"scores": scores, | |
"logits": logits | |
} | |
return ret_data |