HenryTsui commited on
Commit
bfb13af
·
unverified ·
2 Parent(s): fbb556e d15523e

Merge pull request #18 from LucyTuan/MODEL

Browse files

🔨 [Add] RepNCSPELAN and base modules in module.py

Files changed (1) hide show
  1. yolo/model/module.py +121 -36
yolo/model/module.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Tuple
2
 
3
  import torch
4
  from torch import Tensor, nn
@@ -142,49 +142,134 @@ class MultiheadDetection(nn.Module):
142
  #### -- ####
143
  # RepVGG
144
  class RepConv(nn.Module):
145
- # https://github.com/DingXiaoH/RepVGG
 
146
  def __init__(
147
- self, in_channels, out_channels, kernel_size=3, padding=None, stride=1, groups=1, act=nn.SiLU(), deploy=False
 
 
 
 
 
 
148
  ):
149
-
150
  super().__init__()
151
- self.deploy = deploy
152
- self.conv1 = Conv(in_channels, out_channels, kernel_size, stride=stride, groups=groups, activation=False)
153
- self.conv2 = Conv(in_channels, out_channels, 1, stride=stride, groups=groups, activation=False)
154
- self.act = act if isinstance(act, nn.Module) else nn.Identity()
155
 
156
- def forward(self, x):
157
  return self.act(self.conv1(x) + self.conv2(x))
158
 
159
- def forward_fuse(self, x):
160
- return self.act(self.conv(x))
161
-
162
- # to be implement
163
- # def fuse_convs(self):
164
- def fuse_conv_bn(self, conv, bn):
165
-
166
- std = (bn.running_var + bn.eps).sqrt()
167
- bias = bn.bias - bn.running_mean * bn.weight / std
168
-
169
- t = (bn.weight / std).reshape(-1, 1, 1, 1)
170
- weights = conv.weight * t
171
-
172
- bn = nn.Identity()
173
- conv = nn.Conv2d(
174
- in_channels=conv.in_channels,
175
- out_channels=conv.out_channels,
176
- kernel_size=conv.kernel_size,
177
- stride=conv.stride,
178
- padding=conv.padding,
179
- dilation=conv.dilation,
180
- groups=conv.groups,
181
- bias=True,
182
- padding_mode=conv.padding_mode,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  )
184
 
185
- conv.weight = torch.nn.Parameter(weights)
186
- conv.bias = torch.nn.Parameter(bias)
187
- return conv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
 
190
  # ResNet
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
 
3
  import torch
4
  from torch import Tensor, nn
 
142
  #### -- ####
143
  # RepVGG
144
  class RepConv(nn.Module):
145
+ """A convolutional block that combines two convolution layers (kernel and point-wise)."""
146
+
147
  def __init__(
148
+ self,
149
+ in_channels: int,
150
+ out_channels: int,
151
+ kernel_size: _size_2_t = 3,
152
+ *,
153
+ activation: Optional[str] = "SiLU",
154
+ **kwargs
155
  ):
 
156
  super().__init__()
157
+ self.act = get_activation(activation)
158
+ self.conv1 = Conv(in_channels, out_channels, kernel_size, activation=False, **kwargs)
159
+ self.conv2 = Conv(in_channels, out_channels, 1, activation=False, **kwargs)
 
160
 
161
+ def forward(self, x: Tensor) -> Tensor:
162
  return self.act(self.conv1(x) + self.conv2(x))
163
 
