LangSQL / utils /db_utils.py
Roxanne-WANG's picture
update utils
749f953
import os
import json
import sqlite3
from func_timeout import func_set_timeout, FunctionTimedOut
from utils.bridge_content_encoder import get_matched_entries
from nltk.tokenize import word_tokenize
from nltk import ngrams
from whoosh.qparser import QueryParser
def add_a_record(question, db_id):
conn = sqlite3.connect('data/history/history.sqlite')
cursor = conn.cursor()
cursor.execute("INSERT INTO record (question, db_id) VALUES (?, ?)", (question, db_id))
conn.commit()
conn.close()
def obtain_n_grams(sequence, max_n):
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
tokens = word_tokenize(sequence)
all_grams = []
for n in range(1, max_n + 1):
all_grams.extend([" ".join(gram) for gram in ngrams(tokens, n)])
return all_grams
# get the database cursor for a sqlite database path
def get_cursor_from_path(sqlite_path):
try:
if not os.path.exists(sqlite_path):
print("Openning a new connection %s" % sqlite_path)
connection = sqlite3.connect(sqlite_path, check_same_thread = False)
except Exception as e:
print(sqlite_path)
raise e
connection.text_factory = lambda b: b.decode(errors="ignore")
cursor = connection.cursor()
return cursor
# execute predicted sql with a time limitation
@func_set_timeout(15)
def execute_sql(cursor, sql):
cursor.execute(sql)
return cursor.fetchall()
# execute predicted sql with a long time limitation (for buiding content index)
@func_set_timeout(2000)
def execute_sql_long_time_limitation(cursor, sql):
cursor.execute(sql)
return cursor.fetchall()
def check_sql_executability(generated_sql, db):
if generated_sql.strip() == "":
return "Error: empty string"
try:
cursor = get_cursor_from_path(db)
execute_sql(cursor, generated_sql)
execution_error = None
except FunctionTimedOut as fto:
print("SQL execution time out error: {}.".format(fto))
execution_error = "SQL execution times out."
except Exception as e:
print("SQL execution runtime error: {}.".format(e))
execution_error = str(e)
return execution_error
def is_number(s):
try:
float(s)
return True
except ValueError:
return False
def detect_special_char(name):
for special_char in ['(', '-', ')', ' ', '/']:
if special_char in name:
return True
return False
def add_quotation_mark(s):
return "`" + s + "`"
def get_column_contents(column_name, table_name, cursor):
select_column_sql = "SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL LIMIT 2;".format(column_name, table_name, column_name)
results = execute_sql_long_time_limitation(cursor, select_column_sql)
column_contents = [str(result[0]).strip() for result in results]
# remove empty and extremely-long contents
column_contents = [content for content in column_contents if len(content) != 0 and len(content) <= 25]
return column_contents
def get_matched_contents(question, searcher):
# Coarse-grained matching between the input text and all contents in database
grams = obtain_n_grams(question, 4)
hits = []
# Parse each n-gram query into a valid Whoosh query object
for query in grams:
query_parser = QueryParser("content", schema=searcher.schema) # 'content' should match the field you are searching
parsed_query = query_parser.parse(query) # Convert the query string into a Whoosh Query object
hits.extend(searcher.search(parsed_query, limit=10)) # Perform the search with the parsed query
coarse_matched_contents = dict()
for hit in hits:
matched_result = json.loads(hit.raw)
# `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
tc_name = ".".join(matched_result["id"].split("-**-")[:2])
if tc_name in coarse_matched_contents.keys():
if matched_result["contents"] not in coarse_matched_contents[tc_name]:
coarse_matched_contents[tc_name].append(matched_result["contents"])
else:
coarse_matched_contents[tc_name] = [matched_result["contents"]]
fine_matched_contents = dict()
for tc_name, contents in coarse_matched_contents.items():
# Fine-grained matching between the question and coarse matched contents
fm_contents = get_matched_entries(question, contents)
if fm_contents is None:
continue
for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents:
if match_score < 0.9:
continue
if tc_name in fine_matched_contents.keys():
if len(fine_matched_contents[tc_name]) < 25:
fine_matched_contents[tc_name].append(field_value.strip())
else:
fine_matched_contents[tc_name] = [field_value.strip()]
return fine_matched_contents
# def get_matched_contents(question, searcher):
# # coarse-grained matching between the input text and all contents in database
# grams = obtain_n_grams(question, 4)
# hits = []
# for query in grams:
# hits.extend(searcher.search(query, limit = 10))
# coarse_matched_contents = dict()
# for i in range(len(hits)):
# matched_result = json.loads(hits[i].raw)
# # `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
# tc_name = ".".join(matched_result["id"].split("-**-")[:2])
# if tc_name in coarse_matched_contents.keys():
# if matched_result["contents"] not in coarse_matched_contents[tc_name]:
# coarse_matched_contents[tc_name].append(matched_result["contents"])
# else:
# coarse_matched_contents[tc_name] = [matched_result["contents"]]
# fine_matched_contents = dict()
# for tc_name, contents in coarse_matched_contents.items():
# # fine-grained matching between the question and coarse matched contents
# fm_contents = get_matched_entries(question, contents)
# if fm_contents is None:
# continue
# for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents:
# if match_score < 0.9:
# continue
# if tc_name in fine_matched_contents.keys():
# if len(fine_matched_contents[tc_name]) < 25:
# fine_matched_contents[tc_name].append(field_value.strip())
# else:
# fine_matched_contents[tc_name] = [field_value.strip()]
# return fine_matched_contents
def get_db_schema_sequence(schema):
schema_sequence = "database schema :\n"
for table in schema["schema_items"]:
table_name, table_comment = table["table_name"], table["table_comment"]
if detect_special_char(table_name):
table_name = add_quotation_mark(table_name)
# if table_comment != "":
# table_name += " ( comment : " + table_comment + " )"
column_info_list = []
for column_name, column_type, column_comment, column_content, pk_indicator in \
zip(table["column_names"], table["column_types"], table["column_comments"], table["column_contents"], table["pk_indicators"]):
if detect_special_char(column_name):
column_name = add_quotation_mark(column_name)
additional_column_info = []
# column type
additional_column_info.append(column_type)
# pk indicator
if pk_indicator != 0:
additional_column_info.append("primary key")
# column comment
if column_comment != "":
additional_column_info.append("comment : " + column_comment)
# representive column values
if len(column_content) != 0:
additional_column_info.append("values : " + " , ".join(column_content))
column_info_list.append(table_name + "." + column_name + " ( " + " | ".join(additional_column_info) + " )")
schema_sequence += "table "+ table_name + " , columns = [ " + " , ".join(column_info_list) + " ]\n"
if len(schema["foreign_keys"]) != 0:
schema_sequence += "foreign keys :\n"
for foreign_key in schema["foreign_keys"]:
for i in range(len(foreign_key)):
if detect_special_char(foreign_key[i]):
foreign_key[i] = add_quotation_mark(foreign_key[i])
schema_sequence += "{}.{} = {}.{}\n".format(foreign_key[0], foreign_key[1], foreign_key[2], foreign_key[3])
else:
schema_sequence += "foreign keys : None\n"
return schema_sequence.strip()
def get_matched_content_sequence(matched_contents):
content_sequence = ""
if len(matched_contents) != 0:
content_sequence += "matched contents :\n"
for tc_name, contents in matched_contents.items():
table_name = tc_name.split(".")[0]
column_name = tc_name.split(".")[1]
if detect_special_char(table_name):
table_name = add_quotation_mark(table_name)
if detect_special_char(column_name):
column_name = add_quotation_mark(column_name)
content_sequence += table_name + "." + column_name + " ( " + " , ".join(contents) + " )\n"
else:
content_sequence = "matched contents : None"
return content_sequence.strip()
def get_db_schema(db_path, db_comments, db_id):
if db_id in db_comments:
db_comment = db_comments[db_id]
else:
db_comment = None
cursor = get_cursor_from_path(db_path)
# obtain table names
results = execute_sql(cursor, "SELECT name FROM sqlite_master WHERE type='table';")
table_names = [result[0].lower() for result in results]
schema = dict()
schema["schema_items"] = []
foreign_keys = []
# for each table
for table_name in table_names:
# skip SQLite system table: sqlite_sequence
if table_name == "sqlite_sequence":
continue
# obtain column names in the current table
results = execute_sql(cursor, "SELECT name, type, pk FROM PRAGMA_TABLE_INFO('{}')".format(table_name))
column_names_in_one_table = [result[0].lower() for result in results]
column_types_in_one_table = [result[1].lower() for result in results]
pk_indicators_in_one_table = [result[2] for result in results]
column_contents = []
for column_name in column_names_in_one_table:
column_contents.append(get_column_contents(column_name, table_name, cursor))
# obtain foreign keys in the current table
results = execute_sql(cursor, "SELECT * FROM pragma_foreign_key_list('{}');".format(table_name))
for result in results:
if None not in [result[3], result[2], result[4]]:
foreign_keys.append([table_name.lower(), result[3].lower(), result[2].lower(), result[4].lower()])
# obtain comments for each schema item
if db_comment is not None:
if table_name in db_comment: # record comments for tables and columns
table_comment = db_comment[table_name]["table_comment"]
column_comments = [db_comment[table_name]["column_comments"][column_name] \
if column_name in db_comment[table_name]["column_comments"] else "" \
for column_name in column_names_in_one_table]
else: # current database has comment information, but the current table does not
table_comment = ""
column_comments = ["" for _ in column_names_in_one_table]
else: # current database has no comment information
table_comment = ""
column_comments = ["" for _ in column_names_in_one_table]
schema["schema_items"].append({
"table_name": table_name,
"table_comment": table_comment,
"column_names": column_names_in_one_table,
"column_types": column_types_in_one_table,
"column_comments": column_comments,
"column_contents": column_contents,
"pk_indicators": pk_indicators_in_one_table
})
schema["foreign_keys"] = foreign_keys
return schema