RoutePilot / GraphRouter_eval /model /multi_task_graph_router.py
cmulgy's picture
update dependency
ff63790
import random
import numpy as np
from .graph_nn import form_data,GNN_prediction
from ..data_processing.utils import savejson,loadjson,savepkl,loadpkl
import pandas as pd
import json
import re
import yaml
from sklearn.preprocessing import MinMaxScaler
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
class graph_router_prediction:
def __init__(self, router_data_train,router_data_test,llm_path,llm_embedding_path,config):
self.config = config
self.router_data_train, self.router_data_test=router_data_train,router_data_test
self.data_df = pd.concat([self.router_data_train, self.router_data_test], ignore_index=True)
self.llm_description = loadjson(llm_path)
self.llm_names = list(self.llm_description.keys())
self.num_llms=len(self.llm_names)
self.num_query=int(len(self.data_df)/self.num_llms)
self.num_query_train = int(len(self.router_data_train) / self.num_llms)
self.num_query_test = int(len(self.router_data_test) / self.num_llms)
self.num_task=config['num_task']
self.llm_description_embedding=loadpkl(llm_embedding_path)
self.prepare_data_for_GNN()
self.split_data()
self.form_data = form_data(device)
self.query_dim = self.query_embedding_list.shape[1]
self.llm_dim = self.llm_description_embedding.shape[1]
self.GNN_predict = GNN_prediction(query_feature_dim=self.query_dim, llm_feature_dim=self.llm_dim,
hidden_features_size=self.config['embedding_dim'], in_edges_size=self.config['edge_dim'],config=self.config,device=device)
self.test_GNN()
def split_data(self):
split_ratio = self.config['split_ratio']
# Calculate the size for train,val,test
train_size = int(self.num_query_train * split_ratio[0])
val_size = int(self.num_query_train * split_ratio[1])
test_size = self.num_query_test
all_query_indices = np.arange(self.num_query_train)
np.random.shuffle(all_query_indices)
train_query_indices = all_query_indices[:train_size]
val_query_indices = all_query_indices[train_size:train_size + val_size]
# Generate indices
train_idx = []
validate_idx = []
for query_idx in train_query_indices:
start_idx = query_idx * self.num_llms
end_idx = start_idx + self.num_llms
train_idx.extend(range(start_idx, end_idx))
for query_idx in val_query_indices:
start_idx = query_idx * self.num_llms
end_idx = start_idx + self.num_llms
validate_idx.extend(range(start_idx, end_idx))
test_idx=[range(self.num_llms*self.num_query-self.num_llms*test_size,self.num_llms*self.num_query)]
self.combined_edge=np.concatenate((self.cost_list.reshape(-1,1),self.effect_list.reshape(-1,1)),axis=1)
self.scenario=self.config['scenario']
if self.scenario== "Performance First":
self.effect_list = 1.0 * self.effect_list - 0.0 * self.cost_list
elif self.scenario== "Balance":
self.effect_list = 0.5 * self.effect_list - 0.5 * self.cost_list
else:
self.effect_list = 0.2 * self.effect_list - 0.8 * self.cost_list
effect_re=self.effect_list.reshape(-1,self.num_llms)
self.label=np.eye(self.num_llms)[np.argmax(effect_re, axis=1)].reshape(-1,1)
self.edge_org_id=[num for num in range(self.num_query) for _ in range(self.num_llms)]
self.edge_des_id=list(range(self.edge_org_id[0],self.edge_org_id[0]+self.num_llms))*self.num_query
self.mask_train =torch.zeros(len(self.edge_org_id))
self.mask_train[train_idx]=1
self.mask_validate = torch.zeros(len(self.edge_org_id))
self.mask_validate[validate_idx] = 1
self.mask_test = torch.zeros(len(self.edge_org_id))
self.mask_test[test_idx] = 1
def check_tensor_values(self):
def check_array(name, array):
array = np.array(array) # Ensure it's a numpy array
has_nan = np.isnan(array).any()
out_of_bounds = ((array < 0) | (array > 1)).any()
if has_nan or out_of_bounds:
print(f"[Warning] '{name}' has invalid values:")
if has_nan:
print(f" - Contains NaN values.")
if out_of_bounds:
min_val = np.min(array)
max_val = np.max(array)
print(f" - Values outside [0, 1] range. Min: {min_val}, Max: {max_val}")
else:
print(f"[OK] '{name}' is valid (all values in [0, 1] and no NaNs).")
check_array("query_embedding_list", self.query_embedding_list)
check_array("task_embedding_list", self.task_embedding_list)
check_array("effect_list", self.effect_list)
check_array("cost_list", self.cost_list)
def prepare_data_for_GNN(self):
unique_index_list = list(range(0, len(self.data_df), self.num_llms))
query_embedding_list_raw = self.data_df['query_embedding'].tolist()
task_embedding_list_raw = self.data_df['task_description_embedding'].tolist()
self.query_embedding_list = []
self.task_embedding_list = []
def parse_embedding(tensor_str):
if pd.isna(tensor_str) or not isinstance(tensor_str, str):
return []
tensor_str = tensor_str.replace('tensor(', '').replace(')', '')
try:
values = json.loads(tensor_str)
except:
numbers = re.findall(r'[-+]?\d*\.\d+|\d+', tensor_str)
values = [float(x) for x in numbers]
return np.nan_to_num(values, nan=0.0).tolist()
# Extract and clean query embeddings
for i in range(0, len(query_embedding_list_raw), self.num_llms):
embedding = parse_embedding(query_embedding_list_raw[i])
self.query_embedding_list.append(embedding)
# Extract and clean task embeddings
for i in range(0, len(task_embedding_list_raw), self.num_llms):
embedding = parse_embedding(task_embedding_list_raw[i])
self.task_embedding_list.append(embedding)
# import pdb; pdb.set_trace()
# Convert to numpy arrays
print("\nAttempting to convert to numpy arrays...")
print(f"Query embedding list lengths: {set(len(x) for x in self.query_embedding_list)}")
print(f"Task embedding list lengths: {set(len(x) for x in self.task_embedding_list)}")
self.query_embedding_list = np.array(self.query_embedding_list)
self.task_embedding_list = np.array(self.task_embedding_list)
# Normalize embeddings to [0, 1]
def normalize_array(arr):
scaler = MinMaxScaler()
return scaler.fit_transform(arr)
self.query_embedding_list = normalize_array(self.query_embedding_list)
self.task_embedding_list = normalize_array(self.task_embedding_list)
# Process and normalize effect and cost lists
effect_raw = np.nan_to_num(self.data_df['normalized_performance'].tolist(), nan=0.0)
cost_raw = np.nan_to_num(self.data_df['normalized_cost'].tolist(), nan=0.0)
self.effect_list=effect_raw.flatten()
self.cost_list =cost_raw.flatten()
self.check_tensor_values()
def test_GNN(self):
self.data_for_test = self.form_data.formulation(task_id=self.task_embedding_list,
query_feature=self.query_embedding_list,
llm_feature=self.llm_description_embedding,
org_node=self.edge_org_id,
des_node=self.edge_des_id,
edge_feature=self.effect_list, edge_mask=self.mask_test,
label=self.label, combined_edge=self.combined_edge,
train_mask=self.mask_train, valide_mask=self.mask_validate,
test_mask=self.mask_test)
best_llm = self.GNN_predict.test(data=self.data_for_test,model_path=self.config['model_path'],llm_names=self.llm_names)
return best_llm