Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import cv2 | |
import torch | |
import numpy as np | |
import torchvision | |
from .utils import convert_to_numpy | |
class GDINOAnnotator: | |
def __init__(self, cfg, device=None): | |
try: | |
from groundingdino.util.inference import Model, load_model, load_image, predict | |
except: | |
import warnings | |
warnings.warn("please pip install groundingdino package, or you can refer to models/VACE-Annotators/gdino/groundingdino-0.1.0-cp310-cp310-linux_x86_64.whl") | |
grounding_dino_config_path = cfg['CONFIG_PATH'] | |
grounding_dino_checkpoint_path = cfg['PRETRAINED_MODEL'] | |
grounding_dino_tokenizer_path = cfg['TOKENIZER_PATH'] # TODO | |
self.box_threshold = cfg.get('BOX_THRESHOLD', 0.25) | |
self.text_threshold = cfg.get('TEXT_THRESHOLD', 0.2) | |
self.iou_threshold = cfg.get('IOU_THRESHOLD', 0.5) | |
self.use_nms = cfg.get('USE_NMS', True) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device | |
self.model = Model(model_config_path=grounding_dino_config_path, | |
model_checkpoint_path=grounding_dino_checkpoint_path, | |
device=self.device) | |
def forward(self, image, classes=None, caption=None): | |
image_bgr = convert_to_numpy(image)[..., ::-1] # bgr | |
if classes is not None: | |
classes = [classes] if isinstance(classes, str) else classes | |
detections = self.model.predict_with_classes( | |
image=image_bgr, | |
classes=classes, | |
box_threshold=self.box_threshold, | |
text_threshold=self.text_threshold | |
) | |
elif caption is not None: | |
detections, phrases = self.model.predict_with_caption( | |
image=image_bgr, | |
caption=caption, | |
box_threshold=self.box_threshold, | |
text_threshold=self.text_threshold | |
) | |
else: | |
raise NotImplementedError() | |
if self.use_nms: | |
nms_idx = torchvision.ops.nms( | |
torch.from_numpy(detections.xyxy), | |
torch.from_numpy(detections.confidence), | |
self.iou_threshold | |
).numpy().tolist() | |
detections.xyxy = detections.xyxy[nms_idx] | |
detections.confidence = detections.confidence[nms_idx] | |
detections.class_id = detections.class_id[nms_idx] if detections.class_id is not None else None | |
boxes = detections.xyxy | |
confidences = detections.confidence | |
class_ids = detections.class_id | |
class_names = [classes[_id] for _id in class_ids] if classes is not None else phrases | |
ret_data = { | |
"boxes": boxes.tolist() if boxes is not None else None, | |
"confidences": confidences.tolist() if confidences is not None else None, | |
"class_ids": class_ids.tolist() if class_ids is not None else None, | |
"class_names": class_names if class_names is not None else None, | |
} | |
return ret_data | |
class GDINORAMAnnotator: | |
def __init__(self, cfg, device=None): | |
from .ram import RAMAnnotator | |
from .gdino import GDINOAnnotator | |
self.ram_model = RAMAnnotator(cfg['RAM'], device=device) | |
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device) | |
def forward(self, image): | |
ram_res = self.ram_model.forward(image) | |
classes = ram_res['tag_e'] if isinstance(ram_res, dict) else ram_res | |
gdino_res = self.gdino_model.forward(image, classes=classes) | |
return gdino_res | |