UVIS / models /segmentation /segmenter.py
DurgaDeepak's picture
Update models/segmentation/segmenter.py
52aa226 verified
import logging
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50
from transformers import (
SegformerForSemanticSegmentation,
SegformerFeatureExtractor,
AutoProcessor,
CLIPSegForImageSegmentation,
)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
class Segmenter:
"""
Generalized Semantic Segmentation Wrapper for SegFormer, DeepLabV3, and CLIPSeg.
"""
def __init__(self, model_key="nvidia/segformer-b0-finetuned-ade-512-512", device="cpu"):
"""
Args:
model_key (str): HF model identifier or 'deeplabv3_resnet50'.
device (str): 'cpu' or 'cuda'.
"""
logger.info(f"Initializing Segmenter for model '{model_key}' on {device}")
self.model_key = model_key.lower()
self.device = device
self.model = None
self.processor = None # for transformers-based models
def _load_model(self):
"""
Lazy-load the model & processor based on model_key.
"""
if self.model is not None:
return
# SegFormer
if "segformer" in self.model_key:
self.model = SegformerForSemanticSegmentation.from_pretrained(self.model_key).to(self.device).eval()
self.processor = SegformerFeatureExtractor.from_pretrained(self.model_key)
# DeepLabV3
elif self.model_key == "deeplabv3_resnet50":
self.model = deeplabv3_resnet50(pretrained=True).to(self.device).eval()
self.processor = None
# CLIPSeg
elif "clipseg" in self.model_key:
self.model = CLIPSegForImageSegmentation.from_pretrained(self.model_key).to(self.device).eval()
self.processor = AutoProcessor.from_pretrained(self.model_key)
else:
raise ValueError(f"Unsupported segmentation model key: '{self.model_key}'")
logger.info(f"Loaded segmentation model '{self.model_key}'")
def predict(self, image: Image.Image, prompt: str = "", **kwargs) -> np.ndarray:
"""
Perform segmentation.
Args:
image (PIL.Image.Image): Input.
prompt (str): Only used for CLIPSeg.
Returns:
np.ndarray: Segmentation mask (H×W).
"""
self._load_model()
# SegFormer path
if "segformer" in self.model_key:
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
outputs = self.model(**inputs)
mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
return mask
# DeepLabV3 path
if self.model_key == "deeplabv3_resnet50":
tf = transforms.ToTensor()
inp = tf(image).unsqueeze(0).to(self.device)
with torch.no_grad():
out = self.model(inp)["out"]
mask = out.argmax(1).squeeze().cpu().numpy()
return mask
# CLIPSeg path
if "clipseg" in self.model_key:
# CLIPSeg expects both text and image
inputs = self.processor(
text=[prompt], # list of prompts
images=[image], # list of images
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# outputs.logits shape: (batch=1, height, width)
mask = outputs.logits.squeeze(0).cpu().numpy()
# Optionally threshold to binary:
# mask = (mask > kwargs.get("threshold", 0.5)).astype(np.uint8)
return mask
raise RuntimeError("Unreachable segmentation branch")
def draw(self, image: Image.Image, mask: np.ndarray, alpha=0.5) -> Image.Image:
"""
Overlay the segmentation mask on the input image.
Args:
image (PIL.Image.Image): Original.
mask (np.ndarray): Segmentation mask.
alpha (float): Blend strength.
Returns:
PIL.Image.Image: Blended output.
"""
logger.info("Drawing segmentation overlay")
# Normalize mask to 0–255
gray = ((mask - mask.min()) / (mask.ptp()) * 255).astype(np.uint8)
mask_img = Image.fromarray(gray).convert("L").resize(image.size)
# Make it RGB
color_mask = Image.merge("RGB", (mask_img, mask_img, mask_img))
return Image.blend(image, color_mask, alpha)