Spaces:
Sleeping
Sleeping
from torch import nn | |
class ConvLayer(nn.Module): | |
def __init__(self, input_c, output_c, bias=False, stride=1, padding=1, pool=False, dropout=0.): | |
super(ConvLayer, self).__init__() | |
layers = list() | |
layers.append( | |
nn.Conv2d(input_c, output_c, kernel_size=3, bias=bias, stride=stride, padding=padding, | |
padding_mode='replicate') | |
) | |
if pool: | |
layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) | |
layers.append(nn.BatchNorm2d(output_c)) | |
layers.append(nn.ReLU()) | |
if dropout > 0: | |
layers.append(nn.Dropout(dropout)) | |
self.all_layers = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.all_layers(x) | |
class CustomLayer(nn.Module): | |
def __init__(self, input_c, output_c, pool=True, residue=2, dropout=0.): | |
super(CustomLayer, self).__init__() | |
self.pool_block = ConvLayer(input_c, output_c, pool=pool, dropout=dropout) | |
self.res_block = None | |
if residue > 0: | |
layers = list() | |
for i in range(0, residue): | |
layers.append(ConvLayer(output_c, output_c, pool=False, dropout=dropout)) | |
self.res_block = nn.Sequential(*layers) | |
def forward(self, x): | |
x = self.pool_block(x) | |
if self.res_block is not None: | |
x_ = x | |
x = self.res_block(x) | |
# += operator causes inplace errors in pytorch if done right after relu. | |
x = x + x_ | |
return x | |
class Model(nn.Module): | |
def __init__(self, dropout=0.05): | |
super(Model, self).__init__() | |
self.network = nn.Sequential( | |
CustomLayer(3, 64, pool=False, residue=0, dropout=dropout), | |
CustomLayer(64, 128, pool=True, residue=2, dropout=dropout), | |
CustomLayer(128, 256, pool=True, residue=0, dropout=dropout), | |
CustomLayer(256, 512, pool=True, residue=2, dropout=dropout), | |
nn.MaxPool2d(kernel_size=4, stride=4), | |
nn.Flatten(), | |
nn.Linear(512, 10) | |
) | |
def forward(self, x): | |
return self.network(x) | |