Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from typing import List, Tuple | |
| class SharedMLP(nn.Sequential): | |
| def __init__( | |
| self, | |
| args: List[int], | |
| *, | |
| bn: bool = False, | |
| activation=nn.ReLU(inplace=True), | |
| preact: bool = False, | |
| first: bool = False, | |
| name: str = "", | |
| instance_norm: bool = False, | |
| ): | |
| super().__init__() | |
| for i in range(len(args) - 1): | |
| self.add_module( | |
| name + 'layer{}'.format(i), | |
| Conv2d( | |
| args[i], | |
| args[i + 1], | |
| bn=(not first or not preact or (i != 0)) and bn, | |
| activation=activation | |
| if (not first or not preact or (i != 0)) else None, | |
| preact=preact, | |
| instance_norm=instance_norm | |
| ) | |
| ) | |
| class _ConvBase(nn.Sequential): | |
| def __init__( | |
| self, | |
| in_size, | |
| out_size, | |
| kernel_size, | |
| stride, | |
| padding, | |
| activation, | |
| bn, | |
| init, | |
| conv=None, | |
| batch_norm=None, | |
| bias=True, | |
| preact=False, | |
| name="", | |
| instance_norm=False, | |
| instance_norm_func=None | |
| ): | |
| super().__init__() | |
| bias = bias and (not bn) | |
| conv_unit = conv( | |
| in_size, | |
| out_size, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| bias=bias | |
| ) | |
| init(conv_unit.weight) | |
| if bias: | |
| nn.init.constant_(conv_unit.bias, 0) | |
| if bn: | |
| if not preact: | |
| bn_unit = batch_norm(out_size) | |
| else: | |
| bn_unit = batch_norm(in_size) | |
| if instance_norm: | |
| if not preact: | |
| in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) | |
| else: | |
| in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) | |
| if preact: | |
| if bn: | |
| self.add_module(name + 'bn', bn_unit) | |
| if activation is not None: | |
| self.add_module(name + 'activation', activation) | |
| if not bn and instance_norm: | |
| self.add_module(name + 'in', in_unit) | |
| self.add_module(name + 'conv', conv_unit) | |
| if not preact: | |
| if bn: | |
| self.add_module(name + 'bn', bn_unit) | |
| if activation is not None: | |
| self.add_module(name + 'activation', activation) | |
| if not bn and instance_norm: | |
| self.add_module(name + 'in', in_unit) | |
| class _BNBase(nn.Sequential): | |
| def __init__(self, in_size, batch_norm=None, name=""): | |
| super().__init__() | |
| self.add_module(name + "bn", batch_norm(in_size)) | |
| nn.init.constant_(self[0].weight, 1.0) | |
| nn.init.constant_(self[0].bias, 0) | |
| class BatchNorm1d(_BNBase): | |
| def __init__(self, in_size: int, *, name: str = ""): | |
| super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) | |
| class BatchNorm2d(_BNBase): | |
| def __init__(self, in_size: int, name: str = ""): | |
| super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) | |
| class Conv1d(_ConvBase): | |
| def __init__( | |
| self, | |
| in_size: int, | |
| out_size: int, | |
| *, | |
| kernel_size: int = 1, | |
| stride: int = 1, | |
| padding: int = 0, | |
| activation=nn.ReLU(inplace=True), | |
| bn: bool = False, | |
| init=nn.init.kaiming_normal_, | |
| bias: bool = True, | |
| preact: bool = False, | |
| name: str = "", | |
| instance_norm=False | |
| ): | |
| super().__init__( | |
| in_size, | |
| out_size, | |
| kernel_size, | |
| stride, | |
| padding, | |
| activation, | |
| bn, | |
| init, | |
| conv=nn.Conv1d, | |
| batch_norm=BatchNorm1d, | |
| bias=bias, | |
| preact=preact, | |
| name=name, | |
| instance_norm=instance_norm, | |
| instance_norm_func=nn.InstanceNorm1d | |
| ) | |
| class Conv2d(_ConvBase): | |
| def __init__( | |
| self, | |
| in_size: int, | |
| out_size: int, | |
| *, | |
| kernel_size: Tuple[int, int] = (1, 1), | |
| stride: Tuple[int, int] = (1, 1), | |
| padding: Tuple[int, int] = (0, 0), | |
| activation=nn.ReLU(inplace=True), | |
| bn: bool = False, | |
| init=nn.init.kaiming_normal_, | |
| bias: bool = True, | |
| preact: bool = False, | |
| name: str = "", | |
| instance_norm=False | |
| ): | |
| super().__init__( | |
| in_size, | |
| out_size, | |
| kernel_size, | |
| stride, | |
| padding, | |
| activation, | |
| bn, | |
| init, | |
| conv=nn.Conv2d, | |
| batch_norm=BatchNorm2d, | |
| bias=bias, | |
| preact=preact, | |
| name=name, | |
| instance_norm=instance_norm, | |
| instance_norm_func=nn.InstanceNorm2d | |
| ) | |
| class FC(nn.Sequential): | |
| def __init__( | |
| self, | |
| in_size: int, | |
| out_size: int, | |
| *, | |
| activation=nn.ReLU(inplace=True), | |
| bn: bool = False, | |
| init=None, | |
| preact: bool = False, | |
| name: str = "" | |
| ): | |
| super().__init__() | |
| fc = nn.Linear(in_size, out_size, bias=not bn) | |
| if init is not None: | |
| init(fc.weight) | |
| if not bn: | |
| nn.init.constant(fc.bias, 0) | |
| if preact: | |
| if bn: | |
| self.add_module(name + 'bn', BatchNorm1d(in_size)) | |
| if activation is not None: | |
| self.add_module(name + 'activation', activation) | |
| self.add_module(name + 'fc', fc) | |
| if not preact: | |
| if bn: | |
| self.add_module(name + 'bn', BatchNorm1d(out_size)) | |
| if activation is not None: | |
| self.add_module(name + 'activation', activation) | |