Roxanne-WANG commited on
Commit
036a85e
·
1 Parent(s): ebacf16

update final code

Browse files
Files changed (1) hide show
  1. text2sql.py +24 -28
text2sql.py CHANGED
@@ -106,68 +106,64 @@ class ChatBot():
106
  def __init__(self) -> None:
107
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
108
  model_name = "seeklhy/codes-1b"
109
-
110
- # Load tokenizer and model
111
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
112
- self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
113
-
114
- # Set the device for the model (this ensures it's on either GPU or CPU)
115
- self.device = self.model.device # This will get the device the model is loaded on (either CUDA or CPU)
116
-
117
- # Define other parameters
118
  self.max_length = 4096
119
  self.max_new_tokens = 256
120
  self.max_prefix_length = self.max_length - self.max_new_tokens
121
 
122
- # Load the Schema Item Classifier
123
  self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
124
-
125
- # Initialize searcher for DB content (Whoosh index)
126
  self.db_id2content_searcher = dict()
127
  for db_id in os.listdir("db_contents_index"):
128
  index_dir = os.path.join("db_contents_index", db_id)
 
 
129
  if index.exists_in(index_dir):
130
  ix = index.open_dir(index_dir)
131
- self.db_id2content_searcher[db_id] = ix
 
 
 
 
 
 
 
132
 
133
  def get_response(self, question, db_id):
134
- # Prepare the data for schema filtering
135
  data = {
136
  "text": question,
137
  "schema": copy.deepcopy(self.db_id2schema[db_id]),
138
  "matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id])
139
  }
140
-
141
- # Filter schema based on predictions
142
  data = filter_schema(data, self.sic, 6, 10)
143
  data["schema_sequence"] = get_db_schema_sequence(data["schema"])
144
  data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
145
-
146
- # Prepare input sequence for the model
147
  prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n"
 
148
 
149
- input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq, truncation=False)["input_ids"]
150
  if len(input_ids) > self.max_prefix_length:
 
151
  input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):]
152
  attention_mask = [1] * len(input_ids)
153
 
154
- # Move input tensors to the same device as the model
155
  inputs = {
156
- "input_ids": torch.tensor([input_ids], dtype=torch.int64).to(self.device),
157
- "attention_mask": torch.tensor([attention_mask], dtype=torch.int64).to(self.device)
158
  }
159
-
160
- # Generate SQL query using the model
161
  with torch.no_grad():
162
  generate_ids = self.model.generate(
163
  **inputs,
164
- max_new_tokens=self.max_new_tokens,
165
- num_beams=4,
166
- num_return_sequences=4
167
  )
168
 
169
- # Decode the generated SQL queries
170
- generated_sqls = self.tokenizer.batch_decode(generate_ids[:, len(input_ids):], skip_special_tokens=True, clean_up_tokenization_spaces=False)
171
  final_generated_sql = None
172
  for generated_sql in generated_sqls:
173
  execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite"))
 
106
  def __init__(self) -> None:
107
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
108
  model_name = "seeklhy/codes-1b"
 
 
109
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
110
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16)
 
 
 
 
 
111
  self.max_length = 4096
112
  self.max_new_tokens = 256
113
  self.max_prefix_length = self.max_length - self.max_new_tokens
114
 
115
+ # Directly loading the model from Hugging Face
116
  self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
 
 
117
  self.db_id2content_searcher = dict()
118
  for db_id in os.listdir("db_contents_index"):
119
  index_dir = os.path.join("db_contents_index", db_id)
120
+
121
+ # Open existing Whoosh index directory
122
  if index.exists_in(index_dir):
123
  ix = index.open_dir(index_dir)
124
+ # keep a searcher around for querying
125
+ self.db_id2content_searcher[db_id] = ix.searcher()
126
+ else:
127
+ raise ValueError(f"No Whoosh index found for '{db_id}' at '{index_dir}'")
128
+
129
+ self.db_ids = sorted(os.listdir("databases"))
130
+ self.db_id2schema = get_db_id2schema("databases", "data/tables.json")
131
+ self.db_id2ddl = get_db_id2ddl("databases")
132
 
133
  def get_response(self, question, db_id):
 
134
  data = {
135
  "text": question,
136
  "schema": copy.deepcopy(self.db_id2schema[db_id]),
137
  "matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id])
138
  }
 
 
139
  data = filter_schema(data, self.sic, 6, 10)
140
  data["schema_sequence"] = get_db_schema_sequence(data["schema"])
141
  data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
142
+
 
143
  prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n"
144
+ print(prefix_seq)
145
 
146
+ input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq , truncation = False)["input_ids"]
147
  if len(input_ids) > self.max_prefix_length:
148
+ print("the current input sequence exceeds the max_tokens, we will truncate it.")
149
  input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):]
150
  attention_mask = [1] * len(input_ids)
151
 
 
152
  inputs = {
153
+ "input_ids": torch.tensor([input_ids], dtype = torch.int64).to(self.model.device),
154
+ "attention_mask": torch.tensor([attention_mask], dtype = torch.int64).to(self.model.device)
155
  }
156
+ input_length = inputs["input_ids"].shape[1]
157
+
158
  with torch.no_grad():
159
  generate_ids = self.model.generate(
160
  **inputs,
161
+ max_new_tokens = self.max_new_tokens,
162
+ num_beams = 4,
163
+ num_return_sequences = 4
164
  )
165
 
166
+ generated_sqls = self.tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False)
 
167
  final_generated_sql = None
168
  for generated_sql in generated_sqls:
169
  execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite"))