File size: 2,262 Bytes
5ad4868 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from torch import ones, split, bfloat16
from torch.nn.functional import relu, sigmoid
from torch.nn import Module, MaxPool1d, Conv1d, Conv2d, Linear, BatchNorm2d, LSTMCell
class ChunkCNN(Module):
def __init__(self):
super(ChunkCNN, self).__init__()
self.pool = MaxPool1d(kernel_size=2, stride=2)
self.conv1 = Conv2d(in_channels=1, out_channels=4, kernel_size=(15,10), stride=(4, 1), padding=(1, 0))
self.conv2 = Conv1d(in_channels=4, out_channels=8, kernel_size=9, stride=2, padding=0)
self.conv3 = Conv1d(in_channels=8, out_channels=16, kernel_size=2, stride=2, padding=1)
self.collapse = Conv1d(in_channels=16, out_channels=1, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = self.conv1(x).squeeze()
x = relu(x)
x = self.pool(x)
x = self.conv2(x)
x = relu(x)
x = self.pool(x)
x = self.conv3(x)
x = relu(x)
x = self.collapse(x).squeeze()
x = relu(x)
return x
class LastLayer(Module):
def __init__(self, inputsize):
super(LastLayer, self).__init__()
self.dense1 = Linear(inputsize, 3)
self.dense2 = Linear(3, 1)
def forward(self, x):
x = self.dense1(x)
x = relu(x)
x = self.dense2(x)
x = sigmoid(x).squeeze()
return x
class ChainsawDetector(Module):
def __init__(self, batch_size):
super(ChainsawDetector, self).__init__()
self.batch_size = batch_size
self.nb_lstm = 8
self.batchnorm = BatchNorm2d(1)
self.chunkcnn = ChunkCNN()
self.lstmcell = LSTMCell(self.nb_lstm, self.nb_lstm)
self.lastlayer = LastLayer(self.nb_lstm)
self.initstate = ones((batch_size, self.nb_lstm), dtype=bfloat16) # default class is 1: environment
def forward(self, x):
hx = self.initstate.detach().clone()
cx = self.initstate.detach().clone()
x = x[:, None , :, :]
x = self.batchnorm(x)
for chunk in split(x, 10, dim=3):
xi = self.chunkcnn(chunk)
hx, cx = self.lstmcell(xi, (hx, cx))
x = self.lastlayer(hx)
# final decision
x = (x>0.5).bfloat16()
return x |