Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from transformers import SegformerForSemanticSegmentation | |
class SegFormerUNet(nn.Module): | |
def __init__(self, model_name="nvidia/segformer-b2-finetuned-ade-512-512", num_classes=1): | |
super(SegFormerUNet, self).__init__() | |
# Load Pretrained SegFormer | |
self.segformer = SegformerForSemanticSegmentation.from_pretrained(model_name) | |
# Extract Encoder | |
self.encoder = self.segformer.segformer.encoder # Correct way to get encoder | |
# U-Net Style Decoder (Upsampling to match input size) | |
self.decoder = nn.Sequential( | |
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2), # 16x16 -> 32x32 | |
nn.ReLU(), | |
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), # 32x32 -> 64x64 | |
nn.ReLU(), | |
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), # 64x64 -> 128x128 | |
nn.ReLU(), | |
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2), # 128x128 -> 256x256 | |
nn.ReLU(), | |
nn.ConvTranspose2d(32, num_classes, kernel_size=2, stride=2) # 256x256 -> 512x512 | |
) | |
def forward(self, x): | |
retained_input = x # Keep input image | |
# Encoder processing | |
encoder_output = self.encoder(x) # Extract encoder features | |
encoder_output = encoder_output.last_hidden_state.permute(0, 1, 2, 3) # (B, C, H, W) | |
# print("Encoder Output Shape:", encoder_output.shape) # Should be (B, 512, 16, 16) | |
# Decoder (Upsample back to input size) | |
output = self.decoder(encoder_output) # (B, num_classes, 512, 512) | |
return output # return segmentation mask | |