|
import torch |
|
from typing import Callable, List, Optional |
|
|
|
|
|
class MLP(torch.nn.Sequential): |
|
"""This block implements the multi-layer perceptron (MLP) module. |
|
Adapted for backward compatibility from the torchvision library: |
|
https://pytorch.org/vision/0.14/generated/torchvision.ops.MLP.html |
|
|
|
LICENSE: |
|
|
|
From PyTorch: |
|
|
|
Copyright (c) 2016- Facebook, Inc (Adam Paszke) |
|
Copyright (c) 2014- Facebook, Inc (Soumith Chintala) |
|
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) |
|
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) |
|
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) |
|
Copyright (c) 2011-2013 NYU (Clement Farabet) |
|
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) |
|
Copyright (c) 2006 Idiap Research Institute (Samy Bengio) |
|
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) |
|
|
|
From Caffe2: |
|
|
|
Copyright (c) 2016-present, Facebook Inc. All rights reserved. |
|
|
|
All contributions by Facebook: |
|
Copyright (c) 2016 Facebook Inc. |
|
|
|
All contributions by Google: |
|
Copyright (c) 2015 Google Inc. |
|
All rights reserved. |
|
|
|
All contributions by Yangqing Jia: |
|
Copyright (c) 2015 Yangqing Jia |
|
All rights reserved. |
|
|
|
All contributions by Kakao Brain: |
|
Copyright 2019-2020 Kakao Brain |
|
|
|
All contributions by Cruise LLC: |
|
Copyright (c) 2022 Cruise LLC. |
|
All rights reserved. |
|
|
|
All contributions from Caffe: |
|
Copyright(c) 2013, 2014, 2015, the respective contributors |
|
All rights reserved. |
|
|
|
All other contributions: |
|
Copyright(c) 2015, 2016 the respective contributors |
|
All rights reserved. |
|
|
|
Caffe2 uses a copyright model similar to Caffe: each contributor holds |
|
copyright over their contributions to Caffe2. The project versioning records |
|
all such contribution and copyright details. If a contributor wants to further |
|
mark their specific copyright on a particular contribution, they should |
|
indicate their copyright solely in the commit message of the change when it is |
|
committed. |
|
|
|
All rights reserved. |
|
|
|
Redistribution and use in source and binary forms, with or without |
|
modification, are permitted provided that the following conditions are met: |
|
|
|
1. Redistributions of source code must retain the above copyright |
|
notice, this list of conditions and the following disclaimer. |
|
|
|
2. Redistributions in binary form must reproduce the above copyright |
|
notice, this list of conditions and the following disclaimer in the |
|
documentation and/or other materials provided with the distribution. |
|
|
|
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America |
|
and IDIAP Research Institute nor the names of its contributors may be |
|
used to endorse or promote products derived from this software without |
|
specific prior written permission. |
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
|
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE |
|
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
|
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
|
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
|
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
|
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
|
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
|
POSSIBILITY OF SUCH DAMAGE. |
|
|
|
|
|
Args: |
|
in_channels (int): Number of channels of the input |
|
hidden_channels (List[int]): List of the hidden channel dimensions |
|
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` |
|
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` |
|
inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. |
|
Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. |
|
bias (bool): Whether to use bias in the linear layer. Default ``True`` |
|
dropout (float): The probability for the dropout layer. Default: 0.0 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
hidden_channels: List[int], |
|
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, |
|
inplace: Optional[bool] = None, |
|
bias: bool = True, |
|
dropout: float = 0.0, |
|
): |
|
params = {} if inplace is None else {"inplace": inplace} |
|
|
|
layers = [] |
|
in_dim = in_channels |
|
for hidden_dim in hidden_channels[:-1]: |
|
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) |
|
layers.append(activation_layer(**params)) |
|
layers.append(torch.nn.Dropout(dropout, **params)) |
|
in_dim = hidden_dim |
|
|
|
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) |
|
layers.append(torch.nn.Dropout(dropout, **params)) |
|
|
|
super().__init__(*layers) |
|
|
|
|
|
def batch_idx(x: torch.Tensor, idx: torch.Tensor): |
|
assert ( |
|
idx.shape == x.shape[: len(idx.shape)] |
|
), f"Index shape must match a prefix of x.shape. Got idx: {idx.shape} and x: {x.shape}" |
|
|
|
remaining_shape = x.shape[len(idx.shape) + 1 :] |
|
x = x.flatten(start_dim=idx.dim() + 1) |
|
indices = [] |
|
for i in range(len(idx.shape)): |
|
raw = torch.arange(idx.shape[i], device=idx.device) |
|
desired_shape = [1] * len(idx.shape) |
|
desired_shape[i] = idx.shape[i] |
|
indices.append(raw.reshape(desired_shape)) |
|
indices.append(idx) |
|
result = x[tuple(indices)] |
|
return result.reshape(idx.shape + remaining_shape) |
|
|