Spaces:
Running
Running
File size: 1,684 Bytes
99cf037 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
import torch.nn as nn
from transformers import SegformerForSemanticSegmentation
class SegFormerUNet(nn.Module):
def __init__(self, model_name="nvidia/segformer-b2-finetuned-ade-512-512", num_classes=1):
super(SegFormerUNet, self).__init__()
# Load Pretrained SegFormer
self.segformer = SegformerForSemanticSegmentation.from_pretrained(model_name)
# Extract Encoder
self.encoder = self.segformer.segformer.encoder # Correct way to get encoder
# U-Net Style Decoder (Upsampling to match input size)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2), # 16x16 -> 32x32
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), # 32x32 -> 64x64
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), # 64x64 -> 128x128
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2), # 128x128 -> 256x256
nn.ReLU(),
nn.ConvTranspose2d(32, num_classes, kernel_size=2, stride=2) # 256x256 -> 512x512
)
def forward(self, x):
retained_input = x # Keep input image
# Encoder processing
encoder_output = self.encoder(x) # Extract encoder features
encoder_output = encoder_output.last_hidden_state.permute(0, 1, 2, 3) # (B, C, H, W)
# print("Encoder Output Shape:", encoder_output.shape) # Should be (B, 512, 16, 16)
# Decoder (Upsample back to input size)
output = self.decoder(encoder_output) # (B, num_classes, 512, 512)
return output # return segmentation mask
|