CV-Agent / tool_utils /clip_segmentation.py
Samarth991's picture
added segmentation mask overlay
9b2e4f0
raw
history blame
2.51 kB
import cv2
from matplotlib import pyplot as plt
import torch
import numpy as np
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from segmentation_mask_overlay import overlay_masks
from typing import List
import logging
class CLIPSEG:
def __init__(self,model_name = "CIDAS/clipseg-rd64-refined",threshould=0.60):
self.clip_processor = CLIPSegProcessor.from_pretrained(model_name)
self.clip_model = CLIPSegForImageSegmentation.from_pretrained(model_name)
self.threshould = threshould
self.clip_model.to('cpu')
@staticmethod
def create_rgb_mask(mask,color=None):
color = tuple(np.random.choice(range(0,256), size=3))
gray_3_channel = cv2.merge((mask, mask, mask))
gray_3_channel[mask==255] = color
return gray_3_channel.astype(np.uint8)
def get_segmentation_mask(self,image_path:str,object_prompts:List):
image = cv2.cvtColor(cv2.imread(image_path),cv2.COLOR_BGR2RGB)
logging.info("objects found out from the image :{}".format(object_prompts))
predicted_masks = []
inputs = self.clip_processor(
text=object_prompts,
images=[image] * len(object_prompts),
padding="max_length",
return_tensors="pt",
)
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
outputs = self.clip_model(**inputs)
preds = outputs.logits.unsqueeze(1)
# detections = outputs.logits[0] # Assuming class index 0
for i in range(preds.shape[0]):
predicted_mask = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
predicted_mask = np.where(predicted_mask>self.threshould, 255,0)
predicted_masks.append(predicted_mask)
resize_image = cv2.resize(image,(352,352))
mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(object_prompts)]
cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1]
bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
final_mask = overlay_masks(resize_image,np.stack(bool_masks,-1),labels=mask_labels,colors=cmap,alpha=0.5,beta=0.7)
try:
cv2.imwrite('final_mask.png',final_mask)
return 'Segmentation image created : final_mask.png'
except Exception as e:
logging.error("Error while saving the final mask :",e)
return "unable to create a mask image "