Spaces:
Paused
Paused
Commit
·
6cb47ce
1
Parent(s):
079d156
update model weight
Browse files- schema_item_filter.py +3 -7
- text2sql.py +1 -2
schema_item_filter.py
CHANGED
@@ -18,7 +18,7 @@ def prepare_inputs_and_labels(sample, tokenizer):
|
|
18 |
input_words = [sample["text"]]
|
19 |
for table_id, table_name in enumerate(table_names):
|
20 |
input_words.append("|")
|
21 |
-
input_words.append(table_name)
|
22 |
table_name_word_indices.append(len(input_words) - 1)
|
23 |
input_words.append(":")
|
24 |
|
@@ -238,14 +238,10 @@ def lista_contains_listb(lista, listb):
|
|
238 |
class SchemaItemClassifierInference():
|
239 |
def __init__(self, model_save_path):
|
240 |
set_seed(42)
|
241 |
-
# load tokenizer
|
242 |
self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True)
|
243 |
-
#
|
244 |
self.model = SchemaItemClassifier(model_save_path, "test")
|
245 |
-
# load fine-tuned params
|
246 |
-
self.model.load_state_dict(torch.load(model_save_path + "/dense_classifier.pt", map_location=torch.device('cpu')), strict=False)
|
247 |
-
if torch.cuda.is_available():
|
248 |
-
self.model = self.model.cuda()
|
249 |
self.model.eval()
|
250 |
|
251 |
def predict_one(self, sample):
|
|
|
18 |
input_words = [sample["text"]]
|
19 |
for table_id, table_name in enumerate(table_names):
|
20 |
input_words.append("|")
|
21 |
+
input_words.append(table_name)_
|
22 |
table_name_word_indices.append(len(input_words) - 1)
|
23 |
input_words.append(":")
|
24 |
|
|
|
238 |
class SchemaItemClassifierInference():
|
239 |
def __init__(self, model_save_path):
|
240 |
set_seed(42)
|
241 |
+
# load tokenizer from Hugging Face
|
242 |
self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True)
|
243 |
+
# load model directly from Hugging Face
|
244 |
self.model = SchemaItemClassifier(model_save_path, "test")
|
|
|
|
|
|
|
|
|
245 |
self.model.eval()
|
246 |
|
247 |
def predict_one(self, sample):
|
text2sql.py
CHANGED
@@ -111,9 +111,8 @@ class ChatBot():
|
|
111 |
self.max_new_tokens = 256
|
112 |
self.max_prefix_length = self.max_length - self.max_new_tokens
|
113 |
|
114 |
-
#
|
115 |
self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
|
116 |
-
|
117 |
self.db_id2content_searcher = dict()
|
118 |
for db_id in os.listdir("db_contents_index"):
|
119 |
schema = Schema(content=TEXT(stored=True))
|
|
|
111 |
self.max_new_tokens = 256
|
112 |
self.max_prefix_length = self.max_length - self.max_new_tokens
|
113 |
|
114 |
+
# Directly loading the model from Hugging Face
|
115 |
self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
|
|
|
116 |
self.db_id2content_searcher = dict()
|
117 |
for db_id in os.listdir("db_contents_index"):
|
118 |
schema = Schema(content=TEXT(stored=True))
|