Spaces:
Paused
Paused
import os | |
import json | |
import sqlite3 | |
from func_timeout import func_set_timeout, FunctionTimedOut | |
from utils.bridge_content_encoder import get_matched_entries | |
from nltk.tokenize import word_tokenize | |
from nltk import ngrams | |
from whoosh.qparser import QueryParser | |
def add_a_record(question, db_id): | |
conn = sqlite3.connect('data/history/history.sqlite') | |
cursor = conn.cursor() | |
cursor.execute("INSERT INTO record (question, db_id) VALUES (?, ?)", (question, db_id)) | |
conn.commit() | |
conn.close() | |
def obtain_n_grams(sequence, max_n): | |
import nltk | |
nltk.download('punkt') | |
nltk.download('punkt_tab') | |
tokens = word_tokenize(sequence) | |
all_grams = [] | |
for n in range(1, max_n + 1): | |
all_grams.extend([" ".join(gram) for gram in ngrams(tokens, n)]) | |
return all_grams | |
# get the database cursor for a sqlite database path | |
def get_cursor_from_path(sqlite_path): | |
try: | |
if not os.path.exists(sqlite_path): | |
print("Openning a new connection %s" % sqlite_path) | |
connection = sqlite3.connect(sqlite_path, check_same_thread = False) | |
except Exception as e: | |
print(sqlite_path) | |
raise e | |
connection.text_factory = lambda b: b.decode(errors="ignore") | |
cursor = connection.cursor() | |
return cursor | |
# execute predicted sql with a time limitation | |
def execute_sql(cursor, sql): | |
cursor.execute(sql) | |
return cursor.fetchall() | |
# execute predicted sql with a long time limitation (for buiding content index) | |
def execute_sql_long_time_limitation(cursor, sql): | |
cursor.execute(sql) | |
return cursor.fetchall() | |
def check_sql_executability(generated_sql, db): | |
if generated_sql.strip() == "": | |
return "Error: empty string" | |
try: | |
cursor = get_cursor_from_path(db) | |
execute_sql(cursor, generated_sql) | |
execution_error = None | |
except FunctionTimedOut as fto: | |
print("SQL execution time out error: {}.".format(fto)) | |
execution_error = "SQL execution times out." | |
except Exception as e: | |
print("SQL execution runtime error: {}.".format(e)) | |
execution_error = str(e) | |
return execution_error | |
def is_number(s): | |
try: | |
float(s) | |
return True | |
except ValueError: | |
return False | |
def detect_special_char(name): | |
for special_char in ['(', '-', ')', ' ', '/']: | |
if special_char in name: | |
return True | |
return False | |
def add_quotation_mark(s): | |
return "`" + s + "`" | |
def get_column_contents(column_name, table_name, cursor): | |
select_column_sql = "SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL LIMIT 2;".format(column_name, table_name, column_name) | |
results = execute_sql_long_time_limitation(cursor, select_column_sql) | |
column_contents = [str(result[0]).strip() for result in results] | |
# remove empty and extremely-long contents | |
column_contents = [content for content in column_contents if len(content) != 0 and len(content) <= 25] | |
return column_contents | |
def get_matched_contents(question, searcher): | |
# Coarse-grained matching between the input text and all contents in database | |
grams = obtain_n_grams(question, 4) | |
hits = [] | |
# Parse each n-gram query into a valid Whoosh query object | |
for query in grams: | |
query_parser = QueryParser("content", schema=searcher.schema) # 'content' should match the field you are searching | |
parsed_query = query_parser.parse(query) # Convert the query string into a Whoosh Query object | |
hits.extend(searcher.search(parsed_query, limit=10)) # Perform the search with the parsed query | |
coarse_matched_contents = dict() | |
for hit in hits: | |
matched_result = json.loads(hit.raw) | |
# `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id | |
tc_name = ".".join(matched_result["id"].split("-**-")[:2]) | |
if tc_name in coarse_matched_contents.keys(): | |
if matched_result["contents"] not in coarse_matched_contents[tc_name]: | |
coarse_matched_contents[tc_name].append(matched_result["contents"]) | |
else: | |
coarse_matched_contents[tc_name] = [matched_result["contents"]] | |
fine_matched_contents = dict() | |
for tc_name, contents in coarse_matched_contents.items(): | |
# Fine-grained matching between the question and coarse matched contents | |
fm_contents = get_matched_entries(question, contents) | |
if fm_contents is None: | |
continue | |
for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents: | |
if match_score < 0.9: | |
continue | |
if tc_name in fine_matched_contents.keys(): | |
if len(fine_matched_contents[tc_name]) < 25: | |
fine_matched_contents[tc_name].append(field_value.strip()) | |
else: | |
fine_matched_contents[tc_name] = [field_value.strip()] | |
return fine_matched_contents | |
# def get_matched_contents(question, searcher): | |
# # coarse-grained matching between the input text and all contents in database | |
# grams = obtain_n_grams(question, 4) | |
# hits = [] | |
# for query in grams: | |
# hits.extend(searcher.search(query, limit = 10)) | |
# coarse_matched_contents = dict() | |
# for i in range(len(hits)): | |
# matched_result = json.loads(hits[i].raw) | |
# # `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id | |
# tc_name = ".".join(matched_result["id"].split("-**-")[:2]) | |
# if tc_name in coarse_matched_contents.keys(): | |
# if matched_result["contents"] not in coarse_matched_contents[tc_name]: | |
# coarse_matched_contents[tc_name].append(matched_result["contents"]) | |
# else: | |
# coarse_matched_contents[tc_name] = [matched_result["contents"]] | |
# fine_matched_contents = dict() | |
# for tc_name, contents in coarse_matched_contents.items(): | |
# # fine-grained matching between the question and coarse matched contents | |
# fm_contents = get_matched_entries(question, contents) | |
# if fm_contents is None: | |
# continue | |
# for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents: | |
# if match_score < 0.9: | |
# continue | |
# if tc_name in fine_matched_contents.keys(): | |
# if len(fine_matched_contents[tc_name]) < 25: | |
# fine_matched_contents[tc_name].append(field_value.strip()) | |
# else: | |
# fine_matched_contents[tc_name] = [field_value.strip()] | |
# return fine_matched_contents | |
def get_db_schema_sequence(schema): | |
schema_sequence = "database schema :\n" | |
for table in schema["schema_items"]: | |
table_name, table_comment = table["table_name"], table["table_comment"] | |
if detect_special_char(table_name): | |
table_name = add_quotation_mark(table_name) | |
# if table_comment != "": | |
# table_name += " ( comment : " + table_comment + " )" | |
column_info_list = [] | |
for column_name, column_type, column_comment, column_content, pk_indicator in \ | |
zip(table["column_names"], table["column_types"], table["column_comments"], table["column_contents"], table["pk_indicators"]): | |
if detect_special_char(column_name): | |
column_name = add_quotation_mark(column_name) | |
additional_column_info = [] | |
# column type | |
additional_column_info.append(column_type) | |
# pk indicator | |
if pk_indicator != 0: | |
additional_column_info.append("primary key") | |
# column comment | |
if column_comment != "": | |
additional_column_info.append("comment : " + column_comment) | |
# representive column values | |
if len(column_content) != 0: | |
additional_column_info.append("values : " + " , ".join(column_content)) | |
column_info_list.append(table_name + "." + column_name + " ( " + " | ".join(additional_column_info) + " )") | |
schema_sequence += "table "+ table_name + " , columns = [ " + " , ".join(column_info_list) + " ]\n" | |
if len(schema["foreign_keys"]) != 0: | |
schema_sequence += "foreign keys :\n" | |
for foreign_key in schema["foreign_keys"]: | |
for i in range(len(foreign_key)): | |
if detect_special_char(foreign_key[i]): | |
foreign_key[i] = add_quotation_mark(foreign_key[i]) | |
schema_sequence += "{}.{} = {}.{}\n".format(foreign_key[0], foreign_key[1], foreign_key[2], foreign_key[3]) | |
else: | |
schema_sequence += "foreign keys : None\n" | |
return schema_sequence.strip() | |
def get_matched_content_sequence(matched_contents): | |
content_sequence = "" | |
if len(matched_contents) != 0: | |
content_sequence += "matched contents :\n" | |
for tc_name, contents in matched_contents.items(): | |
table_name = tc_name.split(".")[0] | |
column_name = tc_name.split(".")[1] | |
if detect_special_char(table_name): | |
table_name = add_quotation_mark(table_name) | |
if detect_special_char(column_name): | |
column_name = add_quotation_mark(column_name) | |
content_sequence += table_name + "." + column_name + " ( " + " , ".join(contents) + " )\n" | |
else: | |
content_sequence = "matched contents : None" | |
return content_sequence.strip() | |
def get_db_schema(db_path, db_comments, db_id): | |
if db_id in db_comments: | |
db_comment = db_comments[db_id] | |
else: | |
db_comment = None | |
cursor = get_cursor_from_path(db_path) | |
# obtain table names | |
results = execute_sql(cursor, "SELECT name FROM sqlite_master WHERE type='table';") | |
table_names = [result[0].lower() for result in results] | |
schema = dict() | |
schema["schema_items"] = [] | |
foreign_keys = [] | |
# for each table | |
for table_name in table_names: | |
# skip SQLite system table: sqlite_sequence | |
if table_name == "sqlite_sequence": | |
continue | |
# obtain column names in the current table | |
results = execute_sql(cursor, "SELECT name, type, pk FROM PRAGMA_TABLE_INFO('{}')".format(table_name)) | |
column_names_in_one_table = [result[0].lower() for result in results] | |
column_types_in_one_table = [result[1].lower() for result in results] | |
pk_indicators_in_one_table = [result[2] for result in results] | |
column_contents = [] | |
for column_name in column_names_in_one_table: | |
column_contents.append(get_column_contents(column_name, table_name, cursor)) | |
# obtain foreign keys in the current table | |
results = execute_sql(cursor, "SELECT * FROM pragma_foreign_key_list('{}');".format(table_name)) | |
for result in results: | |
if None not in [result[3], result[2], result[4]]: | |
foreign_keys.append([table_name.lower(), result[3].lower(), result[2].lower(), result[4].lower()]) | |
# obtain comments for each schema item | |
if db_comment is not None: | |
if table_name in db_comment: # record comments for tables and columns | |
table_comment = db_comment[table_name]["table_comment"] | |
column_comments = [db_comment[table_name]["column_comments"][column_name] \ | |
if column_name in db_comment[table_name]["column_comments"] else "" \ | |
for column_name in column_names_in_one_table] | |
else: # current database has comment information, but the current table does not | |
table_comment = "" | |
column_comments = ["" for _ in column_names_in_one_table] | |
else: # current database has no comment information | |
table_comment = "" | |
column_comments = ["" for _ in column_names_in_one_table] | |
schema["schema_items"].append({ | |
"table_name": table_name, | |
"table_comment": table_comment, | |
"column_names": column_names_in_one_table, | |
"column_types": column_types_in_one_table, | |
"column_comments": column_comments, | |
"column_contents": column_contents, | |
"pk_indicators": pk_indicators_in_one_table | |
}) | |
schema["foreign_keys"] = foreign_keys | |
return schema | |