LangSQL / utils /bridge_content_encoder.py
Roxanne-WANG's picture
update webpage
b759b87
"""
Copyright (c) 2020, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
Encode DB content.
"""
import difflib
from typing import List, Optional, Tuple
from rapidfuzz import fuzz
import sqlite3
import functools
# fmt: off
_stopwords = {'who', 'ourselves', 'down', 'only', 'were', 'him', 'at', "weren't", 'has', 'few', "it's", 'm', 'again',
'd', 'haven', 'been', 'other', 'we', 'an', 'own', 'doing', 'ma', 'hers', 'all', "haven't", 'in', 'but',
"shouldn't", 'does', 'out', 'aren', 'you', "you'd", 'himself', "isn't", 'most', 'y', 'below', 'is',
"wasn't", 'hasn', 'them', 'wouldn', 'against', 'this', 'about', 'there', 'don', "that'll", 'a', 'being',
'with', 'your', 'theirs', 'its', 'any', 'why', 'now', 'during', 'weren', 'if', 'should', 'those', 'be',
'they', 'o', 't', 'of', 'or', 'me', 'i', 'some', 'her', 'do', 'will', 'yours', 'for', 'mightn', 'nor',
'needn', 'the', 'until', "couldn't", 'he', 'which', 'yourself', 'to', "needn't", "you're", 'because',
'their', 'where', 'it', "didn't", 've', 'whom', "should've", 'can', "shan't", 'on', 'had', 'have',
'myself', 'am', "don't", 'under', 'was', "won't", 'these', 'so', 'as', 'after', 'above', 'each', 'ours',
'hadn', 'having', 'wasn', 's', 'doesn', "hadn't", 'than', 'by', 'that', 'both', 'herself', 'his',
"wouldn't", 'into', "doesn't", 'before', 'my', 'won', 'more', 'are', 'through', 'same', 'how', 'what',
'over', 'll', 'yourselves', 'up', 'mustn', "mustn't", "she's", 're', 'such', 'didn', "you'll", 'shan',
'when', "you've", 'themselves', "mightn't", 'she', 'from', 'isn', 'ain', 'between', 'once', 'here',
'shouldn', 'our', 'and', 'not', 'too', 'very', 'further', 'while', 'off', 'couldn', "hasn't", 'itself',
'then', 'did', 'just', "aren't"}
# fmt: on
_commonwords = {"no", "yes", "many"}
def is_number(s: str) -> bool:
try:
float(s.replace(",", ""))
return True
except:
return False
def is_stopword(s: str) -> bool:
return s.strip() in _stopwords
def is_commonword(s: str) -> bool:
return s.strip() in _commonwords
def is_common_db_term(s: str) -> bool:
return s.strip() in ["id"]
class Match(object):
def __init__(self, start: int, size: int) -> None:
self.start = start
self.size = size
def is_span_separator(c: str) -> bool:
return c in "'\"()`,.?! "
def split(s: str) -> List[str]:
return [c.lower() for c in s.strip()]
def prefix_match(s1: str, s2: str) -> bool:
i, j = 0, 0
for i in range(len(s1)):
if not is_span_separator(s1[i]):
break
for j in range(len(s2)):
if not is_span_separator(s2[j]):
break
if i < len(s1) and j < len(s2):
return s1[i] == s2[j]
elif i >= len(s1) and j >= len(s2):
return True
else:
return False
def get_effective_match_source(s: str, start: int, end: int) -> Match:
_start = -1
for i in range(start, start - 2, -1):
if i < 0:
_start = i + 1
break
if is_span_separator(s[i]):
_start = i
break
if _start < 0:
return None
_end = -1
for i in range(end - 1, end + 3):
if i >= len(s):
_end = i - 1
break
if is_span_separator(s[i]):
_end = i
break
if _end < 0:
return None
while _start < len(s) and is_span_separator(s[_start]):
_start += 1
while _end >= 0 and is_span_separator(s[_end]):
_end -= 1
return Match(_start, _end - _start + 1)
def get_matched_entries(
s: str, field_values: List[str], m_theta: float = 0.85, s_theta: float = 0.85
) -> Optional[List[Tuple[str, Tuple[str, str, float, float, int]]]]:
if not field_values:
return None
if isinstance(s, str):
n_grams = split(s)
else:
n_grams = s
matched = dict()
for field_value in field_values:
if not isinstance(field_value, str):
continue
fv_tokens = split(field_value)
sm = difflib.SequenceMatcher(None, n_grams, fv_tokens)
match = sm.find_longest_match(0, len(n_grams), 0, len(fv_tokens))
if match.size > 0:
source_match = get_effective_match_source(
n_grams, match.a, match.a + match.size
)
if source_match: # and source_match.size > 1
match_str = field_value[match.b : match.b + match.size]
source_match_str = s[
source_match.start : source_match.start + source_match.size
]
c_match_str = match_str.lower().strip()
c_source_match_str = source_match_str.lower().strip()
c_field_value = field_value.lower().strip()
if c_match_str and not is_common_db_term(c_match_str): # and not is_number(c_match_str)
if (
is_stopword(c_match_str)
or is_stopword(c_source_match_str)
or is_stopword(c_field_value)
):
continue
if c_source_match_str.endswith(c_match_str + "'s"):
match_score = 1.0
else:
if prefix_match(c_field_value, c_source_match_str):
match_score = fuzz.ratio(c_field_value, c_source_match_str) / 100
else:
match_score = 0
if (
is_commonword(c_match_str)
or is_commonword(c_source_match_str)
or is_commonword(c_field_value)
) and match_score < 1:
continue
s_match_score = match_score
if match_score >= m_theta and s_match_score >= s_theta:
if field_value.isupper() and match_score * s_match_score < 1:
continue
matched[match_str] = (
field_value,
source_match_str,
match_score,
s_match_score,
match.size,
)
if not matched:
return None
else:
return sorted(
matched.items(),
key=lambda x: (1e16 * x[1][2] + 1e8 * x[1][3] + x[1][4]),
reverse=True,
)
@functools.lru_cache(maxsize=1000, typed=False)
def get_column_picklist(table_name: str, column_name: str, db_path: str) -> list:
fetch_sql = "SELECT DISTINCT `{}` FROM `{}`".format(column_name, table_name)
try:
conn = sqlite3.connect(db_path)
conn.text_factory = bytes
c = conn.cursor()
c.execute(fetch_sql)
picklist = set()
for x in c.fetchall():
if isinstance(x[0], str):
picklist.add(x[0].encode("utf-8"))
elif isinstance(x[0], bytes):
try:
picklist.add(x[0].decode("utf-8"))
except UnicodeDecodeError:
picklist.add(x[0].decode("latin-1"))
else:
picklist.add(x[0])
picklist = list(picklist)
except Exception as e:
picklist = []
finally:
conn.close()
return picklist
def get_database_matches(
question: str,
table_name: str,
column_name: str,
db_path: str,
top_k_matches: int = 2,
match_threshold: float = 0.85,
) -> List[str]:
picklist = get_column_picklist(
table_name=table_name, column_name=column_name, db_path=db_path
)
# only maintain data in ``str'' type
picklist = [ele.strip() for ele in picklist if isinstance(ele, str)]
# picklist is unordered, we sort it to ensure the reproduction stability
picklist = sorted(picklist)
matches = []
if picklist and isinstance(picklist[0], str):
matched_entries = get_matched_entries(
s=question,
field_values=picklist,
m_theta=match_threshold,
s_theta=match_threshold,
)
if matched_entries:
num_values_inserted = 0
for _match_str, (
field_value,
_s_match_str,
match_score,
s_match_score,
_match_size,
) in matched_entries:
if "name" in column_name and match_score * s_match_score < 1:
continue
if table_name != "sqlite_sequence": # Spider database artifact
matches.append(field_value.strip())
num_values_inserted += 1
if num_values_inserted >= top_k_matches:
break
return matches