|
"""Library implementing linear transformation. |
|
|
|
Authors |
|
* Mirco Ravanelli 2020 |
|
* Davide Borra 2021 |
|
""" |
|
|
|
import logging |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class Linear(torch.nn.Module): |
|
"""Computes a linear transformation y = wx + b. |
|
|
|
Arguments |
|
--------- |
|
n_neurons : int |
|
It is the number of output neurons (i.e, the dimensionality of the |
|
output). |
|
input_shape : tuple |
|
It is the shape of the input tensor. |
|
input_size : int |
|
Size of the input tensor. |
|
bias : bool |
|
If True, the additive bias b is adopted. |
|
max_norm : float |
|
weight max-norm. |
|
combine_dims : bool |
|
If True and the input is 4D, combine 3rd and 4th dimensions of input. |
|
|
|
Example |
|
------- |
|
>>> inputs = torch.rand(10, 50, 40) |
|
>>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100) |
|
>>> output = lin_t(inputs) |
|
>>> output.shape |
|
torch.Size([10, 50, 100]) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_neurons, |
|
input_shape=None, |
|
input_size=None, |
|
bias=True, |
|
max_norm=None, |
|
combine_dims=False, |
|
): |
|
super().__init__() |
|
self.max_norm = max_norm |
|
self.combine_dims = combine_dims |
|
|
|
if input_shape is None and input_size is None: |
|
raise ValueError("Expected one of input_shape or input_size") |
|
|
|
if input_size is None: |
|
input_size = input_shape[-1] |
|
if len(input_shape) == 4 and self.combine_dims: |
|
input_size = input_shape[2] * input_shape[3] |
|
|
|
|
|
self.w = nn.Linear(input_size, n_neurons, bias=bias) |
|
|
|
def forward(self, x): |
|
"""Returns the linear transformation of input tensor. |
|
|
|
Arguments |
|
--------- |
|
x : torch.Tensor |
|
Input to transform linearly. |
|
|
|
Returns |
|
------- |
|
wx : torch.Tensor |
|
The linearly transformed outputs. |
|
""" |
|
if x.ndim == 4 and self.combine_dims: |
|
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) |
|
|
|
if self.max_norm is not None: |
|
self.w.weight.data = torch.renorm( |
|
self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm |
|
) |
|
|
|
wx = self.w(x) |
|
|
|
return wx |
|
|