Roxanne-WANG commited on
Commit
6cb47ce
·
1 Parent(s): 079d156

update model weight

Browse files
Files changed (2) hide show
  1. schema_item_filter.py +3 -7
  2. text2sql.py +1 -2
schema_item_filter.py CHANGED
@@ -18,7 +18,7 @@ def prepare_inputs_and_labels(sample, tokenizer):
18
  input_words = [sample["text"]]
19
  for table_id, table_name in enumerate(table_names):
20
  input_words.append("|")
21
- input_words.append(table_name)
22
  table_name_word_indices.append(len(input_words) - 1)
23
  input_words.append(":")
24
 
@@ -238,14 +238,10 @@ def lista_contains_listb(lista, listb):
238
  class SchemaItemClassifierInference():
239
  def __init__(self, model_save_path):
240
  set_seed(42)
241
- # load tokenizer
242
  self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True)
243
- # initialize model
244
  self.model = SchemaItemClassifier(model_save_path, "test")
245
- # load fine-tuned params
246
- self.model.load_state_dict(torch.load(model_save_path + "/dense_classifier.pt", map_location=torch.device('cpu')), strict=False)
247
- if torch.cuda.is_available():
248
- self.model = self.model.cuda()
249
  self.model.eval()
250
 
251
  def predict_one(self, sample):
 
18
  input_words = [sample["text"]]
19
  for table_id, table_name in enumerate(table_names):
20
  input_words.append("|")
21
+ input_words.append(table_name)_
22
  table_name_word_indices.append(len(input_words) - 1)
23
  input_words.append(":")
24
 
 
238
  class SchemaItemClassifierInference():
239
  def __init__(self, model_save_path):
240
  set_seed(42)
241
+ # load tokenizer from Hugging Face
242
  self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True)
243
+ # load model directly from Hugging Face
244
  self.model = SchemaItemClassifier(model_save_path, "test")
 
 
 
 
245
  self.model.eval()
246
 
247
  def predict_one(self, sample):
text2sql.py CHANGED
@@ -111,9 +111,8 @@ class ChatBot():
111
  self.max_new_tokens = 256
112
  self.max_prefix_length = self.max_length - self.max_new_tokens
113
 
114
- # self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL", token=os.getenv('HF_TOKEN'))
115
  self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
116
-
117
  self.db_id2content_searcher = dict()
118
  for db_id in os.listdir("db_contents_index"):
119
  schema = Schema(content=TEXT(stored=True))
 
111
  self.max_new_tokens = 256
112
  self.max_prefix_length = self.max_length - self.max_new_tokens
113
 
114
+ # Directly loading the model from Hugging Face
115
  self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
 
116
  self.db_id2content_searcher = dict()
117
  for db_id in os.listdir("db_contents_index"):
118
  schema = Schema(content=TEXT(stored=True))