SegNetX / arch.py
deekshithabandam's picture
Upload 4 files
99cf037 verified
import torch
import torch.nn as nn
from transformers import SegformerForSemanticSegmentation
class SegFormerUNet(nn.Module):
def __init__(self, model_name="nvidia/segformer-b2-finetuned-ade-512-512", num_classes=1):
super(SegFormerUNet, self).__init__()
# Load Pretrained SegFormer
self.segformer = SegformerForSemanticSegmentation.from_pretrained(model_name)
# Extract Encoder
self.encoder = self.segformer.segformer.encoder # Correct way to get encoder
# U-Net Style Decoder (Upsampling to match input size)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2), # 16x16 -> 32x32
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), # 32x32 -> 64x64
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), # 64x64 -> 128x128
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2), # 128x128 -> 256x256
nn.ReLU(),
nn.ConvTranspose2d(32, num_classes, kernel_size=2, stride=2) # 256x256 -> 512x512
)
def forward(self, x):
retained_input = x # Keep input image
# Encoder processing
encoder_output = self.encoder(x) # Extract encoder features
encoder_output = encoder_output.last_hidden_state.permute(0, 1, 2, 3) # (B, C, H, W)
# print("Encoder Output Shape:", encoder_output.shape) # Should be (B, 512, 16, 16)
# Decoder (Upsample back to input size)
output = self.decoder(encoder_output) # (B, num_classes, 512, 512)
return output # return segmentation mask