File size: 2,231 Bytes
d3479d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from transformers import T5EncoderModel, T5Config, T5PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from typing import List, Optional, Tuple, Union
from torch import nn, Tensor


class T5ProjectionConfig(T5Config):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.project_in_dim = kwargs.get("project_in_dim", 768)
        self.project_out_dim = kwargs.get("out_dim", 4096)


class T5EncoderWithProjection(T5PreTrainedModel):
    config_class = T5ProjectionConfig

    def __init__(self, config):
        super().__init__(config)
        # self.encoder = encoder
        self.encoder = T5EncoderModel(config)

        self.final_projection = nn.Sequential(
            nn.Linear(config.project_in_dim, config.project_out_dim, bias=False),
            nn.ReLU(),
            nn.Dropout(0.0),
            nn.Linear(config.project_out_dim, config.project_out_dim, bias=False)
        )

    def forward(
            self,
            input_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            head_mask: Optional[torch.FloatTensor] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:

        return_dict = return_dict if return_dict is not None else False

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        last_hidden_state = self.final_projection(encoder_outputs[0])
        # last_hidden_state = self.final_block(last_hidden_state)[0]

        if not return_dict:
            return tuple(
                v for v in [last_hidden_state] if v is not None
            )

        return BaseModelOutput(
            last_hidden_state=last_hidden_state
        )