Spaces:
Sleeping
Sleeping
from typing import List | |
import torch | |
import torch.nn as nn | |
class ConvBlock(nn.Module): | |
def __init__(self, ch_in, ch_out): | |
super(ConvBlock, self).__init__() | |
self.conv1 = nn.Conv2d( | |
ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False | |
) | |
self.conv2 = nn.Conv2d( | |
ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=False | |
) | |
nn.init.normal_(self.conv1.weight, mean=0.0, std=0.02) | |
nn.init.normal_(self.conv2.weight, mean=0.0, std=0.02) | |
self.batchnorm1 = nn.BatchNorm2d(ch_out) | |
self.batchnorm2 = nn.BatchNorm2d(ch_out) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
h = self.relu(self.batchnorm1(self.conv1(x))) | |
h = self.relu(self.batchnorm2(self.conv2(h))) | |
return h | |
class EncodeBlock(nn.Module): | |
def __init__(self, ch_in, ch_out): | |
super(EncodeBlock, self).__init__() | |
self.conv = ConvBlock(ch_in, ch_out) | |
self.pool = nn.MaxPool2d((2, 2)) | |
def forward(self, x): | |
skip = self.conv(x) | |
h = self.pool(skip) | |
return h, skip | |
class DecodeBlock(nn.Module): | |
def __init__(self, ch_in, ch_out): | |
super(DecodeBlock, self).__init__() | |
self.up = nn.ConvTranspose2d( | |
ch_in, ch_out, kernel_size=2, stride=2, padding=0, bias=True | |
) | |
self.conv = ConvBlock(ch_out * 2, ch_out) | |
def forward(self, x, skip): | |
h = self.up(x) | |
h = self.conv(torch.cat([h, skip], dim=1)) | |
return h | |
class UNet(nn.Module): | |
def __init__(self, ch_in, ch_out): | |
super(UNet, self).__init__() | |
self.econv0 = nn.Conv2d( | |
ch_in, 64, kernel_size=1, stride=1, padding=0, bias=True | |
) | |
nn.init.normal_(self.econv0.weight, mean=0.0, std=0.02) | |
self.econv1 = self.make_downblock(64, 64) | |
self.econv2 = self.make_downblock(64, 128) | |
self.econv3 = self.make_downblock(128, 256) | |
self.econv4 = self.make_downblock(256, 512) | |
self.bottle = self.make_bottleblock(512, 1024) | |
self.dconv4 = self.make_upblock(1024, 512) | |
self.dconv3 = self.make_upblock(512, 256) | |
self.dconv2 = self.make_upblock(256, 128) | |
self.dconv1 = self.make_upblock(128, 64) | |
self.dconv0 = nn.Conv2d( | |
64, ch_out, kernel_size=1, stride=1, padding=0, bias=True | |
) | |
nn.init.normal_(self.dconv0.weight, mean=0.0, std=0.02) | |
def make_downblock(self, ch_in, ch_out): | |
return EncodeBlock(ch_in=ch_in, ch_out=ch_out) | |
def make_bottleblock(self, ch_in, ch_out): | |
return ConvBlock(ch_in=ch_in, ch_out=ch_out) | |
def make_upblock(self, ch_in, ch_out): | |
return DecodeBlock(ch_in=ch_in, ch_out=ch_out) | |
def forward(self, x): | |
x = self.econv0(x) | |
x, skip1 = self.econv1(x) | |
x, skip2 = self.econv2(x) | |
x, skip3 = self.econv3(x) | |
x, skip4 = self.econv4(x) | |
x = self.bottle(x) | |
x = self.dconv4(x, skip4) | |
x = self.dconv3(x, skip3) | |
x = self.dconv2(x, skip2) | |
x = self.dconv1(x, skip1) | |
x = self.dconv0(x) | |
return x | |