shunt-adapter-testing / custom /t5_encoder_with_projection.py
AbstractPhil
local project created to properly edit and debug
d3479d5
import torch
from transformers import T5EncoderModel, T5Config, T5PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from typing import List, Optional, Tuple, Union
from torch import nn, Tensor
class T5ProjectionConfig(T5Config):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.project_in_dim = kwargs.get("project_in_dim", 768)
self.project_out_dim = kwargs.get("out_dim", 4096)
class T5EncoderWithProjection(T5PreTrainedModel):
config_class = T5ProjectionConfig
def __init__(self, config):
super().__init__(config)
# self.encoder = encoder
self.encoder = T5EncoderModel(config)
self.final_projection = nn.Sequential(
nn.Linear(config.project_in_dim, config.project_out_dim, bias=False),
nn.ReLU(),
nn.Dropout(0.0),
nn.Linear(config.project_out_dim, config.project_out_dim, bias=False)
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
return_dict = return_dict if return_dict is not None else False
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = self.final_projection(encoder_outputs[0])
# last_hidden_state = self.final_block(last_hidden_state)[0]
if not return_dict:
return tuple(
v for v in [last_hidden_state] if v is not None
)
return BaseModelOutput(
last_hidden_state=last_hidden_state
)