File size: 1,684 Bytes
99cf037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
from transformers import SegformerForSemanticSegmentation

class SegFormerUNet(nn.Module):
    def __init__(self, model_name="nvidia/segformer-b2-finetuned-ade-512-512", num_classes=1):
        super(SegFormerUNet, self).__init__()

        # Load Pretrained SegFormer
        self.segformer = SegformerForSemanticSegmentation.from_pretrained(model_name)

        # Extract Encoder
        self.encoder = self.segformer.segformer.encoder  # Correct way to get encoder

        # U-Net Style Decoder (Upsampling to match input size)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),  # 16x16 -> 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),  # 32x32 -> 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),   # 64x64 -> 128x128
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),    # 128x128 -> 256x256
            nn.ReLU(),
            nn.ConvTranspose2d(32, num_classes, kernel_size=2, stride=2)  # 256x256 -> 512x512
        )

    def forward(self, x):
        retained_input = x  # Keep input image

        # Encoder processing
        encoder_output = self.encoder(x)  # Extract encoder features
        encoder_output = encoder_output.last_hidden_state.permute(0, 1, 2, 3)  # (B, C, H, W)
        # print("Encoder Output Shape:", encoder_output.shape)  # Should be (B, 512, 16, 16)

        # Decoder (Upsample back to input size)
        output = self.decoder(encoder_output)  # (B, num_classes, 512, 512)

        return output  # return segmentation mask