Spaces:
Runtime error
Runtime error
Upload pim_module.py
Browse files- pim_module.py +9 -8
pim_module.py
CHANGED
@@ -230,17 +230,18 @@ class FPN(nn.Module):
|
|
230 |
nn.Conv2d(inputs[node_name].size(1), fpn_size, 1)
|
231 |
)
|
232 |
elif proj_type == "Linear":
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
m = nn.Sequential(
|
234 |
-
|
235 |
-
if isinstance(in_feat, torch.Tensor):
|
236 |
-
dim = in_feat.size(-1)
|
237 |
-
else:
|
238 |
-
raise ValueError(f"Entrée invalide dans FPN: {type(in_feat)} pour node_name={node_name}")
|
239 |
-
nn.Linear(dim, dim)
|
240 |
-
# nn.Linear(inputs[node_name].size(-1), inputs[node_name].size(-1)),
|
241 |
nn.ReLU(),
|
242 |
-
nn.Linear(
|
243 |
)
|
|
|
244 |
self.add_module("Proj_"+node_name, m)
|
245 |
|
246 |
### upsample module
|
|
|
230 |
nn.Conv2d(inputs[node_name].size(1), fpn_size, 1)
|
231 |
)
|
232 |
elif proj_type == "Linear":
|
233 |
+
in_feat = inputs[node_name]
|
234 |
+
if isinstance(in_feat, torch.Tensor):
|
235 |
+
dim = in_feat.size(-1)
|
236 |
+
else:
|
237 |
+
raise ValueError(f"Entrée invalide dans FPN: {type(in_feat)} pour node_name={node_name}")
|
238 |
+
|
239 |
m = nn.Sequential(
|
240 |
+
nn.Linear(dim, dim),
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
nn.ReLU(),
|
242 |
+
nn.Linear(dim, fpn_size),
|
243 |
)
|
244 |
+
|
245 |
self.add_module("Proj_"+node_name, m)
|
246 |
|
247 |
### upsample module
|