|
import torch |
|
import torch.nn as nn |
|
import torchvision.models as models |
|
import torch.nn.functional as F |
|
from torchvision.models.feature_extraction import get_graph_node_names |
|
from torchvision.models.feature_extraction import create_feature_extractor |
|
from typing import Union |
|
import copy |
|
|
|
class GCNCombiner(nn.Module): |
|
|
|
def __init__(self, |
|
total_num_selects: int, |
|
num_classes: int, |
|
inputs: Union[dict, None] = None, |
|
proj_size: Union[int, None] = None, |
|
fpn_size: Union[int, None] = None): |
|
""" |
|
If building backbone without FPN, set fpn_size to None and MUST give |
|
'inputs' and 'proj_size', the reason of these setting is to constrain the |
|
dimension of graph convolutional network input. |
|
""" |
|
super(GCNCombiner, self).__init__() |
|
|
|
assert inputs is not None or fpn_size is not None, \ |
|
"To build GCN combiner, you must give one features dimension." |
|
|
|
|
|
self.fpn_size = fpn_size |
|
if fpn_size is None: |
|
for name in inputs: |
|
if len(name) == 4: |
|
in_size = inputs[name].size(1) |
|
elif len(name) == 3: |
|
in_size = inputs[name].size(2) |
|
else: |
|
raise ValusError("The size of output dimension of previous must be 3 or 4.") |
|
m = nn.Sequential( |
|
nn.Linear(in_size, proj_size), |
|
nn.ReLU(), |
|
nn.Linear(proj_size, proj_size) |
|
) |
|
self.add_module("proj_"+name, m) |
|
self.proj_size = proj_size |
|
else: |
|
self.proj_size = fpn_size |
|
|
|
|
|
num_joints = total_num_selects // 64 |
|
|
|
self.param_pool0 = nn.Linear(total_num_selects, num_joints) |
|
|
|
A = torch.eye(num_joints) / 100 + 1 / 100 |
|
self.adj1 = nn.Parameter(copy.deepcopy(A)) |
|
self.conv1 = nn.Conv1d(self.proj_size, self.proj_size, 1) |
|
self.batch_norm1 = nn.BatchNorm1d(self.proj_size) |
|
|
|
self.conv_q1 = nn.Conv1d(self.proj_size, self.proj_size//4, 1) |
|
self.conv_k1 = nn.Conv1d(self.proj_size, self.proj_size//4, 1) |
|
self.alpha1 = nn.Parameter(torch.zeros(1)) |
|
|
|
|
|
self.param_pool1 = nn.Linear(num_joints, 1) |
|
|
|
|
|
self.dropout = nn.Dropout(p=0.1) |
|
self.classifier = nn.Linear(self.proj_size, num_classes) |
|
|
|
self.tanh = nn.Tanh() |
|
|
|
def forward(self, x): |
|
""" |
|
""" |
|
hs = [] |
|
names = [] |
|
for name in x: |
|
if "FPN1_" in name: |
|
continue |
|
if self.fpn_size is None: |
|
_tmp = getattr(self, "proj_"+name)(x[name]) |
|
else: |
|
_tmp = x[name] |
|
hs.append(_tmp) |
|
names.append([name, _tmp.size()]) |
|
|
|
hs = torch.cat(hs, dim=1).transpose(1, 2).contiguous() |
|
|
|
hs = self.param_pool0(hs) |
|
|
|
q1 = self.conv_q1(hs).mean(1) |
|
k1 = self.conv_k1(hs).mean(1) |
|
A1 = self.tanh(q1.unsqueeze(-1) - k1.unsqueeze(1)) |
|
A1 = self.adj1 + A1 * self.alpha1 |
|
|
|
hs = self.conv1(hs) |
|
hs = torch.matmul(hs, A1) |
|
hs = self.batch_norm1(hs) |
|
|
|
hs = self.param_pool1(hs) |
|
hs = self.dropout(hs) |
|
hs = hs.flatten(1) |
|
hs = self.classifier(hs) |
|
|
|
return hs |
|
|
|
class WeaklySelector(nn.Module): |
|
|
|
def __init__(self, inputs: dict, num_classes: int, num_select: dict, fpn_size: Union[int, None] = None): |
|
""" |
|
inputs: dictionary contain torch.Tensors, which comes from backbone |
|
[Tensor1(hidden feature1), Tensor2(hidden feature2)...] |
|
Please note that if len(features.size) equal to 3, the order of dimension must be [B,S,C], |
|
S mean the spatial domain, and if len(features.size) equal to 4, the order must be [B,C,H,W] |
|
""" |
|
super(WeaklySelector, self).__init__() |
|
|
|
self.num_select = num_select |
|
|
|
self.fpn_size = fpn_size |
|
|
|
if self.fpn_size is None: |
|
self.num_classes = num_classes |
|
for name in inputs: |
|
fs_size = inputs[name].size() |
|
if len(fs_size) == 3: |
|
in_size = fs_size[2] |
|
elif len(fs_size) == 4: |
|
in_size = fs_size[1] |
|
m = nn.Linear(in_size, num_classes) |
|
self.add_module("classifier_l_"+name, m) |
|
|
|
self.thresholds = {} |
|
for name in inputs: |
|
self.thresholds[name] = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, logits=None): |
|
""" |
|
x : |
|
dictionary contain the features maps which |
|
come from your choosen layers. |
|
size must be [B, HxW, C] ([B, S, C]) or [B, C, H, W]. |
|
[B,C,H,W] will be transpose to [B, HxW, C] automatically. |
|
""" |
|
if self.fpn_size is None: |
|
logits = {} |
|
selections = {} |
|
for name in x: |
|
|
|
if "FPN1_" in name: |
|
continue |
|
if len(x[name].size()) == 4: |
|
B, C, H, W = x[name].size() |
|
x[name] = x[name].view(B, C, H*W).permute(0, 2, 1).contiguous() |
|
C = x[name].size(-1) |
|
if self.fpn_size is None: |
|
logits[name] = getattr(self, "classifier_l_"+name)(x[name]) |
|
|
|
probs = torch.softmax(logits[name], dim=-1) |
|
sum_probs = torch.softmax(logits[name].mean(1), dim=-1) |
|
selections[name] = [] |
|
preds_1 = [] |
|
preds_0 = [] |
|
num_select = self.num_select[name] |
|
for bi in range(logits[name].size(0)): |
|
_, max_ids = torch.max(sum_probs[bi], dim=-1) |
|
confs, ranks = torch.sort(probs[bi, :, max_ids], descending=True) |
|
sf = x[name][bi][ranks[:num_select]] |
|
nf = x[name][bi][ranks[num_select:]] |
|
selections[name].append(sf) |
|
preds_1.append(logits[name][bi][ranks[:num_select]]) |
|
preds_0.append(logits[name][bi][ranks[num_select:]]) |
|
|
|
if bi >= len(self.thresholds[name]): |
|
self.thresholds[name].append(confs[num_select]) |
|
else: |
|
self.thresholds[name][bi] = confs[num_select] |
|
|
|
selections[name] = torch.stack(selections[name]) |
|
preds_1 = torch.stack(preds_1) |
|
preds_0 = torch.stack(preds_0) |
|
|
|
logits["select_"+name] = preds_1 |
|
logits["drop_"+name] = preds_0 |
|
|
|
return selections |
|
|
|
|
|
class FPN(nn.Module): |
|
|
|
def __init__(self, inputs: dict, fpn_size: int, proj_type: str, upsample_type: str): |
|
""" |
|
inputs : dictionary contains torch.Tensor |
|
which comes from backbone output |
|
fpn_size: integer, fpn |
|
proj_type: |
|
in ["Conv", "Linear"] |
|
upsample_type: |
|
in ["Bilinear", "Conv", "Fc"] |
|
for convolution neural network (e.g. ResNet, EfficientNet), recommand 'Bilinear'. |
|
for Vit, "Fc". and Swin-T, "Conv" |
|
""" |
|
super(FPN, self).__init__() |
|
assert proj_type in ["Conv", "Linear"], \ |
|
"FPN projection type {} were not support yet, please choose type 'Conv' or 'Linear'".format(proj_type) |
|
assert upsample_type in ["Bilinear", "Conv"], \ |
|
"FPN upsample type {} were not support yet, please choose type 'Bilinear' or 'Conv'".format(proj_type) |
|
|
|
self.fpn_size = fpn_size |
|
self.upsample_type = upsample_type |
|
inp_names = [name for name in inputs] |
|
|
|
for i, node_name in enumerate(inputs): |
|
|
|
if proj_type == "Conv": |
|
m = nn.Sequential( |
|
nn.Conv2d(inputs[node_name].size(1), inputs[node_name].size(1), 1), |
|
nn.ReLU(), |
|
nn.Conv2d(inputs[node_name].size(1), fpn_size, 1) |
|
) |
|
elif proj_type == "Linear": |
|
in_feat = inputs[node_name] |
|
if isinstance(in_feat, torch.Tensor): |
|
dim = in_feat.size(-1) |
|
else: |
|
raise ValueError(f"Entrée invalide dans FPN: {type(in_feat)} pour node_name={node_name}") |
|
|
|
m = nn.Sequential( |
|
nn.Linear(dim, dim), |
|
nn.ReLU(), |
|
nn.Linear(dim, fpn_size), |
|
) |
|
|
|
self.add_module("Proj_"+node_name, m) |
|
|
|
|
|
if upsample_type == "Conv" and i != 0: |
|
assert len(inputs[node_name].size()) == 3 |
|
in_dim = inputs[node_name].size(1) |
|
out_dim = inputs[inp_names[i-1]].size(1) |
|
|
|
m = nn.Conv1d(in_dim, out_dim, 1) |
|
|
|
|
|
self.add_module("Up_"+node_name, m) |
|
|
|
if upsample_type == "Bilinear": |
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') |
|
|
|
def upsample_add(self, x0: torch.Tensor, x1: torch.Tensor, x1_name: str): |
|
""" |
|
return Upsample(x1) + x1 |
|
""" |
|
if self.upsample_type == "Bilinear": |
|
if x1.size(-1) != x0.size(-1): |
|
x1 = self.upsample(x1) |
|
else: |
|
x1 = getattr(self, "Up_"+x1_name)(x1) |
|
return x1 + x0 |
|
|
|
def forward(self, x): |
|
""" |
|
x : dictionary |
|
{ |
|
"node_name1": feature1, |
|
"node_name2": feature2, ... |
|
} |
|
""" |
|
|
|
hs = [] |
|
for i, name in enumerate(x): |
|
if "FPN1_" in name: |
|
continue |
|
x[name] = getattr(self, "Proj_"+name)(x[name]) |
|
hs.append(name) |
|
|
|
x["FPN1_" + "layer4"] = x["layer4"] |
|
|
|
for i in range(len(hs)-1, 0, -1): |
|
x1_name = hs[i] |
|
x0_name = hs[i-1] |
|
x[x0_name] = self.upsample_add(x[x0_name], |
|
x[x1_name], |
|
x1_name) |
|
x["FPN1_" + x0_name] = x[x0_name] |
|
|
|
return x |
|
|
|
|
|
class FPN_UP(nn.Module): |
|
|
|
def __init__(self, |
|
inputs: dict, |
|
fpn_size: int): |
|
super(FPN_UP, self).__init__() |
|
|
|
inp_names = [name for name in inputs] |
|
|
|
for i, node_name in enumerate(inputs): |
|
|
|
m = nn.Sequential( |
|
nn.Linear(fpn_size, fpn_size), |
|
nn.ReLU(), |
|
nn.Linear(fpn_size, fpn_size), |
|
) |
|
self.add_module("Proj_"+node_name, m) |
|
|
|
|
|
if i != (len(inputs) - 1): |
|
assert len(inputs[node_name].size()) == 3 |
|
in_dim = inputs[node_name].size(1) |
|
out_dim = inputs[inp_names[i+1]].size(1) |
|
m = nn.Conv1d(in_dim, out_dim, 1) |
|
self.add_module("Down_"+node_name, m) |
|
|
|
""" |
|
Down_layer1 2304 576 |
|
Down_layer2 576 144 |
|
Down_layer3 144 144 |
|
""" |
|
|
|
def downsample_add(self, x0: torch.Tensor, x1: torch.Tensor, x0_name: str): |
|
""" |
|
return Upsample(x1) + x1 |
|
""" |
|
|
|
x0 = getattr(self, "Down_" + x0_name)(x0) |
|
return x1 + x0 |
|
|
|
def forward(self, x): |
|
""" |
|
x : dictionary |
|
{ |
|
"node_name1": feature1, |
|
"node_name2": feature2, ... |
|
} |
|
""" |
|
|
|
hs = [] |
|
for i, name in enumerate(x): |
|
if "FPN1_" in name: |
|
continue |
|
x[name] = getattr(self, "Proj_"+name)(x[name]) |
|
hs.append(name) |
|
|
|
|
|
for i in range(0, len(hs) - 1): |
|
x0_name = hs[i] |
|
x1_name = hs[i+1] |
|
|
|
|
|
x[x1_name] = self.downsample_add(x[x0_name], |
|
x[x1_name], |
|
x0_name) |
|
return x |
|
|
|
|
|
|
|
|
|
class PluginMoodel(nn.Module): |
|
|
|
def __init__(self, |
|
backbone: torch.nn.Module, |
|
return_nodes: Union[dict, None], |
|
img_size: int, |
|
use_fpn: bool, |
|
fpn_size: Union[int, None], |
|
proj_type: str, |
|
upsample_type: str, |
|
use_selection: bool, |
|
num_classes: int, |
|
num_selects: dict, |
|
use_combiner: bool, |
|
comb_proj_size: Union[int, None] |
|
): |
|
""" |
|
* backbone: |
|
torch.nn.Module class (recommand pretrained on ImageNet or IG-3.5B-17k(provided by FAIR)) |
|
* return_nodes: |
|
e.g. |
|
return_nodes = { |
|
# node_name: user-specified key for output dict |
|
'layer1.2.relu_2': 'layer1', |
|
'layer2.3.relu_2': 'layer2', |
|
'layer3.5.relu_2': 'layer3', |
|
'layer4.2.relu_2': 'layer4', |
|
} # you can see the example on https://pytorch.org/vision/main/feature_extraction.html |
|
!!! if using 'Swin-Transformer', please set return_nodes to None |
|
!!! and please set use_fpn to True |
|
* feat_sizes: |
|
tuple or list contain features map size of each layers. |
|
((C, H, W)). e.g. ((1024, 14, 14), (2048, 7, 7)) |
|
* use_fpn: |
|
boolean, use features pyramid network or not |
|
* fpn_size: |
|
integer, features pyramid network projection dimension |
|
* num_selects: |
|
num_selects = { |
|
# match user-specified in return_nodes |
|
"layer1": 2048, |
|
"layer2": 512, |
|
"layer3": 128, |
|
"layer4": 32, |
|
} |
|
Note: after selector module (WeaklySelector) , the feature map's size is [B, S', C] which |
|
contained by 'logits' or 'selections' dictionary (S' is selection number, different layer |
|
could be different). |
|
""" |
|
super(PluginMoodel, self).__init__() |
|
|
|
|
|
self.return_nodes = return_nodes |
|
if return_nodes is not None: |
|
self.backbone = create_feature_extractor(backbone, return_nodes=return_nodes) |
|
else: |
|
self.backbone = backbone |
|
|
|
|
|
rand_in = torch.randn(1, 3, img_size, img_size) |
|
outs = self.backbone(rand_in) |
|
|
|
|
|
if not use_fpn and (not use_selection and not use_combiner): |
|
for name in outs: |
|
fs_size = outs[name].size() |
|
if len(fs_size) == 3: |
|
out_size = fs_size.size(-1) |
|
elif len(fs_size) == 4: |
|
out_size = fs_size.size(1) |
|
else: |
|
raise ValusError("The size of output dimension of previous must be 3 or 4.") |
|
self.classifier = nn.Linear(out_size, num_classes) |
|
|
|
|
|
self.use_fpn = use_fpn |
|
if self.use_fpn: |
|
self.fpn_down = FPN(outs, fpn_size, proj_type, upsample_type) |
|
self.build_fpn_classifier_down(outs, fpn_size, num_classes) |
|
self.fpn_up = FPN_UP(outs, fpn_size) |
|
self.build_fpn_classifier_up(outs, fpn_size, num_classes) |
|
|
|
self.fpn_size = fpn_size |
|
|
|
|
|
self.use_selection = use_selection |
|
if self.use_selection: |
|
w_fpn_size = self.fpn_size if self.use_fpn else None |
|
self.selector = WeaklySelector(outs, num_classes, num_selects, w_fpn_size) |
|
|
|
|
|
self.use_combiner = use_combiner |
|
if self.use_combiner: |
|
assert self.use_selection, "Please use selection module before combiner" |
|
if self.use_fpn: |
|
gcn_inputs, gcn_proj_size = None, None |
|
else: |
|
gcn_inputs, gcn_proj_size = outs, comb_proj_size |
|
total_num_selects = sum([num_selects[name] for name in num_selects]) |
|
self.combiner = GCNCombiner(total_num_selects, num_classes, gcn_inputs, gcn_proj_size, self.fpn_size) |
|
|
|
def build_fpn_classifier_up(self, inputs: dict, fpn_size: int, num_classes: int): |
|
""" |
|
Teh results of our experiments show that linear classifier in this case may cause some problem. |
|
""" |
|
for name in inputs: |
|
m = nn.Sequential( |
|
nn.Conv1d(fpn_size, fpn_size, 1), |
|
nn.BatchNorm1d(fpn_size), |
|
nn.ReLU(), |
|
nn.Conv1d(fpn_size, num_classes, 1) |
|
) |
|
self.add_module("fpn_classifier_up_"+name, m) |
|
|
|
def build_fpn_classifier_down(self, inputs: dict, fpn_size: int, num_classes: int): |
|
""" |
|
Teh results of our experiments show that linear classifier in this case may cause some problem. |
|
""" |
|
for name in inputs: |
|
m = nn.Sequential( |
|
nn.Conv1d(fpn_size, fpn_size, 1), |
|
nn.BatchNorm1d(fpn_size), |
|
nn.ReLU(), |
|
nn.Conv1d(fpn_size, num_classes, 1) |
|
) |
|
self.add_module("fpn_classifier_down_" + name, m) |
|
|
|
def forward_backbone(self, x): |
|
return self.backbone(x) |
|
|
|
def fpn_predict_down(self, x: dict, logits: dict): |
|
""" |
|
x: [B, C, H, W] or [B, S, C] |
|
[B, C, H, W] --> [B, H*W, C] |
|
""" |
|
for name in x: |
|
if "FPN1_" not in name: |
|
continue |
|
|
|
if len(x[name].size()) == 4: |
|
B, C, H, W = x[name].size() |
|
logit = x[name].view(B, C, H*W) |
|
elif len(x[name].size()) == 3: |
|
logit = x[name].transpose(1, 2).contiguous() |
|
model_name = name.replace("FPN1_", "") |
|
logits[name] = getattr(self, "fpn_classifier_down_" + model_name)(logit) |
|
logits[name] = logits[name].transpose(1, 2).contiguous() |
|
|
|
def fpn_predict_up(self, x: dict, logits: dict): |
|
""" |
|
x: [B, C, H, W] or [B, S, C] |
|
[B, C, H, W] --> [B, H*W, C] |
|
""" |
|
for name in x: |
|
if "FPN1_" in name: |
|
continue |
|
|
|
if len(x[name].size()) == 4: |
|
B, C, H, W = x[name].size() |
|
logit = x[name].view(B, C, H*W) |
|
elif len(x[name].size()) == 3: |
|
logit = x[name].transpose(1, 2).contiguous() |
|
model_name = name.replace("FPN1_", "") |
|
logits[name] = getattr(self, "fpn_classifier_up_" + model_name)(logit) |
|
logits[name] = logits[name].transpose(1, 2).contiguous() |
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
logits = {} |
|
|
|
x = self.forward_backbone(x) |
|
|
|
if self.use_fpn: |
|
x = self.fpn_down(x) |
|
|
|
self.fpn_predict_down(x, logits) |
|
x = self.fpn_up(x) |
|
self.fpn_predict_up(x, logits) |
|
|
|
if self.use_selection: |
|
selects = self.selector(x, logits) |
|
|
|
if self.use_combiner: |
|
comb_outs = self.combiner(selects) |
|
logits['comb_outs'] = comb_outs |
|
return logits |
|
|
|
if self.use_selection or self.fpn: |
|
return logits |
|
|
|
|
|
for name in x: |
|
hs = x[name] |
|
|
|
if len(hs.size()) == 4: |
|
hs = F.adaptive_avg_pool2d(hs, (1, 1)) |
|
hs = hs.flatten(1) |
|
else: |
|
hs = hs.mean(1) |
|
out = self.classifier(hs) |
|
logits['ori_out'] = logits |
|
|
|
return |
|
|