Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
import torch.nn.functional as F | |
from torchvision.models.feature_extraction import get_graph_node_names | |
from torchvision.models.feature_extraction import create_feature_extractor | |
from typing import Union | |
import copy | |
class GCNCombiner(nn.Module): | |
def __init__(self, | |
total_num_selects: int, | |
num_classes: int, | |
inputs: Union[dict, None] = None, | |
proj_size: Union[int, None] = None, | |
fpn_size: Union[int, None] = None): | |
""" | |
If building backbone without FPN, set fpn_size to None and MUST give | |
'inputs' and 'proj_size', the reason of these setting is to constrain the | |
dimension of graph convolutional network input. | |
""" | |
super(GCNCombiner, self).__init__() | |
assert inputs is not None or fpn_size is not None, \ | |
"To build GCN combiner, you must give one features dimension." | |
### auto-proj | |
self.fpn_size = fpn_size | |
if fpn_size is None: | |
for name in inputs: | |
if len(name) == 4: | |
in_size = inputs[name].size(1) | |
elif len(name) == 3: | |
in_size = inputs[name].size(2) | |
else: | |
raise ValusError("The size of output dimension of previous must be 3 or 4.") | |
m = nn.Sequential( | |
nn.Linear(in_size, proj_size), | |
nn.ReLU(), | |
nn.Linear(proj_size, proj_size) | |
) | |
self.add_module("proj_"+name, m) | |
self.proj_size = proj_size | |
else: | |
self.proj_size = fpn_size | |
### build one layer structure (with adaptive module) | |
num_joints = total_num_selects // 64 | |
self.param_pool0 = nn.Linear(total_num_selects, num_joints) | |
A = torch.eye(num_joints) / 100 + 1 / 100 | |
self.adj1 = nn.Parameter(copy.deepcopy(A)) | |
self.conv1 = nn.Conv1d(self.proj_size, self.proj_size, 1) | |
self.batch_norm1 = nn.BatchNorm1d(self.proj_size) | |
self.conv_q1 = nn.Conv1d(self.proj_size, self.proj_size//4, 1) | |
self.conv_k1 = nn.Conv1d(self.proj_size, self.proj_size//4, 1) | |
self.alpha1 = nn.Parameter(torch.zeros(1)) | |
### merge information | |
self.param_pool1 = nn.Linear(num_joints, 1) | |
#### class predict | |
self.dropout = nn.Dropout(p=0.1) | |
self.classifier = nn.Linear(self.proj_size, num_classes) | |
self.tanh = nn.Tanh() | |
def forward(self, x): | |
""" | |
""" | |
hs = [] | |
names = [] | |
for name in x: | |
if "FPN1_" in name: | |
continue | |
if self.fpn_size is None: | |
_tmp = getattr(self, "proj_"+name)(x[name]) | |
else: | |
_tmp = x[name] | |
hs.append(_tmp) | |
names.append([name, _tmp.size()]) | |
hs = torch.cat(hs, dim=1).transpose(1, 2).contiguous() # B, S', C --> B, C, S | |
# print(hs.size(), names) | |
hs = self.param_pool0(hs) | |
### adaptive adjacency | |
q1 = self.conv_q1(hs).mean(1) | |
k1 = self.conv_k1(hs).mean(1) | |
A1 = self.tanh(q1.unsqueeze(-1) - k1.unsqueeze(1)) | |
A1 = self.adj1 + A1 * self.alpha1 | |
### graph convolution | |
hs = self.conv1(hs) | |
hs = torch.matmul(hs, A1) | |
hs = self.batch_norm1(hs) | |
### predict | |
hs = self.param_pool1(hs) | |
hs = self.dropout(hs) | |
hs = hs.flatten(1) | |
hs = self.classifier(hs) | |
return hs | |
class WeaklySelector(nn.Module): | |
def __init__(self, inputs: dict, num_classes: int, num_select: dict, fpn_size: Union[int, None] = None): | |
""" | |
inputs: dictionary contain torch.Tensors, which comes from backbone | |
[Tensor1(hidden feature1), Tensor2(hidden feature2)...] | |
Please note that if len(features.size) equal to 3, the order of dimension must be [B,S,C], | |
S mean the spatial domain, and if len(features.size) equal to 4, the order must be [B,C,H,W] | |
""" | |
super(WeaklySelector, self).__init__() | |
self.num_select = num_select | |
self.fpn_size = fpn_size | |
### build classifier | |
if self.fpn_size is None: | |
self.num_classes = num_classes | |
for name in inputs: | |
fs_size = inputs[name].size() | |
if len(fs_size) == 3: | |
in_size = fs_size[2] | |
elif len(fs_size) == 4: | |
in_size = fs_size[1] | |
m = nn.Linear(in_size, num_classes) | |
self.add_module("classifier_l_"+name, m) | |
self.thresholds = {} | |
for name in inputs: | |
self.thresholds[name] = [] | |
# def select(self, logits, l_name): | |
# """ | |
# logits: [B, S, num_classes] | |
# """ | |
# probs = torch.softmax(logits, dim=-1) | |
# scores, _ = torch.max(probs, dim=-1) | |
# _, ids = torch.sort(scores, -1, descending=True) | |
# sn = self.num_select[l_name] | |
# s_ids = ids[:, :sn] | |
# not_s_ids = ids[:, sn:] | |
# return s_ids.unsqueeze(-1), not_s_ids.unsqueeze(-1) | |
def forward(self, x, logits=None): | |
""" | |
x : | |
dictionary contain the features maps which | |
come from your choosen layers. | |
size must be [B, HxW, C] ([B, S, C]) or [B, C, H, W]. | |
[B,C,H,W] will be transpose to [B, HxW, C] automatically. | |
""" | |
if self.fpn_size is None: | |
logits = {} | |
selections = {} | |
for name in x: | |
# print("[selector]", name, x[name].size()) | |
if "FPN1_" in name: | |
continue | |
if len(x[name].size()) == 4: | |
B, C, H, W = x[name].size() | |
x[name] = x[name].view(B, C, H*W).permute(0, 2, 1).contiguous() | |
C = x[name].size(-1) | |
if self.fpn_size is None: | |
logits[name] = getattr(self, "classifier_l_"+name)(x[name]) | |
probs = torch.softmax(logits[name], dim=-1) | |
sum_probs = torch.softmax(logits[name].mean(1), dim=-1) | |
selections[name] = [] | |
preds_1 = [] | |
preds_0 = [] | |
num_select = self.num_select[name] | |
for bi in range(logits[name].size(0)): | |
_, max_ids = torch.max(sum_probs[bi], dim=-1) | |
confs, ranks = torch.sort(probs[bi, :, max_ids], descending=True) | |
sf = x[name][bi][ranks[:num_select]] | |
nf = x[name][bi][ranks[num_select:]] # calculate | |
selections[name].append(sf) # [num_selected, C] | |
preds_1.append(logits[name][bi][ranks[:num_select]]) | |
preds_0.append(logits[name][bi][ranks[num_select:]]) | |
if bi >= len(self.thresholds[name]): | |
self.thresholds[name].append(confs[num_select]) # for initialize | |
else: | |
self.thresholds[name][bi] = confs[num_select] | |
selections[name] = torch.stack(selections[name]) | |
preds_1 = torch.stack(preds_1) | |
preds_0 = torch.stack(preds_0) | |
logits["select_"+name] = preds_1 | |
logits["drop_"+name] = preds_0 | |
return selections | |
class FPN(nn.Module): | |
def __init__(self, inputs: dict, fpn_size: int, proj_type: str, upsample_type: str): | |
""" | |
inputs : dictionary contains torch.Tensor | |
which comes from backbone output | |
fpn_size: integer, fpn | |
proj_type: | |
in ["Conv", "Linear"] | |
upsample_type: | |
in ["Bilinear", "Conv", "Fc"] | |
for convolution neural network (e.g. ResNet, EfficientNet), recommand 'Bilinear'. | |
for Vit, "Fc". and Swin-T, "Conv" | |
""" | |
super(FPN, self).__init__() | |
assert proj_type in ["Conv", "Linear"], \ | |
"FPN projection type {} were not support yet, please choose type 'Conv' or 'Linear'".format(proj_type) | |
assert upsample_type in ["Bilinear", "Conv"], \ | |
"FPN upsample type {} were not support yet, please choose type 'Bilinear' or 'Conv'".format(proj_type) | |
self.fpn_size = fpn_size | |
self.upsample_type = upsample_type | |
inp_names = [name for name in inputs] | |
for i, node_name in enumerate(inputs): | |
### projection module | |
if proj_type == "Conv": | |
m = nn.Sequential( | |
nn.Conv2d(inputs[node_name].size(1), inputs[node_name].size(1), 1), | |
nn.ReLU(), | |
nn.Conv2d(inputs[node_name].size(1), fpn_size, 1) | |
) | |
elif proj_type == "Linear": | |
in_feat = inputs[node_name] | |
if isinstance(in_feat, torch.Tensor): | |
dim = in_feat.size(-1) | |
else: | |
raise ValueError(f"Entrée invalide dans FPN: {type(in_feat)} pour node_name={node_name}") | |
m = nn.Sequential( | |
nn.Linear(dim, dim), | |
nn.ReLU(), | |
nn.Linear(dim, fpn_size), | |
) | |
self.add_module("Proj_"+node_name, m) | |
### upsample module | |
if upsample_type == "Conv" and i != 0: | |
assert len(inputs[node_name].size()) == 3 # B, S, C | |
in_dim = inputs[node_name].size(1) | |
out_dim = inputs[inp_names[i-1]].size(1) | |
# if in_dim != out_dim: | |
m = nn.Conv1d(in_dim, out_dim, 1) # for spatial domain | |
# else: | |
# m = nn.Identity() | |
self.add_module("Up_"+node_name, m) | |
if upsample_type == "Bilinear": | |
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') | |
def upsample_add(self, x0: torch.Tensor, x1: torch.Tensor, x1_name: str): | |
""" | |
return Upsample(x1) + x1 | |
""" | |
if self.upsample_type == "Bilinear": | |
if x1.size(-1) != x0.size(-1): | |
x1 = self.upsample(x1) | |
else: | |
x1 = getattr(self, "Up_"+x1_name)(x1) | |
return x1 + x0 | |
def forward(self, x): | |
""" | |
x : dictionary | |
{ | |
"node_name1": feature1, | |
"node_name2": feature2, ... | |
} | |
""" | |
### project to same dimension | |
hs = [] | |
for i, name in enumerate(x): | |
if "FPN1_" in name: | |
continue | |
x[name] = getattr(self, "Proj_"+name)(x[name]) | |
hs.append(name) | |
x["FPN1_" + "layer4"] = x["layer4"] | |
for i in range(len(hs)-1, 0, -1): | |
x1_name = hs[i] | |
x0_name = hs[i-1] | |
x[x0_name] = self.upsample_add(x[x0_name], | |
x[x1_name], | |
x1_name) | |
x["FPN1_" + x0_name] = x[x0_name] | |
return x | |
class FPN_UP(nn.Module): | |
def __init__(self, | |
inputs: dict, | |
fpn_size: int): | |
super(FPN_UP, self).__init__() | |
inp_names = [name for name in inputs] | |
for i, node_name in enumerate(inputs): | |
### projection module | |
m = nn.Sequential( | |
nn.Linear(fpn_size, fpn_size), | |
nn.ReLU(), | |
nn.Linear(fpn_size, fpn_size), | |
) | |
self.add_module("Proj_"+node_name, m) | |
### upsample module | |
if i != (len(inputs) - 1): | |
assert len(inputs[node_name].size()) == 3 # B, S, C | |
in_dim = inputs[node_name].size(1) | |
out_dim = inputs[inp_names[i+1]].size(1) | |
m = nn.Conv1d(in_dim, out_dim, 1) # for spatial domain | |
self.add_module("Down_"+node_name, m) | |
# print("Down_"+node_name, in_dim, out_dim) | |
""" | |
Down_layer1 2304 576 | |
Down_layer2 576 144 | |
Down_layer3 144 144 | |
""" | |
def downsample_add(self, x0: torch.Tensor, x1: torch.Tensor, x0_name: str): | |
""" | |
return Upsample(x1) + x1 | |
""" | |
# print("[downsample_add] Down_" + x0_name) | |
x0 = getattr(self, "Down_" + x0_name)(x0) | |
return x1 + x0 | |
def forward(self, x): | |
""" | |
x : dictionary | |
{ | |
"node_name1": feature1, | |
"node_name2": feature2, ... | |
} | |
""" | |
### project to same dimension | |
hs = [] | |
for i, name in enumerate(x): | |
if "FPN1_" in name: | |
continue | |
x[name] = getattr(self, "Proj_"+name)(x[name]) | |
hs.append(name) | |
# print(hs) | |
for i in range(0, len(hs) - 1): | |
x0_name = hs[i] | |
x1_name = hs[i+1] | |
# print(x0_name, x1_name) | |
# print(x[x0_name].size(), x[x1_name].size()) | |
x[x1_name] = self.downsample_add(x[x0_name], | |
x[x1_name], | |
x0_name) | |
return x | |
class PluginMoodel(nn.Module): | |
def __init__(self, | |
backbone: torch.nn.Module, | |
return_nodes: Union[dict, None], | |
img_size: int, | |
use_fpn: bool, | |
fpn_size: Union[int, None], | |
proj_type: str, | |
upsample_type: str, | |
use_selection: bool, | |
num_classes: int, | |
num_selects: dict, | |
use_combiner: bool, | |
comb_proj_size: Union[int, None] | |
): | |
""" | |
* backbone: | |
torch.nn.Module class (recommand pretrained on ImageNet or IG-3.5B-17k(provided by FAIR)) | |
* return_nodes: | |
e.g. | |
return_nodes = { | |
# node_name: user-specified key for output dict | |
'layer1.2.relu_2': 'layer1', | |
'layer2.3.relu_2': 'layer2', | |
'layer3.5.relu_2': 'layer3', | |
'layer4.2.relu_2': 'layer4', | |
} # you can see the example on https://pytorch.org/vision/main/feature_extraction.html | |
!!! if using 'Swin-Transformer', please set return_nodes to None | |
!!! and please set use_fpn to True | |
* feat_sizes: | |
tuple or list contain features map size of each layers. | |
((C, H, W)). e.g. ((1024, 14, 14), (2048, 7, 7)) | |
* use_fpn: | |
boolean, use features pyramid network or not | |
* fpn_size: | |
integer, features pyramid network projection dimension | |
* num_selects: | |
num_selects = { | |
# match user-specified in return_nodes | |
"layer1": 2048, | |
"layer2": 512, | |
"layer3": 128, | |
"layer4": 32, | |
} | |
Note: after selector module (WeaklySelector) , the feature map's size is [B, S', C] which | |
contained by 'logits' or 'selections' dictionary (S' is selection number, different layer | |
could be different). | |
""" | |
super(PluginMoodel, self).__init__() | |
### = = = = = Backbone = = = = = | |
self.return_nodes = return_nodes | |
if return_nodes is not None: | |
self.backbone = create_feature_extractor(backbone, return_nodes=return_nodes) | |
else: | |
self.backbone = backbone | |
### get hidden feartues size | |
rand_in = torch.randn(1, 3, img_size, img_size) | |
outs = self.backbone(rand_in) | |
### just original backbone | |
if not use_fpn and (not use_selection and not use_combiner): | |
for name in outs: | |
fs_size = outs[name].size() | |
if len(fs_size) == 3: | |
out_size = fs_size.size(-1) | |
elif len(fs_size) == 4: | |
out_size = fs_size.size(1) | |
else: | |
raise ValusError("The size of output dimension of previous must be 3 or 4.") | |
self.classifier = nn.Linear(out_size, num_classes) | |
### = = = = = FPN = = = = = | |
self.use_fpn = use_fpn | |
if self.use_fpn: | |
self.fpn_down = FPN(outs, fpn_size, proj_type, upsample_type) | |
self.build_fpn_classifier_down(outs, fpn_size, num_classes) | |
self.fpn_up = FPN_UP(outs, fpn_size) | |
self.build_fpn_classifier_up(outs, fpn_size, num_classes) | |
self.fpn_size = fpn_size | |
### = = = = = Selector = = = = = | |
self.use_selection = use_selection | |
if self.use_selection: | |
w_fpn_size = self.fpn_size if self.use_fpn else None # if not using fpn, build classifier in weakly selector | |
self.selector = WeaklySelector(outs, num_classes, num_selects, w_fpn_size) | |
### = = = = = Combiner = = = = = | |
self.use_combiner = use_combiner | |
if self.use_combiner: | |
assert self.use_selection, "Please use selection module before combiner" | |
if self.use_fpn: | |
gcn_inputs, gcn_proj_size = None, None | |
else: | |
gcn_inputs, gcn_proj_size = outs, comb_proj_size # redundant, fix in future | |
total_num_selects = sum([num_selects[name] for name in num_selects]) # sum | |
self.combiner = GCNCombiner(total_num_selects, num_classes, gcn_inputs, gcn_proj_size, self.fpn_size) | |
def build_fpn_classifier_up(self, inputs: dict, fpn_size: int, num_classes: int): | |
""" | |
Teh results of our experiments show that linear classifier in this case may cause some problem. | |
""" | |
for name in inputs: | |
m = nn.Sequential( | |
nn.Conv1d(fpn_size, fpn_size, 1), | |
nn.BatchNorm1d(fpn_size), | |
nn.ReLU(), | |
nn.Conv1d(fpn_size, num_classes, 1) | |
) | |
self.add_module("fpn_classifier_up_"+name, m) | |
def build_fpn_classifier_down(self, inputs: dict, fpn_size: int, num_classes: int): | |
""" | |
Teh results of our experiments show that linear classifier in this case may cause some problem. | |
""" | |
for name in inputs: | |
m = nn.Sequential( | |
nn.Conv1d(fpn_size, fpn_size, 1), | |
nn.BatchNorm1d(fpn_size), | |
nn.ReLU(), | |
nn.Conv1d(fpn_size, num_classes, 1) | |
) | |
self.add_module("fpn_classifier_down_" + name, m) | |
def forward_backbone(self, x): | |
return self.backbone(x) | |
def fpn_predict_down(self, x: dict, logits: dict): | |
""" | |
x: [B, C, H, W] or [B, S, C] | |
[B, C, H, W] --> [B, H*W, C] | |
""" | |
for name in x: | |
if "FPN1_" not in name: | |
continue | |
### predict on each features point | |
if len(x[name].size()) == 4: | |
B, C, H, W = x[name].size() | |
logit = x[name].view(B, C, H*W) | |
elif len(x[name].size()) == 3: | |
logit = x[name].transpose(1, 2).contiguous() | |
model_name = name.replace("FPN1_", "") | |
logits[name] = getattr(self, "fpn_classifier_down_" + model_name)(logit) | |
logits[name] = logits[name].transpose(1, 2).contiguous() # transpose | |
def fpn_predict_up(self, x: dict, logits: dict): | |
""" | |
x: [B, C, H, W] or [B, S, C] | |
[B, C, H, W] --> [B, H*W, C] | |
""" | |
for name in x: | |
if "FPN1_" in name: | |
continue | |
### predict on each features point | |
if len(x[name].size()) == 4: | |
B, C, H, W = x[name].size() | |
logit = x[name].view(B, C, H*W) | |
elif len(x[name].size()) == 3: | |
logit = x[name].transpose(1, 2).contiguous() | |
model_name = name.replace("FPN1_", "") | |
logits[name] = getattr(self, "fpn_classifier_up_" + model_name)(logit) | |
logits[name] = logits[name].transpose(1, 2).contiguous() # transpose | |
def forward(self, x: torch.Tensor): | |
logits = {} | |
x = self.forward_backbone(x) | |
if self.use_fpn: | |
x = self.fpn_down(x) | |
# print([name for name in x]) | |
self.fpn_predict_down(x, logits) | |
x = self.fpn_up(x) | |
self.fpn_predict_up(x, logits) | |
if self.use_selection: | |
selects = self.selector(x, logits) | |
if self.use_combiner: | |
comb_outs = self.combiner(selects) | |
logits['comb_outs'] = comb_outs | |
return logits | |
if self.use_selection or self.fpn: | |
return logits | |
### original backbone (only predict final selected layer) | |
for name in x: | |
hs = x[name] | |
if len(hs.size()) == 4: | |
hs = F.adaptive_avg_pool2d(hs, (1, 1)) | |
hs = hs.flatten(1) | |
else: | |
hs = hs.mean(1) | |
out = self.classifier(hs) | |
logits['ori_out'] = logits | |
return | |