Nicolas Denier
ready for submission
5ad4868
raw
history blame
2.26 kB
from torch import ones, split, bfloat16
from torch.nn.functional import relu, sigmoid
from torch.nn import Module, MaxPool1d, Conv1d, Conv2d, Linear, BatchNorm2d, LSTMCell
class ChunkCNN(Module):
def __init__(self):
super(ChunkCNN, self).__init__()
self.pool = MaxPool1d(kernel_size=2, stride=2)
self.conv1 = Conv2d(in_channels=1, out_channels=4, kernel_size=(15,10), stride=(4, 1), padding=(1, 0))
self.conv2 = Conv1d(in_channels=4, out_channels=8, kernel_size=9, stride=2, padding=0)
self.conv3 = Conv1d(in_channels=8, out_channels=16, kernel_size=2, stride=2, padding=1)
self.collapse = Conv1d(in_channels=16, out_channels=1, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = self.conv1(x).squeeze()
x = relu(x)
x = self.pool(x)
x = self.conv2(x)
x = relu(x)
x = self.pool(x)
x = self.conv3(x)
x = relu(x)
x = self.collapse(x).squeeze()
x = relu(x)
return x
class LastLayer(Module):
def __init__(self, inputsize):
super(LastLayer, self).__init__()
self.dense1 = Linear(inputsize, 3)
self.dense2 = Linear(3, 1)
def forward(self, x):
x = self.dense1(x)
x = relu(x)
x = self.dense2(x)
x = sigmoid(x).squeeze()
return x
class ChainsawDetector(Module):
def __init__(self, batch_size):
super(ChainsawDetector, self).__init__()
self.batch_size = batch_size
self.nb_lstm = 8
self.batchnorm = BatchNorm2d(1)
self.chunkcnn = ChunkCNN()
self.lstmcell = LSTMCell(self.nb_lstm, self.nb_lstm)
self.lastlayer = LastLayer(self.nb_lstm)
self.initstate = ones((batch_size, self.nb_lstm), dtype=bfloat16) # default class is 1: environment
def forward(self, x):
hx = self.initstate.detach().clone()
cx = self.initstate.detach().clone()
x = x[:, None , :, :]
x = self.batchnorm(x)
for chunk in split(x, 10, dim=3):
xi = self.chunkcnn(chunk)
hx, cx = self.lstmcell(xi, (hx, cx))
x = self.lastlayer(hx)
# final decision
x = (x>0.5).bfloat16()
return x