cmulgy's picture
add demo
41b743c
import torch
import torch.nn.functional as F
from torch_geometric.nn import GeneralConv
from torch_geometric.data import Data
import torch.nn as nn
from torch.optim import AdamW
from sklearn.metrics import f1_score
from torch_geometric.utils import normalize_edge_index
from torch_geometric.utils import degree
class FeatureAlign(nn.Module):
def __init__(self, query_feature_dim, llm_feature_dim, common_dim):
super(FeatureAlign, self).__init__()
self.query_transform = nn.Linear(query_feature_dim, common_dim)
self.llm_transform = nn.Linear(llm_feature_dim, common_dim*2)
self.task_transform = nn.Linear(llm_feature_dim, common_dim)
def forward(self,task_id, query_features, llm_features):
aligned_task_features = self.task_transform(task_id)
aligned_query_features = self.query_transform(query_features)
aligned_two_features=torch.cat([aligned_task_features,aligned_query_features], 1)
aligned_llm_features = self.llm_transform(llm_features)
aligned_features = torch.cat([aligned_two_features, aligned_llm_features], 0)
return aligned_features
class EncoderDecoderNet(torch.nn.Module):
def __init__(self, query_feature_dim, llm_feature_dim, hidden_features, in_edges):
super(EncoderDecoderNet, self).__init__()
self.in_edges = in_edges
self.model_align = FeatureAlign(query_feature_dim, llm_feature_dim, hidden_features)
self.encoder_conv_1 = GeneralConv(in_channels=hidden_features* 2, out_channels=hidden_features* 2, in_edge_channels=in_edges)
self.encoder_conv_2 = GeneralConv(in_channels=hidden_features* 2, out_channels=hidden_features* 2, in_edge_channels=in_edges)
self.edge_mlp = nn.Linear(in_edges, in_edges)
self.bn1 = nn.BatchNorm1d(hidden_features * 2)
self.bn2 = nn.BatchNorm1d(hidden_features * 2)
def forward(self, task_id, query_features, llm_features, edge_index, edge_mask=None,
edge_can_see=None, edge_weight=None):
if edge_mask is not None:
edge_index_mask = edge_index[:, edge_can_see]
edge_index_predict = edge_index[:, edge_mask]
if edge_weight is not None:
edge_weight_mask = edge_weight[edge_can_see]
edge_weight_mask=F.leaky_relu(self.edge_mlp(edge_weight_mask.reshape(-1, self.in_edges)))
edge_weight_mask = edge_weight_mask.reshape(-1,self.in_edges)
x_ini = (self.model_align(task_id, query_features, llm_features))
x = F.leaky_relu(self.bn1(self.encoder_conv_1(x_ini, edge_index_mask, edge_attr=edge_weight_mask)))
x = self.bn2(self.encoder_conv_2(x, edge_index_mask, edge_attr=edge_weight_mask))
# x[edge_index_predict[1]] = x[edge_index_predict[1]] + x_ini[edge_index_predict[1]]
edge_predict = F.sigmoid(
(x_ini[edge_index_predict[0]] * x[edge_index_predict[1]]).mean(dim=-1))
return edge_predict
class form_data:
def __init__(self,device):
self.device = device
def formulation(self,task_id,query_feature,llm_feature,org_node,des_node,edge_feature,label,edge_mask,combined_edge,train_mask,valide_mask,test_mask):
query_features = torch.tensor(query_feature, dtype=torch.float).to(self.device)
llm_features = torch.tensor(llm_feature, dtype=torch.float).to(self.device)
task_id=torch.tensor(task_id, dtype=torch.float).to(self.device)
query_indices = list(range(len(query_features)))
llm_indices = [i + len(query_indices) for i in range(len(llm_features))]
des_node=[(i+1 + org_node[-1]) for i in des_node]
edge_index = torch.tensor([org_node, des_node], dtype=torch.long).to(self.device)
edge_weight = torch.tensor(edge_feature, dtype=torch.float).reshape(-1,1).to(self.device)
combined_edge=torch.tensor(combined_edge, dtype=torch.float).reshape(-1,2).to(self.device)
combined_edge=torch.cat((edge_weight, combined_edge), dim=-1)
data = Data(task_id=task_id,query_features=query_features, llm_features=llm_features, edge_index=edge_index,
edge_attr=edge_weight,query_indices=query_indices, llm_indices=llm_indices,label=torch.tensor(label, dtype=torch.float).to(self.device),edge_mask=edge_mask,combined_edge=combined_edge,
train_mask=train_mask,valide_mask=valide_mask,test_mask=test_mask,org_combine=combined_edge)
return data
class GNN_prediction:
def __init__(self, query_feature_dim, llm_feature_dim,hidden_features_size,in_edges_size,config,device):
self.model = EncoderDecoderNet(query_feature_dim=query_feature_dim, llm_feature_dim=llm_feature_dim,
hidden_features=hidden_features_size,in_edges=in_edges_size).to(device)
self.config = config
def test(self,data,model_path, llm_names):
state_dict = torch.load(model_path, map_location='cpu')
self.model.load_state_dict(state_dict)
self.model.eval()
mask = data.edge_mask.clone().to(torch.bool)
edge_can_see = torch.logical_or(data.train_mask, data.valide_mask)
with torch.no_grad():
edge_predict = self.model(task_id=data.task_id,query_features=data.query_features, llm_features=data.llm_features, edge_index=data.edge_index,
edge_mask=mask,edge_can_see=edge_can_see,edge_weight=data.combined_edge)
edge_predict = edge_predict.reshape(-1, self.config['llm_num'])
max_idx = torch.argmax(edge_predict, 1)
value_test = data.edge_attr[mask].reshape(-1, self.config['llm_num'])
probs = torch.softmax(edge_predict, dim=1)
max_idx = torch.multinomial(probs, 1).item()
best_llm = llm_names[max_idx]
## map correct API
print(best_llm)
return best_llm