File size: 8,455 Bytes
41b743c
 
ff63790
 
41b743c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import random
import numpy as np
from .graph_nn import  form_data,GNN_prediction
from ..data_processing.utils import savejson,loadjson,savepkl,loadpkl
import pandas as pd
import json
import re
import yaml
from sklearn.preprocessing import MinMaxScaler
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

class graph_router_prediction:
    def __init__(self, router_data_train,router_data_test,llm_path,llm_embedding_path,config):
        self.config = config
        self.router_data_train, self.router_data_test=router_data_train,router_data_test
        self.data_df = pd.concat([self.router_data_train, self.router_data_test], ignore_index=True)
        self.llm_description = loadjson(llm_path)
        self.llm_names = list(self.llm_description.keys())
        self.num_llms=len(self.llm_names)
        self.num_query=int(len(self.data_df)/self.num_llms)
        self.num_query_train = int(len(self.router_data_train) / self.num_llms)
        self.num_query_test = int(len(self.router_data_test) / self.num_llms)
        self.num_task=config['num_task']
        self.llm_description_embedding=loadpkl(llm_embedding_path)
        self.prepare_data_for_GNN()
        self.split_data()
        self.form_data = form_data(device)
        self.query_dim = self.query_embedding_list.shape[1]
        self.llm_dim = self.llm_description_embedding.shape[1]
        self.GNN_predict = GNN_prediction(query_feature_dim=self.query_dim, llm_feature_dim=self.llm_dim,
                                    hidden_features_size=self.config['embedding_dim'], in_edges_size=self.config['edge_dim'],config=self.config,device=device)
        self.test_GNN()



    def split_data(self):

        split_ratio = self.config['split_ratio']

        # Calculate the size for train,val,test
        train_size = int(self.num_query_train * split_ratio[0])
        val_size = int(self.num_query_train * split_ratio[1])
        test_size = self.num_query_test


        all_query_indices = np.arange(self.num_query_train)
        np.random.shuffle(all_query_indices)
        train_query_indices = all_query_indices[:train_size]
        val_query_indices = all_query_indices[train_size:train_size + val_size]

        # Generate indices
        train_idx = []
        validate_idx = []


        for query_idx in train_query_indices:
            start_idx = query_idx * self.num_llms
            end_idx = start_idx + self.num_llms
            train_idx.extend(range(start_idx, end_idx))

        for query_idx in val_query_indices:
            start_idx = query_idx * self.num_llms
            end_idx = start_idx + self.num_llms
            validate_idx.extend(range(start_idx, end_idx))

        test_idx=[range(self.num_llms*self.num_query-self.num_llms*test_size,self.num_llms*self.num_query)]

        self.combined_edge=np.concatenate((self.cost_list.reshape(-1,1),self.effect_list.reshape(-1,1)),axis=1)
        self.scenario=self.config['scenario']
        if self.scenario== "Performance First":
            self.effect_list = 1.0 * self.effect_list - 0.0 * self.cost_list
        elif self.scenario== "Balance":
            self.effect_list = 0.5 * self.effect_list - 0.5 * self.cost_list
        else:
            self.effect_list = 0.2 * self.effect_list - 0.8 * self.cost_list

        effect_re=self.effect_list.reshape(-1,self.num_llms)
        self.label=np.eye(self.num_llms)[np.argmax(effect_re, axis=1)].reshape(-1,1)
        self.edge_org_id=[num for num in range(self.num_query) for _ in range(self.num_llms)]
        self.edge_des_id=list(range(self.edge_org_id[0],self.edge_org_id[0]+self.num_llms))*self.num_query

        self.mask_train =torch.zeros(len(self.edge_org_id))
        self.mask_train[train_idx]=1

        self.mask_validate = torch.zeros(len(self.edge_org_id))
        self.mask_validate[validate_idx] = 1

        self.mask_test = torch.zeros(len(self.edge_org_id))
        self.mask_test[test_idx] = 1

    def check_tensor_values(self):
        def check_array(name, array):
            array = np.array(array)  # Ensure it's a numpy array
            has_nan = np.isnan(array).any()
            out_of_bounds = ((array < 0) | (array > 1)).any()
            if has_nan or out_of_bounds:
                print(f"[Warning] '{name}' has invalid values:")
                if has_nan:
                    print(f" - Contains NaN values.")
                if out_of_bounds:
                    min_val = np.min(array)
                    max_val = np.max(array)
                    print(f" - Values outside [0, 1] range. Min: {min_val}, Max: {max_val}")
            else:
                print(f"[OK] '{name}' is valid (all values in [0, 1] and no NaNs).")

        check_array("query_embedding_list", self.query_embedding_list)
        check_array("task_embedding_list", self.task_embedding_list)
        check_array("effect_list", self.effect_list)
        check_array("cost_list", self.cost_list)
    
    def prepare_data_for_GNN(self):
        unique_index_list = list(range(0, len(self.data_df), self.num_llms))
        query_embedding_list_raw = self.data_df['query_embedding'].tolist()
        task_embedding_list_raw = self.data_df['task_description_embedding'].tolist()
        self.query_embedding_list = []
        self.task_embedding_list = []

        def parse_embedding(tensor_str):
            if pd.isna(tensor_str) or not isinstance(tensor_str, str):
                return []

            tensor_str = tensor_str.replace('tensor(', '').replace(')', '')
            try:
                values = json.loads(tensor_str)
            except:
                numbers = re.findall(r'[-+]?\d*\.\d+|\d+', tensor_str)
                values = [float(x) for x in numbers]
            return np.nan_to_num(values, nan=0.0).tolist()

        # Extract and clean query embeddings
        for i in range(0, len(query_embedding_list_raw), self.num_llms):
            embedding = parse_embedding(query_embedding_list_raw[i])
            self.query_embedding_list.append(embedding)

        # Extract and clean task embeddings
        for i in range(0, len(task_embedding_list_raw), self.num_llms):
            embedding = parse_embedding(task_embedding_list_raw[i])
            self.task_embedding_list.append(embedding)

        # import pdb; pdb.set_trace()
        # Convert to numpy arrays
        print("\nAttempting to convert to numpy arrays...")
        print(f"Query embedding list lengths: {set(len(x) for x in self.query_embedding_list)}")
        print(f"Task embedding list lengths: {set(len(x) for x in self.task_embedding_list)}")
        self.query_embedding_list = np.array(self.query_embedding_list)
        self.task_embedding_list = np.array(self.task_embedding_list)

        # Normalize embeddings to [0, 1]
        def normalize_array(arr):
            scaler = MinMaxScaler()
            return scaler.fit_transform(arr)

        self.query_embedding_list = normalize_array(self.query_embedding_list)
        self.task_embedding_list = normalize_array(self.task_embedding_list)

        # Process and normalize effect and cost lists
        effect_raw = np.nan_to_num(self.data_df['normalized_performance'].tolist(), nan=0.0)
        cost_raw = np.nan_to_num(self.data_df['normalized_cost'].tolist(), nan=0.0)

        self.effect_list=effect_raw.flatten()
        self.cost_list =cost_raw.flatten()


        self.check_tensor_values()


    def test_GNN(self):
        self.data_for_test = self.form_data.formulation(task_id=self.task_embedding_list,
                                                        query_feature=self.query_embedding_list,
                                                        llm_feature=self.llm_description_embedding,
                                                        org_node=self.edge_org_id,
                                                        des_node=self.edge_des_id,
                                                        edge_feature=self.effect_list, edge_mask=self.mask_test,
                                                        label=self.label, combined_edge=self.combined_edge,
                                                        train_mask=self.mask_train, valide_mask=self.mask_validate,
                                                        test_mask=self.mask_test)
        best_llm = self.GNN_predict.test(data=self.data_for_test,model_path=self.config['model_path'],llm_names=self.llm_names)

        return best_llm