File size: 2,262 Bytes
5ad4868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from torch import ones, split, bfloat16
from torch.nn.functional import relu, sigmoid
from torch.nn import Module, MaxPool1d, Conv1d, Conv2d, Linear, BatchNorm2d, LSTMCell

class ChunkCNN(Module):
    def __init__(self):
        super(ChunkCNN, self).__init__()
        self.pool = MaxPool1d(kernel_size=2, stride=2)
        self.conv1 = Conv2d(in_channels=1, out_channels=4, kernel_size=(15,10), stride=(4, 1), padding=(1, 0))
        self.conv2 = Conv1d(in_channels=4, out_channels=8, kernel_size=9, stride=2, padding=0)
        self.conv3 = Conv1d(in_channels=8, out_channels=16, kernel_size=2, stride=2, padding=1)
        self.collapse = Conv1d(in_channels=16, out_channels=1, kernel_size=1, stride=1, padding=0)


    def forward(self, x):
        x = self.conv1(x).squeeze()
        x = relu(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = relu(x)
        x = self.pool(x)

        x = self.conv3(x)
        x = relu(x)

        x = self.collapse(x).squeeze()
        x = relu(x)

        return x
    

class LastLayer(Module):
    def __init__(self, inputsize):
        super(LastLayer, self).__init__()
        self.dense1 = Linear(inputsize, 3)
        self.dense2 = Linear(3, 1)
        

    def forward(self, x):
        x = self.dense1(x)
        x = relu(x)
        x = self.dense2(x)
        x = sigmoid(x).squeeze()
        return x
    

class ChainsawDetector(Module):
    def __init__(self, batch_size):
        super(ChainsawDetector, self).__init__()
        self.batch_size = batch_size
        self.nb_lstm = 8
        self.batchnorm = BatchNorm2d(1)
        self.chunkcnn = ChunkCNN()
        self.lstmcell = LSTMCell(self.nb_lstm, self.nb_lstm) 
        self.lastlayer = LastLayer(self.nb_lstm)
        self.initstate = ones((batch_size, self.nb_lstm), dtype=bfloat16) # default class is 1: environment

    def forward(self, x):
        hx = self.initstate.detach().clone()
        cx = self.initstate.detach().clone()

        x = x[:, None , :, :]
        x = self.batchnorm(x)

        for chunk in split(x, 10, dim=3):
            xi = self.chunkcnn(chunk)
            hx, cx = self.lstmcell(xi, (hx, cx))

        x = self.lastlayer(hx)

        # final decision
        x = (x>0.5).bfloat16()
        return x