|
from torch import ones, split, bfloat16 |
|
from torch.nn.functional import relu, sigmoid |
|
from torch.nn import Module, MaxPool1d, Conv1d, Conv2d, Linear, BatchNorm2d, LSTMCell |
|
|
|
class ChunkCNN(Module): |
|
def __init__(self): |
|
super(ChunkCNN, self).__init__() |
|
self.pool = MaxPool1d(kernel_size=2, stride=2) |
|
self.conv1 = Conv2d(in_channels=1, out_channels=4, kernel_size=(15,10), stride=(4, 1), padding=(1, 0)) |
|
self.conv2 = Conv1d(in_channels=4, out_channels=8, kernel_size=9, stride=2, padding=0) |
|
self.conv3 = Conv1d(in_channels=8, out_channels=16, kernel_size=2, stride=2, padding=1) |
|
self.collapse = Conv1d(in_channels=16, out_channels=1, kernel_size=1, stride=1, padding=0) |
|
|
|
|
|
def forward(self, x): |
|
x = self.conv1(x).squeeze() |
|
x = relu(x) |
|
x = self.pool(x) |
|
|
|
x = self.conv2(x) |
|
x = relu(x) |
|
x = self.pool(x) |
|
|
|
x = self.conv3(x) |
|
x = relu(x) |
|
|
|
x = self.collapse(x).squeeze() |
|
x = relu(x) |
|
|
|
return x |
|
|
|
|
|
class LastLayer(Module): |
|
def __init__(self, inputsize): |
|
super(LastLayer, self).__init__() |
|
self.dense1 = Linear(inputsize, 3) |
|
self.dense2 = Linear(3, 1) |
|
|
|
|
|
def forward(self, x): |
|
x = self.dense1(x) |
|
x = relu(x) |
|
x = self.dense2(x) |
|
x = sigmoid(x).squeeze() |
|
return x |
|
|
|
|
|
class ChainsawDetector(Module): |
|
def __init__(self, batch_size): |
|
super(ChainsawDetector, self).__init__() |
|
self.batch_size = batch_size |
|
self.nb_lstm = 8 |
|
self.batchnorm = BatchNorm2d(1) |
|
self.chunkcnn = ChunkCNN() |
|
self.lstmcell = LSTMCell(self.nb_lstm, self.nb_lstm) |
|
self.lastlayer = LastLayer(self.nb_lstm) |
|
self.initstate = ones((batch_size, self.nb_lstm), dtype=bfloat16) |
|
|
|
def forward(self, x): |
|
hx = self.initstate.detach().clone() |
|
cx = self.initstate.detach().clone() |
|
|
|
x = x[:, None , :, :] |
|
x = self.batchnorm(x) |
|
|
|
for chunk in split(x, 10, dim=3): |
|
xi = self.chunkcnn(chunk) |
|
hx, cx = self.lstmcell(xi, (hx, cx)) |
|
|
|
x = self.lastlayer(hx) |
|
|
|
|
|
x = (x>0.5).bfloat16() |
|
return x |