Spaces:
Runtime error
Runtime error
File size: 2,231 Bytes
d3479d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
from transformers import T5EncoderModel, T5Config, T5PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from typing import List, Optional, Tuple, Union
from torch import nn, Tensor
class T5ProjectionConfig(T5Config):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.project_in_dim = kwargs.get("project_in_dim", 768)
self.project_out_dim = kwargs.get("out_dim", 4096)
class T5EncoderWithProjection(T5PreTrainedModel):
config_class = T5ProjectionConfig
def __init__(self, config):
super().__init__(config)
# self.encoder = encoder
self.encoder = T5EncoderModel(config)
self.final_projection = nn.Sequential(
nn.Linear(config.project_in_dim, config.project_out_dim, bias=False),
nn.ReLU(),
nn.Dropout(0.0),
nn.Linear(config.project_out_dim, config.project_out_dim, bias=False)
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
return_dict = return_dict if return_dict is not None else False
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = self.final_projection(encoder_outputs[0])
# last_hidden_state = self.final_block(last_hidden_state)[0]
if not return_dict:
return tuple(
v for v in [last_hidden_state] if v is not None
)
return BaseModelOutput(
last_hidden_state=last_hidden_state
) |