Spaces:
Paused
Paused
import torch | |
import torch.nn as nn | |
from transformers import AutoConfig, RobertaModel | |
class SchemaItemClassifier(nn.Module): | |
def __init__(self, model_name_or_path, mode): | |
super(SchemaItemClassifier, self).__init__() | |
if mode in ["eval", "test"]: | |
# load config | |
config = AutoConfig.from_pretrained(model_name_or_path) | |
# randomly initialize model's parameters according to the config | |
self.plm_encoder = RobertaModel(config) | |
elif mode == "train": | |
self.plm_encoder = RobertaModel.from_pretrained(model_name_or_path) | |
else: | |
raise ValueError() | |
self.plm_hidden_size = self.plm_encoder.config.hidden_size | |
# column cls head | |
self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) | |
self.column_info_cls_head_linear2 = nn.Linear(256, 2) | |
# column bi-lstm layer | |
self.column_info_bilstm = nn.LSTM( | |
input_size = self.plm_hidden_size, | |
hidden_size = int(self.plm_hidden_size/2), | |
num_layers = 2, | |
dropout = 0, | |
bidirectional = True | |
) | |
# linear layer after column bi-lstm layer | |
self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) | |
# table cls head | |
self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) | |
self.table_name_cls_head_linear2 = nn.Linear(256, 2) | |
# table bi-lstm pooling layer | |
self.table_name_bilstm = nn.LSTM( | |
input_size = self.plm_hidden_size, | |
hidden_size = int(self.plm_hidden_size/2), | |
num_layers = 2, | |
dropout = 0, | |
bidirectional = True | |
) | |
# linear layer after table bi-lstm layer | |
self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) | |
# activation function | |
self.leakyrelu = nn.LeakyReLU() | |
self.tanh = nn.Tanh() | |
# table-column cross-attention layer | |
self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8) | |
# dropout function, p=0.2 means randomly set 20% neurons to 0 | |
self.dropout = nn.Dropout(p = 0.2) | |
def table_column_cross_attention( | |
self, | |
table_name_embeddings_in_one_db, | |
column_info_embeddings_in_one_db, | |
column_number_in_each_table | |
): | |
table_num = table_name_embeddings_in_one_db.shape[0] | |
table_name_embedding_attn_list = [] | |
for table_id in range(table_num): | |
table_name_embedding = table_name_embeddings_in_one_db[[table_id], :] | |
column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[ | |
sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :] | |
table_name_embedding_attn, _ = self.table_column_cross_attention_layer( | |
table_name_embedding, | |
column_info_embeddings_in_one_table, | |
column_info_embeddings_in_one_table | |
) | |
table_name_embedding_attn_list.append(table_name_embedding_attn) | |
# residual connection | |
table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0) | |
# row-wise L2 norm | |
table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1) | |
return table_name_embeddings_in_one_db | |
def table_column_cls( | |
self, | |
encoder_input_ids, | |
encoder_input_attention_mask, | |
batch_aligned_column_info_ids, | |
batch_aligned_table_name_ids, | |
batch_column_number_in_each_table | |
): | |
batch_size = encoder_input_ids.shape[0] | |
encoder_output = self.plm_encoder( | |
input_ids = encoder_input_ids, | |
attention_mask = encoder_input_attention_mask, | |
return_dict = True | |
) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size) | |
batch_table_name_cls_logits, batch_column_info_cls_logits = [], [] | |
# handle each data in current batch | |
for batch_id in range(batch_size): | |
column_number_in_each_table = batch_column_number_in_each_table[batch_id] | |
sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size) | |
# obtain table ids for each table | |
aligned_table_name_ids = batch_aligned_table_name_ids[batch_id] | |
# obtain column ids for each column | |
aligned_column_info_ids = batch_aligned_column_info_ids[batch_id] | |
table_name_embedding_list, column_info_embedding_list = [], [] | |
# obtain table embedding via bi-lstm pooling + a non-linear layer | |
for table_name_ids in aligned_table_name_ids: | |
table_name_embeddings = sequence_embeddings[table_name_ids, :] | |
# BiLSTM pooling | |
output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings) | |
table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size) | |
table_name_embedding_list.append(table_name_embedding) | |
table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0) | |
# non-linear mlp layer | |
table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db)) | |
# obtain column embedding via bi-lstm pooling + a non-linear layer | |
for column_info_ids in aligned_column_info_ids: | |
column_info_embeddings = sequence_embeddings[column_info_ids, :] | |
# BiLSTM pooling | |
output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings) | |
column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size) | |
column_info_embedding_list.append(column_info_embedding) | |
column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0) | |
# non-linear mlp layer | |
column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db)) | |
# table-column (tc) cross-attention | |
table_name_embeddings_in_one_db = self.table_column_cross_attention( | |
table_name_embeddings_in_one_db, | |
column_info_embeddings_in_one_db, | |
column_number_in_each_table | |
) | |
# calculate table 0-1 logits | |
table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db) | |
table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db)) | |
table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db) | |
# calculate column 0-1 logits | |
column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db) | |
column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db)) | |
column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db) | |
batch_table_name_cls_logits.append(table_name_cls_logits) | |
batch_column_info_cls_logits.append(column_info_cls_logits) | |
return batch_table_name_cls_logits, batch_column_info_cls_logits | |
def forward( | |
self, | |
encoder_input_ids, | |
encoder_attention_mask, | |
batch_aligned_column_info_ids, | |
batch_aligned_table_name_ids, | |
batch_column_number_in_each_table, | |
): | |
batch_table_name_cls_logits, batch_column_info_cls_logits \ | |
= self.table_column_cls( | |
encoder_input_ids, | |
encoder_attention_mask, | |
batch_aligned_column_info_ids, | |
batch_aligned_table_name_ids, | |
batch_column_number_in_each_table | |
) | |
return { | |
"batch_table_name_cls_logits" : batch_table_name_cls_logits, | |
"batch_column_info_cls_logits": batch_column_info_cls_logits | |
} |