from torch import nn class ConvLayer(nn.Module): def __init__(self, input_c, output_c, bias=False, stride=1, padding=1, pool=False, dropout=0.): super(ConvLayer, self).__init__() layers = list() layers.append( nn.Conv2d(input_c, output_c, kernel_size=3, bias=bias, stride=stride, padding=padding, padding_mode='replicate') ) if pool: layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) layers.append(nn.BatchNorm2d(output_c)) layers.append(nn.ReLU()) if dropout > 0: layers.append(nn.Dropout(dropout)) self.all_layers = nn.Sequential(*layers) def forward(self, x): return self.all_layers(x) class CustomLayer(nn.Module): def __init__(self, input_c, output_c, pool=True, residue=2, dropout=0.): super(CustomLayer, self).__init__() self.pool_block = ConvLayer(input_c, output_c, pool=pool, dropout=dropout) self.res_block = None if residue > 0: layers = list() for i in range(0, residue): layers.append(ConvLayer(output_c, output_c, pool=False, dropout=dropout)) self.res_block = nn.Sequential(*layers) def forward(self, x): x = self.pool_block(x) if self.res_block is not None: x_ = x x = self.res_block(x) # += operator causes inplace errors in pytorch if done right after relu. x = x + x_ return x class Model(nn.Module): def __init__(self, dropout=0.05): super(Model, self).__init__() self.network = nn.Sequential( CustomLayer(3, 64, pool=False, residue=0, dropout=dropout), CustomLayer(64, 128, pool=True, residue=2, dropout=dropout), CustomLayer(128, 256, pool=True, residue=0, dropout=dropout), CustomLayer(256, 512, pool=True, residue=2, dropout=dropout), nn.MaxPool2d(kernel_size=4, stride=4), nn.Flatten(), nn.Linear(512, 10) ) def forward(self, x): return self.network(x)