ReactionT5 / models.py
sagawa's picture
Upload 42 files
08ccc8e verified
import torch
import torch.nn as nn
from transformers import (
AutoConfig,
AutoModel,
PreTrainedModel,
T5ForConditionalGeneration,
)
class ReactionT5Yield(nn.Module):
def __init__(self, cfg, config_path=None, pretrained=False):
super().__init__()
self.cfg = cfg
if config_path is None:
self.config = AutoConfig.from_pretrained(
self.cfg.pretrained_model_name_or_path, output_hidden_states=True
)
else:
self.config = torch.load(config_path, weights_only=False)
if pretrained:
self.model = AutoModel.from_pretrained(
self.cfg.pretrained_model_name_or_path
)
else:
self.model = AutoModel.from_config(self.config)
self.model.resize_token_embeddings(len(self.cfg.tokenizer))
self.fc_dropout1 = nn.Dropout(self.cfg.fc_dropout)
self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2)
self.fc_dropout2 = nn.Dropout(self.cfg.fc_dropout)
self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2)
self.fc3 = nn.Linear(self.config.hidden_size // 2 * 2, self.config.hidden_size)
self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
self.fc5 = nn.Linear(self.config.hidden_size, 1)
self._init_weights(self.fc1)
self._init_weights(self.fc2)
self._init_weights(self.fc3)
self._init_weights(self.fc4)
self._init_weights(self.fc5)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.01)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.01)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, inputs):
encoder_outputs = self.model.encoder(**inputs)
encoder_hidden_states = encoder_outputs[0]
outputs = self.model.decoder(
input_ids=torch.full(
(inputs["input_ids"].size(0), 1),
self.config.decoder_start_token_id,
dtype=torch.long,
device=inputs["input_ids"].device,
),
encoder_hidden_states=encoder_hidden_states,
)
last_hidden_states = outputs[0]
output1 = self.fc1(
self.fc_dropout1(last_hidden_states).view(-1, self.config.hidden_size)
)
output2 = self.fc2(
encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size)
)
output = self.fc3(self.fc_dropout2(torch.hstack((output1, output2))))
output = self.fc4(output)
output = self.fc5(output)
return output
def generate_embedding(self, inputs):
encoder_outputs = self.model.encoder(**inputs)
encoder_hidden_states = encoder_outputs[0]
outputs = self.model.decoder(
input_ids=torch.full(
(inputs["input_ids"].size(0), 1),
self.config.decoder_start_token_id,
dtype=torch.long,
device=inputs["input_ids"].device,
),
encoder_hidden_states=encoder_hidden_states,
)
last_hidden_states = outputs[0]
output1 = self.fc1(
self.fc_dropout1(last_hidden_states).view(-1, self.config.hidden_size)
)
output2 = self.fc2(
encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size)
)
return torch.hstack((output1, output2))
class ReactionT5Yield2(PreTrainedModel):
config_class = AutoConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = T5ForConditionalGeneration.from_pretrained(
self.config._name_or_path
)
self.model.resize_token_embeddings(self.config.vocab_size)
self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2)
self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2)
self.fc3 = nn.Linear(self.config.hidden_size // 2 * 2, self.config.hidden_size)
self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
self.fc5 = nn.Linear(self.config.hidden_size, 1)
self._init_weights(self.fc1)
self._init_weights(self.fc2)
self._init_weights(self.fc3)
self._init_weights(self.fc4)
self._init_weights(self.fc5)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.01)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.01)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, inputs):
encoder_outputs = self.model.encoder(**inputs)
encoder_hidden_states = encoder_outputs[0]
outputs = self.model.decoder(
input_ids=torch.full(
(inputs["input_ids"].size(0), 1),
self.config.decoder_start_token_id,
dtype=torch.long,
device=inputs["input_ids"].device,
),
encoder_hidden_states=encoder_hidden_states,
)
last_hidden_states = outputs[0]
output1 = self.fc1(last_hidden_states.view(-1, self.config.hidden_size))
output2 = self.fc2(
encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size)
)
output = self.fc3(torch.hstack((output1, output2)))
output = self.fc4(output)
output = self.fc5(output)
return output * 100
def generate_embedding(self, inputs):
encoder_outputs = self.model.encoder(**inputs)
encoder_hidden_states = encoder_outputs[0]
outputs = self.model.decoder(
input_ids=torch.full(
(inputs["input_ids"].size(0), 1),
self.config.decoder_start_token_id,
dtype=torch.long,
device=inputs["input_ids"].device,
),
encoder_hidden_states=encoder_hidden_states,
)
last_hidden_states = outputs[0]
output1 = self.fc1(last_hidden_states.view(-1, self.config.hidden_size))
output2 = self.fc2(
encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size)
)
return torch.hstack((output1, output2))