Spaces:
Runtime error
Runtime error
from typing import Optional | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
class MultiLevelDownstreamModel(nn.Module): | |
def __init__( | |
self, | |
model_config, | |
use_conv_output: Optional[bool] = True, | |
): | |
super().__init__() | |
assert model_config.output_hidden_states == True, "The upstream model must return all hidden states" | |
self.model_config = model_config | |
self.use_conv_output = use_conv_output | |
self.model_seq = nn.Sequential( | |
nn.Conv1d(self.model_config.hidden_size, self.model_config.classifier_proj_size, 1, padding=0), | |
nn.ReLU(), | |
nn.Dropout(p=0.1), | |
nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0), | |
nn.ReLU(), | |
nn.Dropout(p=0.1), | |
nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0) | |
) | |
if self.use_conv_output: | |
num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings | |
self.weights = nn.Parameter(torch.ones(num_layers)/num_layers) | |
else: | |
num_layers = self.model_config.num_hidden_layers | |
self.weights = nn.Parameter(torch.zeros(num_layers)) | |
self.out_layer = nn.Sequential( | |
nn.Linear(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size), | |
nn.ReLU(), | |
nn.Linear(self.model_config.classifier_proj_size, self.model_config.num_labels), | |
) | |
def forward(self, encoder_hidden_states, length=None): | |
if self.use_conv_output: | |
stacked_feature = torch.stack(encoder_hidden_states, dim=0) | |
else: | |
stacked_feature = torch.stack(encoder_hidden_states, dim=0)[1:] # exclude the convolution output | |
_, *origin_shape = stacked_feature.shape | |
if self.use_conv_output: | |
stacked_feature = stacked_feature.view(self.model_config.num_hidden_layers + 1, -1) | |
else: | |
stacked_feature = stacked_feature.view(self.model_config.config.num_hidden_layers, -1) | |
norm_weights = F.softmax(self.weights, dim=-1) | |
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0) | |
features = weighted_feature.view(*origin_shape) | |
features = features.transpose(1, 2) | |
features = self.model_seq(features) | |
features = features.transpose(1, 2) | |
if length is not None: | |
length = length.cuda() | |
masks = torch.arange(features.size(1)).expand(length.size(0), -1).cuda() < length.unsqueeze(1) | |
masks = masks.float() | |
features = (features * masks.unsqueeze(-1)).sum(1) / length.unsqueeze(1) | |
else: | |
features = torch.mean(features, dim=1) | |
predicted = self.out_layer(features) | |
return predicted |