Spaces:
Paused
Paused
""" | |
Copyright (c) 2020, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
Encode DB content. | |
""" | |
import difflib | |
from typing import List, Optional, Tuple | |
from rapidfuzz import fuzz | |
import sqlite3 | |
import functools | |
# fmt: off | |
_stopwords = {'who', 'ourselves', 'down', 'only', 'were', 'him', 'at', "weren't", 'has', 'few', "it's", 'm', 'again', | |
'd', 'haven', 'been', 'other', 'we', 'an', 'own', 'doing', 'ma', 'hers', 'all', "haven't", 'in', 'but', | |
"shouldn't", 'does', 'out', 'aren', 'you', "you'd", 'himself', "isn't", 'most', 'y', 'below', 'is', | |
"wasn't", 'hasn', 'them', 'wouldn', 'against', 'this', 'about', 'there', 'don', "that'll", 'a', 'being', | |
'with', 'your', 'theirs', 'its', 'any', 'why', 'now', 'during', 'weren', 'if', 'should', 'those', 'be', | |
'they', 'o', 't', 'of', 'or', 'me', 'i', 'some', 'her', 'do', 'will', 'yours', 'for', 'mightn', 'nor', | |
'needn', 'the', 'until', "couldn't", 'he', 'which', 'yourself', 'to', "needn't", "you're", 'because', | |
'their', 'where', 'it', "didn't", 've', 'whom', "should've", 'can', "shan't", 'on', 'had', 'have', | |
'myself', 'am', "don't", 'under', 'was', "won't", 'these', 'so', 'as', 'after', 'above', 'each', 'ours', | |
'hadn', 'having', 'wasn', 's', 'doesn', "hadn't", 'than', 'by', 'that', 'both', 'herself', 'his', | |
"wouldn't", 'into', "doesn't", 'before', 'my', 'won', 'more', 'are', 'through', 'same', 'how', 'what', | |
'over', 'll', 'yourselves', 'up', 'mustn', "mustn't", "she's", 're', 'such', 'didn', "you'll", 'shan', | |
'when', "you've", 'themselves', "mightn't", 'she', 'from', 'isn', 'ain', 'between', 'once', 'here', | |
'shouldn', 'our', 'and', 'not', 'too', 'very', 'further', 'while', 'off', 'couldn', "hasn't", 'itself', | |
'then', 'did', 'just', "aren't"} | |
# fmt: on | |
_commonwords = {"no", "yes", "many"} | |
def is_number(s: str) -> bool: | |
try: | |
float(s.replace(",", "")) | |
return True | |
except: | |
return False | |
def is_stopword(s: str) -> bool: | |
return s.strip() in _stopwords | |
def is_commonword(s: str) -> bool: | |
return s.strip() in _commonwords | |
def is_common_db_term(s: str) -> bool: | |
return s.strip() in ["id"] | |
class Match(object): | |
def __init__(self, start: int, size: int) -> None: | |
self.start = start | |
self.size = size | |
def is_span_separator(c: str) -> bool: | |
return c in "'\"()`,.?! " | |
def split(s: str) -> List[str]: | |
return [c.lower() for c in s.strip()] | |
def prefix_match(s1: str, s2: str) -> bool: | |
i, j = 0, 0 | |
for i in range(len(s1)): | |
if not is_span_separator(s1[i]): | |
break | |
for j in range(len(s2)): | |
if not is_span_separator(s2[j]): | |
break | |
if i < len(s1) and j < len(s2): | |
return s1[i] == s2[j] | |
elif i >= len(s1) and j >= len(s2): | |
return True | |
else: | |
return False | |
def get_effective_match_source(s: str, start: int, end: int) -> Match: | |
_start = -1 | |
for i in range(start, start - 2, -1): | |
if i < 0: | |
_start = i + 1 | |
break | |
if is_span_separator(s[i]): | |
_start = i | |
break | |
if _start < 0: | |
return None | |
_end = -1 | |
for i in range(end - 1, end + 3): | |
if i >= len(s): | |
_end = i - 1 | |
break | |
if is_span_separator(s[i]): | |
_end = i | |
break | |
if _end < 0: | |
return None | |
while _start < len(s) and is_span_separator(s[_start]): | |
_start += 1 | |
while _end >= 0 and is_span_separator(s[_end]): | |
_end -= 1 | |
return Match(_start, _end - _start + 1) | |
def get_matched_entries( | |
s: str, field_values: List[str], m_theta: float = 0.85, s_theta: float = 0.85 | |
) -> Optional[List[Tuple[str, Tuple[str, str, float, float, int]]]]: | |
if not field_values: | |
return None | |
if isinstance(s, str): | |
n_grams = split(s) | |
else: | |
n_grams = s | |
matched = dict() | |
for field_value in field_values: | |
if not isinstance(field_value, str): | |
continue | |
fv_tokens = split(field_value) | |
sm = difflib.SequenceMatcher(None, n_grams, fv_tokens) | |
match = sm.find_longest_match(0, len(n_grams), 0, len(fv_tokens)) | |
if match.size > 0: | |
source_match = get_effective_match_source( | |
n_grams, match.a, match.a + match.size | |
) | |
if source_match: # and source_match.size > 1 | |
match_str = field_value[match.b : match.b + match.size] | |
source_match_str = s[ | |
source_match.start : source_match.start + source_match.size | |
] | |
c_match_str = match_str.lower().strip() | |
c_source_match_str = source_match_str.lower().strip() | |
c_field_value = field_value.lower().strip() | |
if c_match_str and not is_common_db_term(c_match_str): # and not is_number(c_match_str) | |
if ( | |
is_stopword(c_match_str) | |
or is_stopword(c_source_match_str) | |
or is_stopword(c_field_value) | |
): | |
continue | |
if c_source_match_str.endswith(c_match_str + "'s"): | |
match_score = 1.0 | |
else: | |
if prefix_match(c_field_value, c_source_match_str): | |
match_score = fuzz.ratio(c_field_value, c_source_match_str) / 100 | |
else: | |
match_score = 0 | |
if ( | |
is_commonword(c_match_str) | |
or is_commonword(c_source_match_str) | |
or is_commonword(c_field_value) | |
) and match_score < 1: | |
continue | |
s_match_score = match_score | |
if match_score >= m_theta and s_match_score >= s_theta: | |
if field_value.isupper() and match_score * s_match_score < 1: | |
continue | |
matched[match_str] = ( | |
field_value, | |
source_match_str, | |
match_score, | |
s_match_score, | |
match.size, | |
) | |
if not matched: | |
return None | |
else: | |
return sorted( | |
matched.items(), | |
key=lambda x: (1e16 * x[1][2] + 1e8 * x[1][3] + x[1][4]), | |
reverse=True, | |
) | |
def get_column_picklist(table_name: str, column_name: str, db_path: str) -> list: | |
fetch_sql = "SELECT DISTINCT `{}` FROM `{}`".format(column_name, table_name) | |
try: | |
conn = sqlite3.connect(db_path) | |
conn.text_factory = bytes | |
c = conn.cursor() | |
c.execute(fetch_sql) | |
picklist = set() | |
for x in c.fetchall(): | |
if isinstance(x[0], str): | |
picklist.add(x[0].encode("utf-8")) | |
elif isinstance(x[0], bytes): | |
try: | |
picklist.add(x[0].decode("utf-8")) | |
except UnicodeDecodeError: | |
picklist.add(x[0].decode("latin-1")) | |
else: | |
picklist.add(x[0]) | |
picklist = list(picklist) | |
except Exception as e: | |
picklist = [] | |
finally: | |
conn.close() | |
return picklist | |
def get_database_matches( | |
question: str, | |
table_name: str, | |
column_name: str, | |
db_path: str, | |
top_k_matches: int = 2, | |
match_threshold: float = 0.85, | |
) -> List[str]: | |
picklist = get_column_picklist( | |
table_name=table_name, column_name=column_name, db_path=db_path | |
) | |
# only maintain data in ``str'' type | |
picklist = [ele.strip() for ele in picklist if isinstance(ele, str)] | |
# picklist is unordered, we sort it to ensure the reproduction stability | |
picklist = sorted(picklist) | |
matches = [] | |
if picklist and isinstance(picklist[0], str): | |
matched_entries = get_matched_entries( | |
s=question, | |
field_values=picklist, | |
m_theta=match_threshold, | |
s_theta=match_threshold, | |
) | |
if matched_entries: | |
num_values_inserted = 0 | |
for _match_str, ( | |
field_value, | |
_s_match_str, | |
match_score, | |
s_match_score, | |
_match_size, | |
) in matched_entries: | |
if "name" in column_name and match_score * s_match_score < 1: | |
continue | |
if table_name != "sqlite_sequence": # Spider database artifact | |
matches.append(field_value.strip()) | |
num_values_inserted += 1 | |
if num_values_inserted >= top_k_matches: | |
break | |
return matches |