Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn as nn | |
__all__ = ['SharedMLP'] | |
class SharedMLP(nn.Module): | |
def __init__(self, in_channels, out_channels, dim=1, device='cuda'): | |
super().__init__() | |
# print('==> SharedMLP device: ', device) | |
if dim == 1: | |
conv = nn.Conv1d | |
bn = nn.InstanceNorm1d | |
elif dim == 2: | |
conv = nn.Conv2d | |
bn = nn.InstanceNorm1d | |
else: | |
raise ValueError | |
if not isinstance(out_channels, (list, tuple)): | |
out_channels = [out_channels] | |
layers = [] | |
for oc in out_channels: | |
layers.extend( | |
[ | |
conv(in_channels, oc, 1, device=device), | |
bn(oc, device=device), | |
nn.ReLU(True), | |
]) | |
in_channels = oc | |
self.layers = nn.Sequential(*layers) | |
def forward(self, inputs): | |
if isinstance(inputs, (list, tuple)): | |
return (self.layers(inputs[0]), *inputs[1:]) | |
else: | |
return self.layers(inputs) | |