File size: 3,258 Bytes
08d991a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import logging
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor

logger = logging.getLogger(__name__)

class Segmenter:
    """
    Generalized Semantic Segmentation Wrapper for SegFormer and DeepLabV3.
    """

    def __init__(self, model_key="nvidia/segformer-b0-finetuned-ade-512-512", device="cpu"):
        """
        Initialize the segmentation model.

        Args:
            model_key (str): Model identifier, e.g., Hugging Face model id or 'deeplabv3_resnet50'.
            device (str): Inference device ("cpu" or "cuda").
        """
        logger.info(f"Initializing segmenter with model: {model_key}")
        self.device = device
        self.model_key = model_key
        self.model, self.processor = self._load_model()

    def _load_model(self):
        """
        Load the segmentation model and processor.

        Returns:
            Tuple[torch.nn.Module, Optional[Processor]]
        """
        if "segformer" in self.model_key:
            model = SegformerForSemanticSegmentation.from_pretrained(self.model_key).to(self.device)
            processor = SegformerFeatureExtractor.from_pretrained(self.model_key)
            return model, processor
        elif self.model_key == "deeplabv3_resnet50":
            model = deeplabv3_resnet50(pretrained=True).to(self.device).eval()
            return model, None
        else:
            raise ValueError(f"Unsupported model key: {self.model_key}")

    def predict(self, image: Image.Image, **kwargs):
        """
        Perform segmentation on the input image.

        Args:
            image (PIL.Image.Image): Input image.

        Returns:
            np.ndarray: Segmentation mask.
        """
        logger.info("Running segmentation")

        if "segformer" in self.model_key:
            inputs = self.processor(images=image, return_tensors="pt").to(self.device)
            outputs = self.model(**inputs)
            mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
            return mask

        elif self.model_key == "deeplabv3_resnet50":
            transform = transforms.Compose([
                transforms.ToTensor(),
            ])
            inputs = transform(image).unsqueeze(0).to(self.device)
            with torch.no_grad():
                outputs = self.model(inputs)["out"]
                mask = outputs.argmax(1).squeeze().cpu().numpy()
                return mask

    def draw(self, image: Image.Image, mask: np.ndarray, alpha=0.5):
        """
        Overlay the segmentation mask on the input image.

        Args:
            image (PIL.Image.Image): Original image.
            mask (np.ndarray): Segmentation mask.
            alpha (float): Blend strength.

        Returns:
            PIL.Image.Image: Image with mask overlay.
        """
        logger.info("Drawing segmentation overlay")
        mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8)).convert("L").resize(image.size)
        mask_colored = Image.merge("RGB", (mask_img, mask_img, mask_img))
        return Image.blend(image, mask_colored, alpha)