import torch from transformers import T5EncoderModel, T5Config, T5PreTrainedModel from transformers.modeling_outputs import BaseModelOutput from typing import List, Optional, Tuple, Union from torch import nn, Tensor class T5ProjectionConfig(T5Config): def __init__(self, **kwargs): super().__init__(**kwargs) self.project_in_dim = kwargs.get("project_in_dim", 768) self.project_out_dim = kwargs.get("out_dim", 4096) class T5EncoderWithProjection(T5PreTrainedModel): config_class = T5ProjectionConfig def __init__(self, config): super().__init__(config) # self.encoder = encoder self.encoder = T5EncoderModel(config) self.final_projection = nn.Sequential( nn.Linear(config.project_in_dim, config.project_out_dim, bias=False), nn.ReLU(), nn.Dropout(0.0), nn.Linear(config.project_out_dim, config.project_out_dim, bias=False) ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: return_dict = return_dict if return_dict is not None else False encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = self.final_projection(encoder_outputs[0]) # last_hidden_state = self.final_block(last_hidden_state)[0] if not return_dict: return tuple( v for v in [last_hidden_state] if v is not None ) return BaseModelOutput( last_hidden_state=last_hidden_state )