Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from torch_geometric.nn import GeneralConv | |
from torch_geometric.data import Data | |
import torch.nn as nn | |
from torch.optim import AdamW | |
from sklearn.metrics import f1_score | |
from torch_geometric.utils import normalize_edge_index | |
from torch_geometric.utils import degree | |
class FeatureAlign(nn.Module): | |
def __init__(self, query_feature_dim, llm_feature_dim, common_dim): | |
super(FeatureAlign, self).__init__() | |
self.query_transform = nn.Linear(query_feature_dim, common_dim) | |
self.llm_transform = nn.Linear(llm_feature_dim, common_dim*2) | |
self.task_transform = nn.Linear(llm_feature_dim, common_dim) | |
def forward(self,task_id, query_features, llm_features): | |
aligned_task_features = self.task_transform(task_id) | |
aligned_query_features = self.query_transform(query_features) | |
aligned_two_features=torch.cat([aligned_task_features,aligned_query_features], 1) | |
aligned_llm_features = self.llm_transform(llm_features) | |
aligned_features = torch.cat([aligned_two_features, aligned_llm_features], 0) | |
return aligned_features | |
class EncoderDecoderNet(torch.nn.Module): | |
def __init__(self, query_feature_dim, llm_feature_dim, hidden_features, in_edges): | |
super(EncoderDecoderNet, self).__init__() | |
self.in_edges = in_edges | |
self.model_align = FeatureAlign(query_feature_dim, llm_feature_dim, hidden_features) | |
self.encoder_conv_1 = GeneralConv(in_channels=hidden_features* 2, out_channels=hidden_features* 2, in_edge_channels=in_edges) | |
self.encoder_conv_2 = GeneralConv(in_channels=hidden_features* 2, out_channels=hidden_features* 2, in_edge_channels=in_edges) | |
self.edge_mlp = nn.Linear(in_edges, in_edges) | |
self.bn1 = nn.BatchNorm1d(hidden_features * 2) | |
self.bn2 = nn.BatchNorm1d(hidden_features * 2) | |
def forward(self, task_id, query_features, llm_features, edge_index, edge_mask=None, | |
edge_can_see=None, edge_weight=None): | |
if edge_mask is not None: | |
edge_index_mask = edge_index[:, edge_can_see] | |
edge_index_predict = edge_index[:, edge_mask] | |
if edge_weight is not None: | |
edge_weight_mask = edge_weight[edge_can_see] | |
edge_weight_mask=F.leaky_relu(self.edge_mlp(edge_weight_mask.reshape(-1, self.in_edges))) | |
edge_weight_mask = edge_weight_mask.reshape(-1,self.in_edges) | |
x_ini = (self.model_align(task_id, query_features, llm_features)) | |
x = F.leaky_relu(self.bn1(self.encoder_conv_1(x_ini, edge_index_mask, edge_attr=edge_weight_mask))) | |
x = self.bn2(self.encoder_conv_2(x, edge_index_mask, edge_attr=edge_weight_mask)) | |
# x[edge_index_predict[1]] = x[edge_index_predict[1]] + x_ini[edge_index_predict[1]] | |
edge_predict = F.sigmoid( | |
(x_ini[edge_index_predict[0]] * x[edge_index_predict[1]]).mean(dim=-1)) | |
return edge_predict | |
class form_data: | |
def __init__(self,device): | |
self.device = device | |
def formulation(self,task_id,query_feature,llm_feature,org_node,des_node,edge_feature,label,edge_mask,combined_edge,train_mask,valide_mask,test_mask): | |
query_features = torch.tensor(query_feature, dtype=torch.float).to(self.device) | |
llm_features = torch.tensor(llm_feature, dtype=torch.float).to(self.device) | |
task_id=torch.tensor(task_id, dtype=torch.float).to(self.device) | |
query_indices = list(range(len(query_features))) | |
llm_indices = [i + len(query_indices) for i in range(len(llm_features))] | |
des_node=[(i+1 + org_node[-1]) for i in des_node] | |
edge_index = torch.tensor([org_node, des_node], dtype=torch.long).to(self.device) | |
edge_weight = torch.tensor(edge_feature, dtype=torch.float).reshape(-1,1).to(self.device) | |
combined_edge=torch.tensor(combined_edge, dtype=torch.float).reshape(-1,2).to(self.device) | |
combined_edge=torch.cat((edge_weight, combined_edge), dim=-1) | |
data = Data(task_id=task_id,query_features=query_features, llm_features=llm_features, edge_index=edge_index, | |
edge_attr=edge_weight,query_indices=query_indices, llm_indices=llm_indices,label=torch.tensor(label, dtype=torch.float).to(self.device),edge_mask=edge_mask,combined_edge=combined_edge, | |
train_mask=train_mask,valide_mask=valide_mask,test_mask=test_mask,org_combine=combined_edge) | |
return data | |
class GNN_prediction: | |
def __init__(self, query_feature_dim, llm_feature_dim,hidden_features_size,in_edges_size,config,device): | |
self.model = EncoderDecoderNet(query_feature_dim=query_feature_dim, llm_feature_dim=llm_feature_dim, | |
hidden_features=hidden_features_size,in_edges=in_edges_size).to(device) | |
self.config = config | |
def test(self,data,model_path, llm_names): | |
state_dict = torch.load(model_path, map_location='cpu') | |
self.model.load_state_dict(state_dict) | |
self.model.eval() | |
mask = data.edge_mask.clone().to(torch.bool) | |
edge_can_see = torch.logical_or(data.train_mask, data.valide_mask) | |
with torch.no_grad(): | |
edge_predict = self.model(task_id=data.task_id,query_features=data.query_features, llm_features=data.llm_features, edge_index=data.edge_index, | |
edge_mask=mask,edge_can_see=edge_can_see,edge_weight=data.combined_edge) | |
edge_predict = edge_predict.reshape(-1, self.config['llm_num']) | |
max_idx = torch.argmax(edge_predict, 1) | |
value_test = data.edge_attr[mask].reshape(-1, self.config['llm_num']) | |
probs = torch.softmax(edge_predict, dim=1) | |
max_idx = torch.multinomial(probs, 1).item() | |
best_llm = llm_names[max_idx] | |
## map correct API | |
print(best_llm) | |
return best_llm | |