import torch import torch.nn as nn from transformers import ( AutoConfig, AutoModel, PreTrainedModel, T5ForConditionalGeneration, ) class ReactionT5Yield(nn.Module): def __init__(self, cfg, config_path=None, pretrained=False): super().__init__() self.cfg = cfg if config_path is None: self.config = AutoConfig.from_pretrained( self.cfg.pretrained_model_name_or_path, output_hidden_states=True ) else: self.config = torch.load(config_path, weights_only=False) if pretrained: self.model = AutoModel.from_pretrained( self.cfg.pretrained_model_name_or_path ) else: self.model = AutoModel.from_config(self.config) self.model.resize_token_embeddings(len(self.cfg.tokenizer)) self.fc_dropout1 = nn.Dropout(self.cfg.fc_dropout) self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2) self.fc_dropout2 = nn.Dropout(self.cfg.fc_dropout) self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2) self.fc3 = nn.Linear(self.config.hidden_size // 2 * 2, self.config.hidden_size) self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size) self.fc5 = nn.Linear(self.config.hidden_size, 1) self._init_weights(self.fc1) self._init_weights(self.fc2) self._init_weights(self.fc3) self._init_weights(self.fc4) self._init_weights(self.fc5) def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.01) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.01) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward(self, inputs): encoder_outputs = self.model.encoder(**inputs) encoder_hidden_states = encoder_outputs[0] outputs = self.model.decoder( input_ids=torch.full( (inputs["input_ids"].size(0), 1), self.config.decoder_start_token_id, dtype=torch.long, device=inputs["input_ids"].device, ), encoder_hidden_states=encoder_hidden_states, ) last_hidden_states = outputs[0] output1 = self.fc1( self.fc_dropout1(last_hidden_states).view(-1, self.config.hidden_size) ) output2 = self.fc2( encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size) ) output = self.fc3(self.fc_dropout2(torch.hstack((output1, output2)))) output = self.fc4(output) output = self.fc5(output) return output def generate_embedding(self, inputs): encoder_outputs = self.model.encoder(**inputs) encoder_hidden_states = encoder_outputs[0] outputs = self.model.decoder( input_ids=torch.full( (inputs["input_ids"].size(0), 1), self.config.decoder_start_token_id, dtype=torch.long, device=inputs["input_ids"].device, ), encoder_hidden_states=encoder_hidden_states, ) last_hidden_states = outputs[0] output1 = self.fc1( self.fc_dropout1(last_hidden_states).view(-1, self.config.hidden_size) ) output2 = self.fc2( encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size) ) return torch.hstack((output1, output2)) class ReactionT5Yield2(PreTrainedModel): config_class = AutoConfig def __init__(self, config): super().__init__(config) self.config = config self.model = T5ForConditionalGeneration.from_pretrained( self.config._name_or_path ) self.model.resize_token_embeddings(self.config.vocab_size) self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2) self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2) self.fc3 = nn.Linear(self.config.hidden_size // 2 * 2, self.config.hidden_size) self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size) self.fc5 = nn.Linear(self.config.hidden_size, 1) self._init_weights(self.fc1) self._init_weights(self.fc2) self._init_weights(self.fc3) self._init_weights(self.fc4) self._init_weights(self.fc5) def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.01) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.01) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward(self, inputs): encoder_outputs = self.model.encoder(**inputs) encoder_hidden_states = encoder_outputs[0] outputs = self.model.decoder( input_ids=torch.full( (inputs["input_ids"].size(0), 1), self.config.decoder_start_token_id, dtype=torch.long, device=inputs["input_ids"].device, ), encoder_hidden_states=encoder_hidden_states, ) last_hidden_states = outputs[0] output1 = self.fc1(last_hidden_states.view(-1, self.config.hidden_size)) output2 = self.fc2( encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size) ) output = self.fc3(torch.hstack((output1, output2))) output = self.fc4(output) output = self.fc5(output) return output * 100 def generate_embedding(self, inputs): encoder_outputs = self.model.encoder(**inputs) encoder_hidden_states = encoder_outputs[0] outputs = self.model.decoder( input_ids=torch.full( (inputs["input_ids"].size(0), 1), self.config.decoder_start_token_id, dtype=torch.long, device=inputs["input_ids"].device, ), encoder_hidden_states=encoder_hidden_states, ) last_hidden_states = outputs[0] output1 = self.fc1(last_hidden_states.view(-1, self.config.hidden_size)) output2 = self.fc2( encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size) ) return torch.hstack((output1, output2))