from torch import ones, split, bfloat16 from torch.nn.functional import relu, sigmoid from torch.nn import Module, MaxPool1d, Conv1d, Conv2d, Linear, BatchNorm2d, LSTMCell class ChunkCNN(Module): def __init__(self): super(ChunkCNN, self).__init__() self.pool = MaxPool1d(kernel_size=2, stride=2) self.conv1 = Conv2d(in_channels=1, out_channels=4, kernel_size=(15,10), stride=(4, 1), padding=(1, 0)) self.conv2 = Conv1d(in_channels=4, out_channels=8, kernel_size=9, stride=2, padding=0) self.conv3 = Conv1d(in_channels=8, out_channels=16, kernel_size=2, stride=2, padding=1) self.collapse = Conv1d(in_channels=16, out_channels=1, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.conv1(x).squeeze() x = relu(x) x = self.pool(x) x = self.conv2(x) x = relu(x) x = self.pool(x) x = self.conv3(x) x = relu(x) x = self.collapse(x).squeeze() x = relu(x) return x class LastLayer(Module): def __init__(self, inputsize): super(LastLayer, self).__init__() self.dense1 = Linear(inputsize, 3) self.dense2 = Linear(3, 1) def forward(self, x): x = self.dense1(x) x = relu(x) x = self.dense2(x) x = sigmoid(x).squeeze() return x class ChainsawDetector(Module): def __init__(self, batch_size): super(ChainsawDetector, self).__init__() self.batch_size = batch_size self.nb_lstm = 8 self.batchnorm = BatchNorm2d(1) self.chunkcnn = ChunkCNN() self.lstmcell = LSTMCell(self.nb_lstm, self.nb_lstm) self.lastlayer = LastLayer(self.nb_lstm) self.initstate = ones((batch_size, self.nb_lstm), dtype=bfloat16) # default class is 1: environment def forward(self, x): hx = self.initstate.detach().clone() cx = self.initstate.detach().clone() x = x[:, None , :, :] x = self.batchnorm(x) for chunk in split(x, 10, dim=3): xi = self.chunkcnn(chunk) hx, cx = self.lstmcell(xi, (hx, cx)) x = self.lastlayer(hx) # final decision x = (x>0.5).bfloat16() return x