pepitolechevalier commited on
Commit
14a8a64
·
verified ·
1 Parent(s): 8abdfb5

Upload pim_module.py

Browse files
Files changed (1) hide show
  1. pim_module.py +9 -8
pim_module.py CHANGED
@@ -230,17 +230,18 @@ class FPN(nn.Module):
230
  nn.Conv2d(inputs[node_name].size(1), fpn_size, 1)
231
  )
232
  elif proj_type == "Linear":
 
 
 
 
 
 
233
  m = nn.Sequential(
234
- in_feat = inputs[node_name]
235
- if isinstance(in_feat, torch.Tensor):
236
- dim = in_feat.size(-1)
237
- else:
238
- raise ValueError(f"Entrée invalide dans FPN: {type(in_feat)} pour node_name={node_name}")
239
- nn.Linear(dim, dim)
240
- # nn.Linear(inputs[node_name].size(-1), inputs[node_name].size(-1)),
241
  nn.ReLU(),
242
- nn.Linear(inputs[node_name].size(-1), fpn_size),
243
  )
 
244
  self.add_module("Proj_"+node_name, m)
245
 
246
  ### upsample module
 
230
  nn.Conv2d(inputs[node_name].size(1), fpn_size, 1)
231
  )
232
  elif proj_type == "Linear":
233
+ in_feat = inputs[node_name]
234
+ if isinstance(in_feat, torch.Tensor):
235
+ dim = in_feat.size(-1)
236
+ else:
237
+ raise ValueError(f"Entrée invalide dans FPN: {type(in_feat)} pour node_name={node_name}")
238
+
239
  m = nn.Sequential(
240
+ nn.Linear(dim, dim),
 
 
 
 
 
 
241
  nn.ReLU(),
242
+ nn.Linear(dim, fpn_size),
243
  )
244
+
245
  self.add_module("Proj_"+node_name, m)
246
 
247
  ### upsample module