OGSO_IA / pim_module.py
RahmaDev's picture
Upload 856 files
80187e6 verified
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from typing import Union
import copy
class GCNCombiner(nn.Module):
def __init__(self,
total_num_selects: int,
num_classes: int,
inputs: Union[dict, None] = None,
proj_size: Union[int, None] = None,
fpn_size: Union[int, None] = None):
"""
If building backbone without FPN, set fpn_size to None and MUST give
'inputs' and 'proj_size', the reason of these setting is to constrain the
dimension of graph convolutional network input.
"""
super(GCNCombiner, self).__init__()
assert inputs is not None or fpn_size is not None, \
"To build GCN combiner, you must give one features dimension."
### auto-proj
self.fpn_size = fpn_size
if fpn_size is None:
for name in inputs:
if len(name) == 4:
in_size = inputs[name].size(1)
elif len(name) == 3:
in_size = inputs[name].size(2)
else:
raise ValusError("The size of output dimension of previous must be 3 or 4.")
m = nn.Sequential(
nn.Linear(in_size, proj_size),
nn.ReLU(),
nn.Linear(proj_size, proj_size)
)
self.add_module("proj_"+name, m)
self.proj_size = proj_size
else:
self.proj_size = fpn_size
### build one layer structure (with adaptive module)
num_joints = total_num_selects // 64
self.param_pool0 = nn.Linear(total_num_selects, num_joints)
A = torch.eye(num_joints) / 100 + 1 / 100
self.adj1 = nn.Parameter(copy.deepcopy(A))
self.conv1 = nn.Conv1d(self.proj_size, self.proj_size, 1)
self.batch_norm1 = nn.BatchNorm1d(self.proj_size)
self.conv_q1 = nn.Conv1d(self.proj_size, self.proj_size//4, 1)
self.conv_k1 = nn.Conv1d(self.proj_size, self.proj_size//4, 1)
self.alpha1 = nn.Parameter(torch.zeros(1))
### merge information
self.param_pool1 = nn.Linear(num_joints, 1)
#### class predict
self.dropout = nn.Dropout(p=0.1)
self.classifier = nn.Linear(self.proj_size, num_classes)
self.tanh = nn.Tanh()
def forward(self, x):
"""
"""
hs = []
names = []
for name in x:
if "FPN1_" in name:
continue
if self.fpn_size is None:
_tmp = getattr(self, "proj_"+name)(x[name])
else:
_tmp = x[name]
hs.append(_tmp)
names.append([name, _tmp.size()])
hs = torch.cat(hs, dim=1).transpose(1, 2).contiguous() # B, S', C --> B, C, S
# print(hs.size(), names)
hs = self.param_pool0(hs)
### adaptive adjacency
q1 = self.conv_q1(hs).mean(1)
k1 = self.conv_k1(hs).mean(1)
A1 = self.tanh(q1.unsqueeze(-1) - k1.unsqueeze(1))
A1 = self.adj1 + A1 * self.alpha1
### graph convolution
hs = self.conv1(hs)
hs = torch.matmul(hs, A1)
hs = self.batch_norm1(hs)
### predict
hs = self.param_pool1(hs)
hs = self.dropout(hs)
hs = hs.flatten(1)
hs = self.classifier(hs)
return hs
class WeaklySelector(nn.Module):
def __init__(self, inputs: dict, num_classes: int, num_select: dict, fpn_size: Union[int, None] = None):
"""
inputs: dictionary contain torch.Tensors, which comes from backbone
[Tensor1(hidden feature1), Tensor2(hidden feature2)...]
Please note that if len(features.size) equal to 3, the order of dimension must be [B,S,C],
S mean the spatial domain, and if len(features.size) equal to 4, the order must be [B,C,H,W]
"""
super(WeaklySelector, self).__init__()
self.num_select = num_select
self.fpn_size = fpn_size
### build classifier
if self.fpn_size is None:
self.num_classes = num_classes
for name in inputs:
fs_size = inputs[name].size()
if len(fs_size) == 3:
in_size = fs_size[2]
elif len(fs_size) == 4:
in_size = fs_size[1]
m = nn.Linear(in_size, num_classes)
self.add_module("classifier_l_"+name, m)
self.thresholds = {}
for name in inputs:
self.thresholds[name] = []
# def select(self, logits, l_name):
# """
# logits: [B, S, num_classes]
# """
# probs = torch.softmax(logits, dim=-1)
# scores, _ = torch.max(probs, dim=-1)
# _, ids = torch.sort(scores, -1, descending=True)
# sn = self.num_select[l_name]
# s_ids = ids[:, :sn]
# not_s_ids = ids[:, sn:]
# return s_ids.unsqueeze(-1), not_s_ids.unsqueeze(-1)
def forward(self, x, logits=None):
"""
x :
dictionary contain the features maps which
come from your choosen layers.
size must be [B, HxW, C] ([B, S, C]) or [B, C, H, W].
[B,C,H,W] will be transpose to [B, HxW, C] automatically.
"""
if self.fpn_size is None:
logits = {}
selections = {}
for name in x:
# print("[selector]", name, x[name].size())
if "FPN1_" in name:
continue
if len(x[name].size()) == 4:
B, C, H, W = x[name].size()
x[name] = x[name].view(B, C, H*W).permute(0, 2, 1).contiguous()
C = x[name].size(-1)
if self.fpn_size is None:
logits[name] = getattr(self, "classifier_l_"+name)(x[name])
probs = torch.softmax(logits[name], dim=-1)
sum_probs = torch.softmax(logits[name].mean(1), dim=-1)
selections[name] = []
preds_1 = []
preds_0 = []
num_select = self.num_select[name]
for bi in range(logits[name].size(0)):
_, max_ids = torch.max(sum_probs[bi], dim=-1)
confs, ranks = torch.sort(probs[bi, :, max_ids], descending=True)
sf = x[name][bi][ranks[:num_select]]
nf = x[name][bi][ranks[num_select:]] # calculate
selections[name].append(sf) # [num_selected, C]
preds_1.append(logits[name][bi][ranks[:num_select]])
preds_0.append(logits[name][bi][ranks[num_select:]])
if bi >= len(self.thresholds[name]):
self.thresholds[name].append(confs[num_select]) # for initialize
else:
self.thresholds[name][bi] = confs[num_select]
selections[name] = torch.stack(selections[name])
preds_1 = torch.stack(preds_1)
preds_0 = torch.stack(preds_0)
logits["select_"+name] = preds_1
logits["drop_"+name] = preds_0
return selections
class FPN(nn.Module):
def __init__(self, inputs: dict, fpn_size: int, proj_type: str, upsample_type: str):
"""
inputs : dictionary contains torch.Tensor
which comes from backbone output
fpn_size: integer, fpn
proj_type:
in ["Conv", "Linear"]
upsample_type:
in ["Bilinear", "Conv", "Fc"]
for convolution neural network (e.g. ResNet, EfficientNet), recommand 'Bilinear'.
for Vit, "Fc". and Swin-T, "Conv"
"""
super(FPN, self).__init__()
assert proj_type in ["Conv", "Linear"], \
"FPN projection type {} were not support yet, please choose type 'Conv' or 'Linear'".format(proj_type)
assert upsample_type in ["Bilinear", "Conv"], \
"FPN upsample type {} were not support yet, please choose type 'Bilinear' or 'Conv'".format(proj_type)
self.fpn_size = fpn_size
self.upsample_type = upsample_type
inp_names = [name for name in inputs]
for i, node_name in enumerate(inputs):
### projection module
if proj_type == "Conv":
m = nn.Sequential(
nn.Conv2d(inputs[node_name].size(1), inputs[node_name].size(1), 1),
nn.ReLU(),
nn.Conv2d(inputs[node_name].size(1), fpn_size, 1)
)
elif proj_type == "Linear":
in_feat = inputs[node_name]
if isinstance(in_feat, torch.Tensor):
dim = in_feat.size(-1)
else:
raise ValueError(f"Entrée invalide dans FPN: {type(in_feat)} pour node_name={node_name}")
m = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, fpn_size),
)
self.add_module("Proj_"+node_name, m)
### upsample module
if upsample_type == "Conv" and i != 0:
assert len(inputs[node_name].size()) == 3 # B, S, C
in_dim = inputs[node_name].size(1)
out_dim = inputs[inp_names[i-1]].size(1)
# if in_dim != out_dim:
m = nn.Conv1d(in_dim, out_dim, 1) # for spatial domain
# else:
# m = nn.Identity()
self.add_module("Up_"+node_name, m)
if upsample_type == "Bilinear":
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
def upsample_add(self, x0: torch.Tensor, x1: torch.Tensor, x1_name: str):
"""
return Upsample(x1) + x1
"""
if self.upsample_type == "Bilinear":
if x1.size(-1) != x0.size(-1):
x1 = self.upsample(x1)
else:
x1 = getattr(self, "Up_"+x1_name)(x1)
return x1 + x0
def forward(self, x):
"""
x : dictionary
{
"node_name1": feature1,
"node_name2": feature2, ...
}
"""
### project to same dimension
hs = []
for i, name in enumerate(x):
if "FPN1_" in name:
continue
x[name] = getattr(self, "Proj_"+name)(x[name])
hs.append(name)
x["FPN1_" + "layer4"] = x["layer4"]
for i in range(len(hs)-1, 0, -1):
x1_name = hs[i]
x0_name = hs[i-1]
x[x0_name] = self.upsample_add(x[x0_name],
x[x1_name],
x1_name)
x["FPN1_" + x0_name] = x[x0_name]
return x
class FPN_UP(nn.Module):
def __init__(self,
inputs: dict,
fpn_size: int):
super(FPN_UP, self).__init__()
inp_names = [name for name in inputs]
for i, node_name in enumerate(inputs):
### projection module
m = nn.Sequential(
nn.Linear(fpn_size, fpn_size),
nn.ReLU(),
nn.Linear(fpn_size, fpn_size),
)
self.add_module("Proj_"+node_name, m)
### upsample module
if i != (len(inputs) - 1):
assert len(inputs[node_name].size()) == 3 # B, S, C
in_dim = inputs[node_name].size(1)
out_dim = inputs[inp_names[i+1]].size(1)
m = nn.Conv1d(in_dim, out_dim, 1) # for spatial domain
self.add_module("Down_"+node_name, m)
# print("Down_"+node_name, in_dim, out_dim)
"""
Down_layer1 2304 576
Down_layer2 576 144
Down_layer3 144 144
"""
def downsample_add(self, x0: torch.Tensor, x1: torch.Tensor, x0_name: str):
"""
return Upsample(x1) + x1
"""
# print("[downsample_add] Down_" + x0_name)
x0 = getattr(self, "Down_" + x0_name)(x0)
return x1 + x0
def forward(self, x):
"""
x : dictionary
{
"node_name1": feature1,
"node_name2": feature2, ...
}
"""
### project to same dimension
hs = []
for i, name in enumerate(x):
if "FPN1_" in name:
continue
x[name] = getattr(self, "Proj_"+name)(x[name])
hs.append(name)
# print(hs)
for i in range(0, len(hs) - 1):
x0_name = hs[i]
x1_name = hs[i+1]
# print(x0_name, x1_name)
# print(x[x0_name].size(), x[x1_name].size())
x[x1_name] = self.downsample_add(x[x0_name],
x[x1_name],
x0_name)
return x
class PluginMoodel(nn.Module):
def __init__(self,
backbone: torch.nn.Module,
return_nodes: Union[dict, None],
img_size: int,
use_fpn: bool,
fpn_size: Union[int, None],
proj_type: str,
upsample_type: str,
use_selection: bool,
num_classes: int,
num_selects: dict,
use_combiner: bool,
comb_proj_size: Union[int, None]
):
"""
* backbone:
torch.nn.Module class (recommand pretrained on ImageNet or IG-3.5B-17k(provided by FAIR))
* return_nodes:
e.g.
return_nodes = {
# node_name: user-specified key for output dict
'layer1.2.relu_2': 'layer1',
'layer2.3.relu_2': 'layer2',
'layer3.5.relu_2': 'layer3',
'layer4.2.relu_2': 'layer4',
} # you can see the example on https://pytorch.org/vision/main/feature_extraction.html
!!! if using 'Swin-Transformer', please set return_nodes to None
!!! and please set use_fpn to True
* feat_sizes:
tuple or list contain features map size of each layers.
((C, H, W)). e.g. ((1024, 14, 14), (2048, 7, 7))
* use_fpn:
boolean, use features pyramid network or not
* fpn_size:
integer, features pyramid network projection dimension
* num_selects:
num_selects = {
# match user-specified in return_nodes
"layer1": 2048,
"layer2": 512,
"layer3": 128,
"layer4": 32,
}
Note: after selector module (WeaklySelector) , the feature map's size is [B, S', C] which
contained by 'logits' or 'selections' dictionary (S' is selection number, different layer
could be different).
"""
super(PluginMoodel, self).__init__()
### = = = = = Backbone = = = = =
self.return_nodes = return_nodes
if return_nodes is not None:
self.backbone = create_feature_extractor(backbone, return_nodes=return_nodes)
else:
self.backbone = backbone
### get hidden feartues size
rand_in = torch.randn(1, 3, img_size, img_size)
outs = self.backbone(rand_in)
### just original backbone
if not use_fpn and (not use_selection and not use_combiner):
for name in outs:
fs_size = outs[name].size()
if len(fs_size) == 3:
out_size = fs_size.size(-1)
elif len(fs_size) == 4:
out_size = fs_size.size(1)
else:
raise ValusError("The size of output dimension of previous must be 3 or 4.")
self.classifier = nn.Linear(out_size, num_classes)
### = = = = = FPN = = = = =
self.use_fpn = use_fpn
if self.use_fpn:
self.fpn_down = FPN(outs, fpn_size, proj_type, upsample_type)
self.build_fpn_classifier_down(outs, fpn_size, num_classes)
self.fpn_up = FPN_UP(outs, fpn_size)
self.build_fpn_classifier_up(outs, fpn_size, num_classes)
self.fpn_size = fpn_size
### = = = = = Selector = = = = =
self.use_selection = use_selection
if self.use_selection:
w_fpn_size = self.fpn_size if self.use_fpn else None # if not using fpn, build classifier in weakly selector
self.selector = WeaklySelector(outs, num_classes, num_selects, w_fpn_size)
### = = = = = Combiner = = = = =
self.use_combiner = use_combiner
if self.use_combiner:
assert self.use_selection, "Please use selection module before combiner"
if self.use_fpn:
gcn_inputs, gcn_proj_size = None, None
else:
gcn_inputs, gcn_proj_size = outs, comb_proj_size # redundant, fix in future
total_num_selects = sum([num_selects[name] for name in num_selects]) # sum
self.combiner = GCNCombiner(total_num_selects, num_classes, gcn_inputs, gcn_proj_size, self.fpn_size)
def build_fpn_classifier_up(self, inputs: dict, fpn_size: int, num_classes: int):
"""
Teh results of our experiments show that linear classifier in this case may cause some problem.
"""
for name in inputs:
m = nn.Sequential(
nn.Conv1d(fpn_size, fpn_size, 1),
nn.BatchNorm1d(fpn_size),
nn.ReLU(),
nn.Conv1d(fpn_size, num_classes, 1)
)
self.add_module("fpn_classifier_up_"+name, m)
def build_fpn_classifier_down(self, inputs: dict, fpn_size: int, num_classes: int):
"""
Teh results of our experiments show that linear classifier in this case may cause some problem.
"""
for name in inputs:
m = nn.Sequential(
nn.Conv1d(fpn_size, fpn_size, 1),
nn.BatchNorm1d(fpn_size),
nn.ReLU(),
nn.Conv1d(fpn_size, num_classes, 1)
)
self.add_module("fpn_classifier_down_" + name, m)
def forward_backbone(self, x):
return self.backbone(x)
def fpn_predict_down(self, x: dict, logits: dict):
"""
x: [B, C, H, W] or [B, S, C]
[B, C, H, W] --> [B, H*W, C]
"""
for name in x:
if "FPN1_" not in name:
continue
### predict on each features point
if len(x[name].size()) == 4:
B, C, H, W = x[name].size()
logit = x[name].view(B, C, H*W)
elif len(x[name].size()) == 3:
logit = x[name].transpose(1, 2).contiguous()
model_name = name.replace("FPN1_", "")
logits[name] = getattr(self, "fpn_classifier_down_" + model_name)(logit)
logits[name] = logits[name].transpose(1, 2).contiguous() # transpose
def fpn_predict_up(self, x: dict, logits: dict):
"""
x: [B, C, H, W] or [B, S, C]
[B, C, H, W] --> [B, H*W, C]
"""
for name in x:
if "FPN1_" in name:
continue
### predict on each features point
if len(x[name].size()) == 4:
B, C, H, W = x[name].size()
logit = x[name].view(B, C, H*W)
elif len(x[name].size()) == 3:
logit = x[name].transpose(1, 2).contiguous()
model_name = name.replace("FPN1_", "")
logits[name] = getattr(self, "fpn_classifier_up_" + model_name)(logit)
logits[name] = logits[name].transpose(1, 2).contiguous() # transpose
def forward(self, x: torch.Tensor):
logits = {}
x = self.forward_backbone(x)
if self.use_fpn:
x = self.fpn_down(x)
# print([name for name in x])
self.fpn_predict_down(x, logits)
x = self.fpn_up(x)
self.fpn_predict_up(x, logits)
if self.use_selection:
selects = self.selector(x, logits)
if self.use_combiner:
comb_outs = self.combiner(selects)
logits['comb_outs'] = comb_outs
return logits
if self.use_selection or self.fpn:
return logits
### original backbone (only predict final selected layer)
for name in x:
hs = x[name]
if len(hs.size()) == 4:
hs = F.adaptive_avg_pool2d(hs, (1, 1))
hs = hs.flatten(1)
else:
hs = hs.mean(1)
out = self.classifier(hs)
logits['ori_out'] = logits
return