|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
initial_dim=512,
|
|
hidden_dim=[128, 32, 2],
|
|
final_dim=2,
|
|
norm=nn.InstanceNorm1d,
|
|
activation=nn.ReLU,
|
|
aux_data=[],
|
|
):
|
|
"""
|
|
Initializes an MLP Classification Head
|
|
Args:
|
|
hidden_dim (list): list of hidden dimensions for the MLP
|
|
norm (nn.Module): normalization layer
|
|
activation (nn.Module): activation layer
|
|
"""
|
|
super().__init__()
|
|
self.aux_data = aux_data
|
|
self.aux = len(self.aux_data) > 0
|
|
if self.aux:
|
|
hidden_dim_aux = hidden_dim
|
|
hidden_dim_aux[-1] = 128
|
|
final_dim_aux_dict = {
|
|
"land_cover": 12,
|
|
"climate": 30,
|
|
"soil": 14,
|
|
"road_index": 1,
|
|
"drive_side": 1,
|
|
"dist_sea": 1,
|
|
}
|
|
self.idx = {}
|
|
final_dim_aux = 0
|
|
for col in self.aux_data:
|
|
self.idx[col] = [
|
|
final_dim_aux + i for i in range(final_dim_aux_dict[col])
|
|
]
|
|
final_dim_aux += final_dim_aux_dict[col]
|
|
dim = [initial_dim] + hidden_dim_aux + [final_dim_aux]
|
|
args = self.init_layers(dim, norm, activation)
|
|
self.mlp_aux = nn.Sequential(*args)
|
|
dim = [initial_dim] + hidden_dim + [final_dim]
|
|
args = self.init_layers(dim, norm, activation)
|
|
self.mlp = nn.Sequential(*args)
|
|
|
|
def init_layers(self, dim, norm, activation):
|
|
"""Initializes the MLP layers."""
|
|
args = [nn.LayerNorm(dim[0])]
|
|
for i in range(len(dim) - 1):
|
|
args.append(nn.Linear(dim[i], dim[i + 1]))
|
|
if i < len(dim) - 2:
|
|
|
|
args.append(norm(4, dim[i + 1]))
|
|
args.append(activation())
|
|
return args
|
|
|
|
def forward(self, x):
|
|
"""Predicts GPS coordinates from an image.
|
|
Args:
|
|
x: torch.Tensor with features
|
|
"""
|
|
if self.aux:
|
|
out = {"gps": self.mlp(x[:, 0, :])}
|
|
x = self.mlp_aux(x[:, 0, :])
|
|
for col in list(self.idx.keys()):
|
|
out[col] = x[:, self.idx[col]]
|
|
return out
|
|
return self.mlp(x[:, 0, :])
|
|
|
|
class MLPResNet(nn.Module):
|
|
def __init__(
|
|
self,
|
|
initial_dim=512,
|
|
hidden_dim=[128, 32, 2],
|
|
final_dim=2,
|
|
norm=nn.InstanceNorm1d,
|
|
activation=nn.ReLU,
|
|
aux_data=[],
|
|
):
|
|
"""
|
|
Initializes an MLP Classification Head
|
|
Args:
|
|
hidden_dim (list): list of hidden dimensions for the MLP
|
|
norm (nn.Module): normalization layer
|
|
activation (nn.Module): activation layer
|
|
"""
|
|
super().__init__()
|
|
self.aux_data = aux_data
|
|
self.aux = len(self.aux_data) > 0
|
|
if self.aux:
|
|
hidden_dim_aux = hidden_dim
|
|
hidden_dim_aux[-1] = 128
|
|
final_dim_aux_dict = {
|
|
"land_cover": 12,
|
|
"climate": 30,
|
|
"soil": 14,
|
|
"road_index": 1,
|
|
"drive_side": 1,
|
|
"dist_sea": 1,
|
|
}
|
|
self.idx = {}
|
|
final_dim_aux = 0
|
|
for col in self.aux_data:
|
|
self.idx[col] = [
|
|
final_dim_aux + i for i in range(final_dim_aux_dict[col])
|
|
]
|
|
final_dim_aux += final_dim_aux_dict[col]
|
|
dim = [initial_dim] + hidden_dim_aux + [final_dim_aux]
|
|
args = self.init_layers(dim, norm, activation)
|
|
self.mlp_aux = nn.Sequential(*args)
|
|
dim = [initial_dim] + hidden_dim + [final_dim]
|
|
args = self.init_layers(dim, norm, activation)
|
|
self.mlp = nn.Sequential(*args)
|
|
|
|
def init_layers(self, dim, norm, activation):
|
|
"""Initializes the MLP layers."""
|
|
args = [nn.LayerNorm(dim[0])]
|
|
for i in range(len(dim) - 1):
|
|
args.append(nn.Linear(dim[i], dim[i + 1]))
|
|
if i < len(dim) - 2:
|
|
|
|
args.append(norm(4, dim[i + 1]))
|
|
args.append(activation())
|
|
return args
|
|
|
|
def forward(self, x):
|
|
"""Predicts GPS coordinates from an image.
|
|
Args:
|
|
x: torch.Tensor with features
|
|
"""
|
|
if self.aux:
|
|
out = {"gps": self.mlp(x[:, 0, :])}
|
|
x = self.mlp_aux(x[:, 0, :])
|
|
for col in list(self.idx.keys()):
|
|
out[col] = x[:, self.idx[col]]
|
|
return out
|
|
return self.mlp(x)
|
|
|
|
|
|
class MLPCentroid(nn.Module):
|
|
def __init__(
|
|
self,
|
|
initial_dim=512,
|
|
hidden_dim=[128, 32, 2],
|
|
final_dim=2,
|
|
norm=nn.InstanceNorm1d,
|
|
activation=nn.ReLU,
|
|
aux_data=[],
|
|
):
|
|
"""
|
|
Initializes an MLP Classification Head
|
|
Args:
|
|
hidden_dim (list): list of hidden dimensions for the MLP
|
|
norm (nn.Module): normalization layer
|
|
activation (nn.Module): activation layer
|
|
"""
|
|
super().__init__()
|
|
self.aux_data = aux_data
|
|
self.aux = len(self.aux_data) > 0
|
|
dim = [initial_dim] + hidden_dim + [final_dim // 3]
|
|
args = self.init_layers(dim, norm, activation)
|
|
self.classif = nn.Sequential(*args)
|
|
dim = [initial_dim] + hidden_dim + [2 * final_dim // 3]
|
|
args = self.init_layers(dim, norm, activation)
|
|
self.reg = nn.Sequential(*args)
|
|
|
|
if self.aux:
|
|
self.dim = [initial_dim] + hidden_dim
|
|
self.predictors = {"gps": self.mlp}
|
|
self.init_aux(dim, norm, activation)
|
|
|
|
def init_layers(self, dim, norm, activation):
|
|
"""Initializes the MLP layers."""
|
|
args = [nn.LayerNorm(dim[0])]
|
|
for i in range(len(dim) - 1):
|
|
args.append(nn.Linear(dim[i], dim[i + 1]))
|
|
if i < len(dim) - 2:
|
|
|
|
args.append(norm(4, dim[i + 1]))
|
|
args.append(activation())
|
|
return args
|
|
|
|
def init_aux(self, dim, norm, activation):
|
|
final_dim_aux = {
|
|
"land_cover": 12,
|
|
"climate": 30,
|
|
"soil": 14,
|
|
"road_index": 1,
|
|
"drive_side": 1,
|
|
"dist_sea": 1,
|
|
}
|
|
if "land_cover" in self.aux_data:
|
|
args = self.init_layers(
|
|
self.dim + [final_dim_aux["land_cover"]], norm, activation
|
|
)
|
|
self.land_cover = nn.Sequential(*args)
|
|
self.predictors["land_cover"] = self.land_cover
|
|
if "road_index" in self.aux_data:
|
|
args = self.init_layers(
|
|
self.dim + [final_dim_aux["road_index"]], norm, activation
|
|
)
|
|
self.road_index = nn.Sequential(*args)
|
|
self.predictors["road_index"] = self.road_index
|
|
if "drive_side" in self.aux_data:
|
|
args = self.init_layers(
|
|
self.dim + [final_dim_aux["drive_side"]], norm, activation
|
|
)
|
|
self.drive_side = nn.Sequential(*args)
|
|
self.predictors["drive_side"] = self.drive_side
|
|
if "climate" in self.aux_data:
|
|
args = self.init_layers(
|
|
self.dim + [final_dim_aux["climate"]], norm, activation
|
|
)
|
|
self.climate = nn.Sequential(*args)
|
|
self.predictors["climate"] = self.climate
|
|
if "soil" in self.aux_data:
|
|
args = self.init_layers(
|
|
self.dim + [final_dim_aux["soil"]], norm, activation
|
|
)
|
|
self.soil = nn.Sequential(*args)
|
|
self.predictors["soil"] = self.soil
|
|
if "dist_sea" in self.aux_data:
|
|
args = self.init_layers(
|
|
self.dim + [final_dim_aux["dist_sea"]], norm, activation
|
|
)
|
|
self.dist_sea = nn.Sequential(*args)
|
|
self.predictors["dist_sea"] = self.dist_sea
|
|
|
|
def forward(self, x):
|
|
"""Predicts GPS coordinates from an image.
|
|
Args:
|
|
x: torch.Tensor with features
|
|
"""
|
|
if self.aux:
|
|
return {
|
|
col: self.predictors[col](x[:, 0, :]) for col in self.predictors.keys()
|
|
}
|
|
return torch.cat([self.classif(x[:, 0, :]), self.reg(x[:, 0, :])], dim=1)
|
|
|
|
|
|
class Identity(nn.Module):
|
|
def __init__(
|
|
self
|
|
):
|
|
"""
|
|
Initializes an Identity module
|
|
"""
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Return same as input
|
|
"""
|
|
return x |