Spaces:
Paused
Paused
Commit
·
749f953
1
Parent(s):
4640e16
update utils
Browse files- utils/db_utils.py +47 -5
utils/db_utils.py
CHANGED
@@ -7,6 +7,8 @@ from utils.bridge_content_encoder import get_matched_entries
|
|
7 |
from nltk.tokenize import word_tokenize
|
8 |
from nltk import ngrams
|
9 |
|
|
|
|
|
10 |
def add_a_record(question, db_id):
|
11 |
conn = sqlite3.connect('data/history/history.sqlite')
|
12 |
cursor = conn.cursor()
|
@@ -97,15 +99,19 @@ def get_column_contents(column_name, table_name, cursor):
|
|
97 |
return column_contents
|
98 |
|
99 |
def get_matched_contents(question, searcher):
|
100 |
-
#
|
101 |
grams = obtain_n_grams(question, 4)
|
102 |
hits = []
|
|
|
|
|
103 |
for query in grams:
|
104 |
-
|
|
|
|
|
105 |
|
106 |
coarse_matched_contents = dict()
|
107 |
-
for
|
108 |
-
matched_result = json.loads(
|
109 |
# `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
|
110 |
tc_name = ".".join(matched_result["id"].split("-**-")[:2])
|
111 |
if tc_name in coarse_matched_contents.keys():
|
@@ -116,7 +122,7 @@ def get_matched_contents(question, searcher):
|
|
116 |
|
117 |
fine_matched_contents = dict()
|
118 |
for tc_name, contents in coarse_matched_contents.items():
|
119 |
-
#
|
120 |
fm_contents = get_matched_entries(question, contents)
|
121 |
|
122 |
if fm_contents is None:
|
@@ -132,6 +138,42 @@ def get_matched_contents(question, searcher):
|
|
132 |
|
133 |
return fine_matched_contents
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
def get_db_schema_sequence(schema):
|
136 |
schema_sequence = "database schema :\n"
|
137 |
for table in schema["schema_items"]:
|
|
|
7 |
from nltk.tokenize import word_tokenize
|
8 |
from nltk import ngrams
|
9 |
|
10 |
+
from whoosh.qparser import QueryParser
|
11 |
+
|
12 |
def add_a_record(question, db_id):
|
13 |
conn = sqlite3.connect('data/history/history.sqlite')
|
14 |
cursor = conn.cursor()
|
|
|
99 |
return column_contents
|
100 |
|
101 |
def get_matched_contents(question, searcher):
|
102 |
+
# Coarse-grained matching between the input text and all contents in database
|
103 |
grams = obtain_n_grams(question, 4)
|
104 |
hits = []
|
105 |
+
|
106 |
+
# Parse each n-gram query into a valid Whoosh query object
|
107 |
for query in grams:
|
108 |
+
query_parser = QueryParser("content", schema=searcher.schema) # 'content' should match the field you are searching
|
109 |
+
parsed_query = query_parser.parse(query) # Convert the query string into a Whoosh Query object
|
110 |
+
hits.extend(searcher.search(parsed_query, limit=10)) # Perform the search with the parsed query
|
111 |
|
112 |
coarse_matched_contents = dict()
|
113 |
+
for hit in hits:
|
114 |
+
matched_result = json.loads(hit.raw)
|
115 |
# `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
|
116 |
tc_name = ".".join(matched_result["id"].split("-**-")[:2])
|
117 |
if tc_name in coarse_matched_contents.keys():
|
|
|
122 |
|
123 |
fine_matched_contents = dict()
|
124 |
for tc_name, contents in coarse_matched_contents.items():
|
125 |
+
# Fine-grained matching between the question and coarse matched contents
|
126 |
fm_contents = get_matched_entries(question, contents)
|
127 |
|
128 |
if fm_contents is None:
|
|
|
138 |
|
139 |
return fine_matched_contents
|
140 |
|
141 |
+
# def get_matched_contents(question, searcher):
|
142 |
+
# # coarse-grained matching between the input text and all contents in database
|
143 |
+
# grams = obtain_n_grams(question, 4)
|
144 |
+
# hits = []
|
145 |
+
# for query in grams:
|
146 |
+
# hits.extend(searcher.search(query, limit = 10))
|
147 |
+
|
148 |
+
# coarse_matched_contents = dict()
|
149 |
+
# for i in range(len(hits)):
|
150 |
+
# matched_result = json.loads(hits[i].raw)
|
151 |
+
# # `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
|
152 |
+
# tc_name = ".".join(matched_result["id"].split("-**-")[:2])
|
153 |
+
# if tc_name in coarse_matched_contents.keys():
|
154 |
+
# if matched_result["contents"] not in coarse_matched_contents[tc_name]:
|
155 |
+
# coarse_matched_contents[tc_name].append(matched_result["contents"])
|
156 |
+
# else:
|
157 |
+
# coarse_matched_contents[tc_name] = [matched_result["contents"]]
|
158 |
+
|
159 |
+
# fine_matched_contents = dict()
|
160 |
+
# for tc_name, contents in coarse_matched_contents.items():
|
161 |
+
# # fine-grained matching between the question and coarse matched contents
|
162 |
+
# fm_contents = get_matched_entries(question, contents)
|
163 |
+
|
164 |
+
# if fm_contents is None:
|
165 |
+
# continue
|
166 |
+
# for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents:
|
167 |
+
# if match_score < 0.9:
|
168 |
+
# continue
|
169 |
+
# if tc_name in fine_matched_contents.keys():
|
170 |
+
# if len(fine_matched_contents[tc_name]) < 25:
|
171 |
+
# fine_matched_contents[tc_name].append(field_value.strip())
|
172 |
+
# else:
|
173 |
+
# fine_matched_contents[tc_name] = [field_value.strip()]
|
174 |
+
|
175 |
+
# return fine_matched_contents
|
176 |
+
|
177 |
def get_db_schema_sequence(schema):
|
178 |
schema_sequence = "database schema :\n"
|
179 |
for table in schema["schema_items"]:
|