Roxanne-WANG commited on
Commit
9ac5bfc
·
1 Parent(s): 305d669
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. text2sql.py +28 -24
.DS_Store ADDED
Binary file (8.2 kB). View file
 
text2sql.py CHANGED
@@ -106,64 +106,68 @@ class ChatBot():
106
  def __init__(self) -> None:
107
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
108
  model_name = "seeklhy/codes-1b"
 
 
109
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
110
- self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16)
 
 
 
 
 
111
  self.max_length = 4096
112
  self.max_new_tokens = 256
113
  self.max_prefix_length = self.max_length - self.max_new_tokens
114
 
115
- # Directly loading the model from Hugging Face
116
  self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
 
 
117
  self.db_id2content_searcher = dict()
118
  for db_id in os.listdir("db_contents_index"):
119
  index_dir = os.path.join("db_contents_index", db_id)
120
-
121
- # Open existing Whoosh index directory
122
  if index.exists_in(index_dir):
123
  ix = index.open_dir(index_dir)
124
- # keep a searcher around for querying
125
- self.db_id2content_searcher[db_id] = ix.searcher()
126
- else:
127
- raise ValueError(f"No Whoosh index found for '{db_id}' at '{index_dir}'")
128
-
129
- self.db_ids = sorted(os.listdir("databases"))
130
- self.db_id2schema = get_db_id2schema("databases", "data/tables.json")
131
- self.db_id2ddl = get_db_id2ddl("databases")
132
 
133
  def get_response(self, question, db_id):
 
134
  data = {
135
  "text": question,
136
  "schema": copy.deepcopy(self.db_id2schema[db_id]),
137
  "matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id])
138
  }
 
 
139
  data = filter_schema(data, self.sic, 6, 10)
140
  data["schema_sequence"] = get_db_schema_sequence(data["schema"])
141
  data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
142
-
 
143
  prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n"
144
- print(prefix_seq)
145
 
146
- input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq , truncation = False)["input_ids"]
147
  if len(input_ids) > self.max_prefix_length:
148
- print("the current input sequence exceeds the max_tokens, we will truncate it.")
149
  input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):]
150
  attention_mask = [1] * len(input_ids)
151
 
 
152
  inputs = {
153
- "input_ids": torch.tensor([input_ids], dtype = torch.int64).to(self.model.device),
154
- "attention_mask": torch.tensor([attention_mask], dtype = torch.int64).to(self.model.device)
155
  }
156
- input_length = inputs["input_ids"].shape[1]
157
-
158
  with torch.no_grad():
159
  generate_ids = self.model.generate(
160
  **inputs,
161
- max_new_tokens = self.max_new_tokens,
162
- num_beams = 4,
163
- num_return_sequences = 4
164
  )
165
 
166
- generated_sqls = self.tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False)
 
167
  final_generated_sql = None
168
  for generated_sql in generated_sqls:
169
  execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite"))
 
106
  def __init__(self) -> None:
107
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
108
  model_name = "seeklhy/codes-1b"
109
+
110
+ # Load tokenizer and model
111
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
112
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
113
+
114
+ # Set the device for the model (this ensures it's on either GPU or CPU)
115
+ self.device = self.model.device # This will get the device the model is loaded on (either CUDA or CPU)
116
+
117
+ # Define other parameters
118
  self.max_length = 4096
119
  self.max_new_tokens = 256
120
  self.max_prefix_length = self.max_length - self.max_new_tokens
121
 
122
+ # Load the Schema Item Classifier
123
  self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
124
+
125
+ # Initialize searcher for DB content (Whoosh index)
126
  self.db_id2content_searcher = dict()
127
  for db_id in os.listdir("db_contents_index"):
128
  index_dir = os.path.join("db_contents_index", db_id)
 
 
129
  if index.exists_in(index_dir):
130
  ix = index.open_dir(index_dir)
131
+ self.db_id2content_searcher[db_id] = ix
 
 
 
 
 
 
 
132
 
133
  def get_response(self, question, db_id):
134
+ # Prepare the data for schema filtering
135
  data = {
136
  "text": question,
137
  "schema": copy.deepcopy(self.db_id2schema[db_id]),
138
  "matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id])
139
  }
140
+
141
+ # Filter schema based on predictions
142
  data = filter_schema(data, self.sic, 6, 10)
143
  data["schema_sequence"] = get_db_schema_sequence(data["schema"])
144
  data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
145
+
146
+ # Prepare input sequence for the model
147
  prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n"
 
148
 
149
+ input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq, truncation=False)["input_ids"]
150
  if len(input_ids) > self.max_prefix_length:
 
151
  input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):]
152
  attention_mask = [1] * len(input_ids)
153
 
154
+ # Move input tensors to the same device as the model
155
  inputs = {
156
+ "input_ids": torch.tensor([input_ids], dtype=torch.int64).to(self.device),
157
+ "attention_mask": torch.tensor([attention_mask], dtype=torch.int64).to(self.device)
158
  }
159
+
160
+ # Generate SQL query using the model
161
  with torch.no_grad():
162
  generate_ids = self.model.generate(
163
  **inputs,
164
+ max_new_tokens=self.max_new_tokens,
165
+ num_beams=4,
166
+ num_return_sequences=4
167
  )
168
 
169
+ # Decode the generated SQL queries
170
+ generated_sqls = self.tokenizer.batch_decode(generate_ids[:, len(input_ids):], skip_special_tokens=True, clean_up_tokenization_spaces=False)
171
  final_generated_sql = None
172
  for generated_sql in generated_sqls:
173
  execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite"))