Spaces:
Runtime error
Runtime error
import torch | |
from transformers import T5EncoderModel, T5Config, T5PreTrainedModel | |
from transformers.modeling_outputs import BaseModelOutput | |
from typing import List, Optional, Tuple, Union | |
from torch import nn, Tensor | |
class T5ProjectionConfig(T5Config): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.project_in_dim = kwargs.get("project_in_dim", 768) | |
self.project_out_dim = kwargs.get("out_dim", 4096) | |
class T5EncoderWithProjection(T5PreTrainedModel): | |
config_class = T5ProjectionConfig | |
def __init__(self, config): | |
super().__init__(config) | |
# self.encoder = encoder | |
self.encoder = T5EncoderModel(config) | |
self.final_projection = nn.Sequential( | |
nn.Linear(config.project_in_dim, config.project_out_dim, bias=False), | |
nn.ReLU(), | |
nn.Dropout(0.0), | |
nn.Linear(config.project_out_dim, config.project_out_dim, bias=False) | |
) | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: | |
return_dict = return_dict if return_dict is not None else False | |
encoder_outputs = self.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
inputs_embeds=inputs_embeds, | |
head_mask=head_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
last_hidden_state = self.final_projection(encoder_outputs[0]) | |
# last_hidden_state = self.final_block(last_hidden_state)[0] | |
if not return_dict: | |
return tuple( | |
v for v in [last_hidden_state] if v is not None | |
) | |
return BaseModelOutput( | |
last_hidden_state=last_hidden_state | |
) |