164
+
165
+ class RepNBottleneck(nn.Module):
166
+ """A bottleneck block with optional residual connections."""
167
+
168
+ def __init__(
169
+ self,
170
+ in_channels: int,
171
+ out_channels: int,
172
+ *,
173
+ kernel_size: Tuple[int, int] = (3, 3),
174
+ residual: bool = True,
175
+ expand: float = 1.0,
176
+ **kwargs
177
+ ):
178
+ super().__init__()
179
+ neck_channels = int(out_channels * expand)
180
+ self.conv1 = RepConv(in_channels, neck_channels, kernel_size[0], **kwargs)
181
+ self.conv2 = Conv(neck_channels, out_channels, kernel_size[1], **kwargs)
182
+ self.residual = residual
183
+
184
+ if residual and (in_channels != out_channels):
185
+ self.residual = False
186
+ logging.warning("Residual is turned off since in_channels is not equal to out_channels.")
187
+
188
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
189
+ y = self.conv2(self.conv1(x))
190
+ return x + y if self.residual else y
191
+
192
+
193
+ class RepNCSP(nn.Module):
194
+ """RepNCSP block with convolutions, split, and bottleneck processing."""
195
+
196
+ def __init__(
197
+ self,
198
+ in_channels: int,
199
+ out_channels: int,
200
+ kernel_size: int = 1,
201
+ *,
202
+ csp_expand: float = 0.5,
203
+ repeat_num: int = 1,
204
+ bottleneck_args: Optional[Dict[str, Any]] = None,
205
+ **kwargs
206
+ ):
207
+ super().__init__()
208
+
209
+ if bottleneck_args is None:
210
+ bottleneck_args = {"kernel_size": (3, 3), "residual": True, "expand": 0.5}
211
+
212
+ neck_channels = int(out_channels * csp_expand)
213
+ self.conv1 = Conv(in_channels, neck_channels, kernel_size, **kwargs)
214
+ self.conv2 = Conv(in_channels, neck_channels, kernel_size, **kwargs)
215
+ self.conv3 = Conv(2 * neck_channels, out_channels, kernel_size, **kwargs)
216
+
217
+ self.bottleneck_block = nn.Sequential(
218
+ *[RepNBottleneck(neck_channels, neck_channels, **bottleneck_args) for _ in range(repeat_num)]
219
  )
220
 
221
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
222
+ input_features = self.conv1(x)
223
+ split_features = self.conv2(x)
224
+ bottleneck_output = self.bottleneck_block(input_features)
225
+ return self.conv3(torch.cat((bottleneck_output, split_features), dim=1))
226
+
227
+
228
+ class RepNCSPELAN(nn.Module):
229
+ """RepNCSPELAN block combining RepNCSP blocks with ELAN structure."""
230
+
231
+ def __init__(
232
+ self,
233
+ *,
234
+ in_channels: int,
235
+ out_channels: int,
236
+ partition_channels: int,
237
+ process_channels: int,
238
+ expand: float,
239
+ repncsp_args: Optional[Dict[str, Any]] = None,
240
+ bottleneck_args: Optional[Dict[str, Any]] = None,
241
+ **kwargs
242
+ ):
243
+ super().__init__()
244
+
245
+ if repncsp_args is None:
246
+ repncsp_args = {}
247
+
248
+ self.conv1 = Conv(in_channels, partition_channels, 1, **kwargs)
249
+ self.conv2 = nn.Sequential(
250
+ RepNCSP(
251
+ partition_channels // 2,
252
+ process_channels,
253
+ csp_expand=expand,
254
+ bottleneck_args=bottleneck_args,
255
+ **repncsp_args
256
+ ),
257
+ Conv(process_channels, process_channels, 3, padding=1, **kwargs),
258
+ )
259
+ self.conv3 = nn.Sequential(
260
+ RepNCSP(
261
+ process_channels, process_channels, csp_expand=expand, bottleneck_args=bottleneck_args, **repncsp_args
262
+ ),
263
+ Conv(process_channels, process_channels, 3, padding=1, **kwargs),
264
+ )
265
+ self.conv4 = Conv(partition_channels + 2 * process_channels, out_channels, 1, **kwargs)
266
+
267
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
268
+ partition1, partition2 = self.conv1(x).chunk(2, 1)
269
+ csp_output1 = self.conv2(partition2)
270
+ csp_output2 = self.conv3(csp_output1)
271
+ concat = torch.cat([partition1, partition2, csp_output1, csp_output2], dim=1)
272
+ return self.conv4(concat)
273
 
274
 
275
  # ResNet