Spaces:
Paused
Paused
Commit
·
b759b87
1
Parent(s):
b423caf
update webpage
Browse files- Dockerfile +45 -0
- README.md +2 -2
- app.py +64 -0
- build_contents_index.py +75 -0
- data/.DS_Store +0 -0
- data/history/.DS_Store +0 -0
- data/tables.json +0 -0
- databases/.DS_Store +0 -0
- databases/singer/schema.sql +39 -0
- databases/singer/singer.sqlite +0 -0
- requirements.txt +18 -0
- schema_item_filter.py +343 -0
- text2sql.py +182 -0
- utils/bridge_content_encoder.py +261 -0
- utils/classifier_model.py +186 -0
- utils/db_utils.py +255 -0
- utils/translate_utils.py +23 -0
Dockerfile
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use NVIDIA CUDA base image for GPU support
|
2 |
+
FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
|
3 |
+
|
4 |
+
# Set the working directory
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Update and install system dependencies (including Java and other tools)
|
8 |
+
RUN apt-get update && \
|
9 |
+
apt-get install -y openjdk-11-jdk git rsync make build-essential libssl-dev zlib1g-dev \
|
10 |
+
libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm \
|
11 |
+
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
|
12 |
+
libffi-dev liblzma-dev git-lfs ffmpeg libsm6 libxext6 cmake \
|
13 |
+
libgl1-mesa-glx && rm -rf /var/lib/apt/lists/* && git lfs install
|
14 |
+
|
15 |
+
# Set JAVA_HOME environment variable
|
16 |
+
ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64
|
17 |
+
ENV PATH="${JAVA_HOME}/bin:${PATH}"
|
18 |
+
|
19 |
+
# Install Python version manager (pyenv) and Python 3.10
|
20 |
+
RUN curl https://pyenv.run | bash
|
21 |
+
RUN pyenv install 3.10 && pyenv global 3.10 && pyenv rehash
|
22 |
+
|
23 |
+
# Install pip and other dependencies
|
24 |
+
RUN pip install --no-cache-dir --upgrade pip
|
25 |
+
RUN pip install --no-cache-dir datasets transformers langdetect streamlit
|
26 |
+
|
27 |
+
# Install PyTorch and CUDA dependencies
|
28 |
+
RUN pip install --no-cache-dir torch==1.13.1+cu117 torchvision==0.14.1 torchaudio==0.13.1
|
29 |
+
|
30 |
+
# Copy requirements.txt and install dependencies
|
31 |
+
COPY requirements.txt /app/
|
32 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
33 |
+
|
34 |
+
# Copy the application code to the container
|
35 |
+
COPY . /app/
|
36 |
+
|
37 |
+
# Expose the port the app will run on
|
38 |
+
EXPOSE 8501
|
39 |
+
|
40 |
+
# Set the environment variable for streamlit
|
41 |
+
ENV STREAMLIT_SERVER_PORT=8501
|
42 |
+
ENV STREAMLIT_SERVER_HEADLESS=true
|
43 |
+
|
44 |
+
# Command to run the application
|
45 |
+
CMD ["streamlit", "run", "app.py"]
|
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
---
|
2 |
title: LangSQL
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: gray
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.44.1
|
|
|
1 |
---
|
2 |
title: LangSQL
|
3 |
+
emoji: 🦕
|
4 |
+
colorFrom: blue
|
5 |
colorTo: gray
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.44.1
|
app.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from text2sql import ChatBot
|
3 |
+
from langdetect import detect
|
4 |
+
from utils.translate_utils import translate_zh_to_en
|
5 |
+
from utils.db_utils import add_a_record
|
6 |
+
from langdetect.lang_detect_exception import LangDetectException
|
7 |
+
|
8 |
+
# Initialize chatbot and other variables
|
9 |
+
text2sql_bot = ChatBot()
|
10 |
+
baidu_api_token = None
|
11 |
+
|
12 |
+
# Define database schemas for demonstration
|
13 |
+
db_schemas = {
|
14 |
+
"singer": """
|
15 |
+
CREATE TABLE "singer" (
|
16 |
+
"Singer_ID" int,
|
17 |
+
"Name" text,
|
18 |
+
"Birth_Year" real,
|
19 |
+
"Net_Worth_Millions" real,
|
20 |
+
"Citizenship" text,
|
21 |
+
PRIMARY KEY ("Singer_ID")
|
22 |
+
);
|
23 |
+
|
24 |
+
CREATE TABLE "song" (
|
25 |
+
"Song_ID" int,
|
26 |
+
"Title" text,
|
27 |
+
"Singer_ID" int,
|
28 |
+
"Sales" real,
|
29 |
+
"Highest_Position" real,
|
30 |
+
PRIMARY KEY ("Song_ID"),
|
31 |
+
FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
|
32 |
+
);
|
33 |
+
""",
|
34 |
+
# Add other schemas as needed
|
35 |
+
}
|
36 |
+
|
37 |
+
# Streamlit UI
|
38 |
+
st.title("Text-to-SQL Chatbot")
|
39 |
+
st.sidebar.header("Select a Database")
|
40 |
+
|
41 |
+
# Sidebar for selecting a database
|
42 |
+
selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
|
43 |
+
|
44 |
+
# Display the selected schema
|
45 |
+
st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)
|
46 |
+
|
47 |
+
# User input section
|
48 |
+
question = st.text_input("Enter your question:")
|
49 |
+
db_id = selected_db # Use selected database for DB ID
|
50 |
+
|
51 |
+
if question:
|
52 |
+
add_a_record(question, db_id)
|
53 |
+
|
54 |
+
try:
|
55 |
+
if baidu_api_token is not None and detect(question) != "en":
|
56 |
+
print("Before translation:", question)
|
57 |
+
question = translate_zh_to_en(question, baidu_api_token)
|
58 |
+
print("After translation:", question)
|
59 |
+
except LangDetectException as e:
|
60 |
+
print("Language detection error:", str(e))
|
61 |
+
|
62 |
+
predicted_sql = text2sql_bot.get_response(question, db_id)
|
63 |
+
st.write(f"**Database:** {db_id}")
|
64 |
+
st.write(f"**Predicted SQL query:** {predicted_sql}")
|
build_contents_index.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.db_utils import get_cursor_from_path, execute_sql_long_time_limitation
|
2 |
+
import json
|
3 |
+
import os, shutil
|
4 |
+
|
5 |
+
def remove_contents_of_a_folder(index_path):
|
6 |
+
# if index_path does not exist, then create it
|
7 |
+
os.makedirs(index_path, exist_ok = True)
|
8 |
+
# remove files in index_path
|
9 |
+
for filename in os.listdir(index_path):
|
10 |
+
file_path = os.path.join(index_path, filename)
|
11 |
+
try:
|
12 |
+
if os.path.isfile(file_path) or os.path.islink(file_path):
|
13 |
+
os.unlink(file_path)
|
14 |
+
elif os.path.isdir(file_path):
|
15 |
+
shutil.rmtree(file_path)
|
16 |
+
except Exception as e:
|
17 |
+
print('Failed to delete %s. Reason: %s' % (file_path, e))
|
18 |
+
|
19 |
+
def build_content_index(db_path, index_path):
|
20 |
+
'''
|
21 |
+
Create a BM25 index for all contents in a database
|
22 |
+
'''
|
23 |
+
cursor = get_cursor_from_path(db_path)
|
24 |
+
results = execute_sql_long_time_limitation(cursor, "SELECT name FROM sqlite_master WHERE type='table';")
|
25 |
+
table_names = [result[0] for result in results]
|
26 |
+
|
27 |
+
all_column_contents = []
|
28 |
+
for table_name in table_names:
|
29 |
+
# skip SQLite system table: sqlite_sequence
|
30 |
+
if table_name == "sqlite_sequence":
|
31 |
+
continue
|
32 |
+
results = execute_sql_long_time_limitation(cursor, "SELECT name FROM PRAGMA_TABLE_INFO('{}')".format(table_name))
|
33 |
+
column_names_in_one_table = [result[0] for result in results]
|
34 |
+
for column_name in column_names_in_one_table:
|
35 |
+
try:
|
36 |
+
print("SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL;".format(column_name, table_name, column_name))
|
37 |
+
results = execute_sql_long_time_limitation(cursor, "SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL;".format(column_name, table_name, column_name))
|
38 |
+
column_contents = [str(result[0]).strip() for result in results]
|
39 |
+
|
40 |
+
for c_id, column_content in enumerate(column_contents):
|
41 |
+
# remove empty and extremely-long contents
|
42 |
+
if len(column_content) != 0 and len(column_content) <= 25:
|
43 |
+
all_column_contents.append(
|
44 |
+
{
|
45 |
+
"id": "{}-**-{}-**-{}".format(table_name, column_name, c_id).lower(),
|
46 |
+
"contents": column_content
|
47 |
+
}
|
48 |
+
)
|
49 |
+
except Exception as e:
|
50 |
+
print(str(e))
|
51 |
+
|
52 |
+
with open("./data/temp_db_index/contents.json", "w") as f:
|
53 |
+
f.write(json.dumps(all_column_contents, indent = 2, ensure_ascii = True))
|
54 |
+
|
55 |
+
# Building a BM25 Index (Direct Java Implementation), see https://github.com/castorini/pyserini/blob/master/docs/usage-index.md
|
56 |
+
cmd = "python -m pyserini.index.lucene --collection JsonCollection --input ./data/temp_db_index --index {} --generator DefaultLuceneDocumentGenerator --threads 16 --storePositions --storeDocvectors --storeRaw".format(index_path)
|
57 |
+
|
58 |
+
d = os.system(cmd)
|
59 |
+
print(d)
|
60 |
+
os.remove("./data/temp_db_index/contents.json")
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
os.makedirs('./data/temp_db_index', exist_ok = True)
|
64 |
+
|
65 |
+
print("build content index for databases...")
|
66 |
+
remove_contents_of_a_folder("db_contents_index")
|
67 |
+
# build content index for Bank_Financials's training set databases
|
68 |
+
for db_id in os.listdir("databases"):
|
69 |
+
print(db_id)
|
70 |
+
build_content_index(
|
71 |
+
os.path.join("databases", db_id, db_id + ".sqlite"),
|
72 |
+
os.path.join("db_contents_index", db_id)
|
73 |
+
)
|
74 |
+
|
75 |
+
os.rmdir('./data/temp_db_index')
|
data/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data/history/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data/tables.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
databases/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
databases/singer/schema.sql
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRAGMA foreign_keys = ON;
|
2 |
+
|
3 |
+
CREATE TABLE "singer" (
|
4 |
+
"Singer_ID" int,
|
5 |
+
"Name" text,
|
6 |
+
"Birth_Year" real,
|
7 |
+
"Net_Worth_Millions" real,
|
8 |
+
"Citizenship" text,
|
9 |
+
PRIMARY KEY ("Singer_ID")
|
10 |
+
);
|
11 |
+
|
12 |
+
CREATE TABLE "song" (
|
13 |
+
"Song_ID" int,
|
14 |
+
"Title" text,
|
15 |
+
"Singer_ID" int,
|
16 |
+
"Sales" real,
|
17 |
+
"Highest_Position" real,
|
18 |
+
PRIMARY KEY ("Song_ID"),
|
19 |
+
FOREIGN KEY ("Singer_ID") REFERENCES `singer`("Singer_ID")
|
20 |
+
);
|
21 |
+
|
22 |
+
INSERT INTO "singer" VALUES (1,"Liliane Bettencourt","1944","30.0","France");
|
23 |
+
INSERT INTO "singer" VALUES (2,"Christy Walton","1948","28.8","United States");
|
24 |
+
INSERT INTO "singer" VALUES (3,"Alice Walton","1949","26.3","United States");
|
25 |
+
INSERT INTO "singer" VALUES (4,"Iris Fontbona","1942","17.4","Chile");
|
26 |
+
INSERT INTO "singer" VALUES (5,"Jacqueline Mars","1940","17.8","United States");
|
27 |
+
INSERT INTO "singer" VALUES (6,"Gina Rinehart","1953","17","Australia");
|
28 |
+
INSERT INTO "singer" VALUES (7,"Susanne Klatten","1962","14.3","Germany");
|
29 |
+
INSERT INTO "singer" VALUES (8,"Abigail Johnson","1961","12.7","United States");
|
30 |
+
|
31 |
+
INSERT INTO "song" VALUES ("1","Do They Know It's Christmas",1,"1094000","1");
|
32 |
+
INSERT INTO "song" VALUES ("2","F**k It (I Don't Want You Back)",1,"552407","1");
|
33 |
+
INSERT INTO "song" VALUES ("3","Cha Cha Slide",2,"351421","1");
|
34 |
+
INSERT INTO "song" VALUES ("4","Call on Me",4,"335000","1");
|
35 |
+
INSERT INTO "song" VALUES ("5","Yeah",2,"300000","1");
|
36 |
+
INSERT INTO "song" VALUES ("6","All This Time",6,"292000","1");
|
37 |
+
INSERT INTO "song" VALUES ("7","Left Outside Alone",5,"275000","3");
|
38 |
+
INSERT INTO "song" VALUES ("8","Mysterious Girl",7,"261000","1");
|
39 |
+
|
databases/singer/singer.sqlite
ADDED
Binary file (20.5 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
text2sql
|
3 |
+
langdetect
|
4 |
+
faiss-cpu
|
5 |
+
func_timeout
|
6 |
+
nltk
|
7 |
+
numpy
|
8 |
+
pandas
|
9 |
+
rapidfuzz
|
10 |
+
tqdm
|
11 |
+
transformers
|
12 |
+
chardet
|
13 |
+
sqlparse
|
14 |
+
accelerate
|
15 |
+
bitsandbytes
|
16 |
+
sql_metadata
|
17 |
+
datasets
|
18 |
+
whoosh
|
schema_item_filter.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from utils.classifier_model import SchemaItemClassifier
|
8 |
+
from transformers.trainer_utils import set_seed
|
9 |
+
|
10 |
+
def prepare_inputs_and_labels(sample, tokenizer):
|
11 |
+
table_names = [table["table_name"] for table in sample["schema"]["schema_items"]]
|
12 |
+
column_names = [table["column_names"] for table in sample["schema"]["schema_items"]]
|
13 |
+
column_num_in_each_table = [len(table["column_names"]) for table in sample["schema"]["schema_items"]]
|
14 |
+
|
15 |
+
# `column_name_word_indices` and `table_name_word_indices` record the word indices of each column and table in `input_words`, whose element is an integer
|
16 |
+
column_name_word_indices, table_name_word_indices = [], []
|
17 |
+
|
18 |
+
input_words = [sample["text"]]
|
19 |
+
for table_id, table_name in enumerate(table_names):
|
20 |
+
input_words.append("|")
|
21 |
+
input_words.append(table_name)
|
22 |
+
table_name_word_indices.append(len(input_words) - 1)
|
23 |
+
input_words.append(":")
|
24 |
+
|
25 |
+
for column_name in column_names[table_id]:
|
26 |
+
input_words.append(column_name)
|
27 |
+
column_name_word_indices.append(len(input_words) - 1)
|
28 |
+
input_words.append(",")
|
29 |
+
|
30 |
+
# remove the last ","
|
31 |
+
input_words = input_words[:-1]
|
32 |
+
|
33 |
+
tokenized_inputs = tokenizer(
|
34 |
+
input_words,
|
35 |
+
return_tensors="pt",
|
36 |
+
is_split_into_words = True,
|
37 |
+
padding = "max_length",
|
38 |
+
max_length = 512,
|
39 |
+
truncation = True
|
40 |
+
)
|
41 |
+
|
42 |
+
# after tokenizing, one table name or column name may be splitted into multiple tokens (i.e., sub-words)
|
43 |
+
# `column_name_token_indices` and `table_name_token_indices` records the token indices of each column and table in `input_ids`, whose element is a list of integer
|
44 |
+
column_name_token_indices, table_name_token_indices = [], []
|
45 |
+
word_indices = tokenized_inputs.word_ids(batch_index = 0)
|
46 |
+
|
47 |
+
# obtain token indices of each column in `input_ids`
|
48 |
+
for column_name_word_index in column_name_word_indices:
|
49 |
+
column_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if column_name_word_index == word_index])
|
50 |
+
|
51 |
+
# obtain token indices of each table in `input_ids`
|
52 |
+
for table_name_word_index in table_name_word_indices:
|
53 |
+
table_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if table_name_word_index == word_index])
|
54 |
+
|
55 |
+
encoder_input_ids = tokenized_inputs["input_ids"]
|
56 |
+
encoder_input_attention_mask = tokenized_inputs["attention_mask"]
|
57 |
+
|
58 |
+
# print("\n".join(tokenizer.batch_decode(encoder_input_ids, skip_special_tokens = True)))
|
59 |
+
|
60 |
+
if torch.cuda.is_available():
|
61 |
+
encoder_input_ids = encoder_input_ids.cuda()
|
62 |
+
encoder_input_attention_mask = encoder_input_attention_mask.cuda()
|
63 |
+
|
64 |
+
return encoder_input_ids, encoder_input_attention_mask, \
|
65 |
+
column_name_token_indices, table_name_token_indices, column_num_in_each_table
|
66 |
+
|
67 |
+
def get_schema(tables_and_columns):
|
68 |
+
schema_items = []
|
69 |
+
table_names = list(dict.fromkeys([t for t, c in tables_and_columns]))
|
70 |
+
for table_name in table_names:
|
71 |
+
schema_items.append(
|
72 |
+
{
|
73 |
+
"table_name": table_name,
|
74 |
+
"column_names": [c for t, c in tables_and_columns if t == table_name]
|
75 |
+
}
|
76 |
+
)
|
77 |
+
|
78 |
+
return {"schema_items": schema_items}
|
79 |
+
|
80 |
+
def get_sequence_length(text, tables_and_columns, tokenizer):
|
81 |
+
table_names = [t for t, c in tables_and_columns]
|
82 |
+
# duplicate `table_names` while preserving order
|
83 |
+
table_names = list(dict.fromkeys(table_names))
|
84 |
+
|
85 |
+
column_names = []
|
86 |
+
for table_name in table_names:
|
87 |
+
column_names.append([c for t, c in tables_and_columns if t == table_name])
|
88 |
+
|
89 |
+
input_words = [text]
|
90 |
+
for table_id, table_name in enumerate(table_names):
|
91 |
+
input_words.append("|")
|
92 |
+
input_words.append(table_name)
|
93 |
+
input_words.append(":")
|
94 |
+
for column_name in column_names[table_id]:
|
95 |
+
input_words.append(column_name)
|
96 |
+
input_words.append(",")
|
97 |
+
# remove the last ","
|
98 |
+
input_words = input_words[:-1]
|
99 |
+
|
100 |
+
tokenized_inputs = tokenizer(input_words, is_split_into_words = True)
|
101 |
+
|
102 |
+
return len(tokenized_inputs["input_ids"])
|
103 |
+
|
104 |
+
# handle extremely long schema sequences
|
105 |
+
def split_sample(sample, tokenizer):
|
106 |
+
text = sample["text"]
|
107 |
+
|
108 |
+
table_names = []
|
109 |
+
column_names = []
|
110 |
+
for table in sample["schema"]["schema_items"]:
|
111 |
+
table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \
|
112 |
+
if table["table_comment"] != "" else table["table_name"])
|
113 |
+
column_names.append([column_name + " ( " + column_comment + " ) " \
|
114 |
+
if column_comment != "" else column_name \
|
115 |
+
for column_name, column_comment in zip(table["column_names"], table["column_comments"])])
|
116 |
+
|
117 |
+
splitted_samples = []
|
118 |
+
recorded_tables_and_columns = []
|
119 |
+
|
120 |
+
for table_idx, table_name in enumerate(table_names):
|
121 |
+
for column_name in column_names[table_idx]:
|
122 |
+
if get_sequence_length(text, recorded_tables_and_columns + [[table_name, column_name]], tokenizer) < 500:
|
123 |
+
recorded_tables_and_columns.append([table_name, column_name])
|
124 |
+
else:
|
125 |
+
splitted_samples.append(
|
126 |
+
{
|
127 |
+
"text": text,
|
128 |
+
"schema": get_schema(recorded_tables_and_columns)
|
129 |
+
}
|
130 |
+
)
|
131 |
+
recorded_tables_and_columns = [[table_name, column_name]]
|
132 |
+
|
133 |
+
splitted_samples.append(
|
134 |
+
{
|
135 |
+
"text": text,
|
136 |
+
"schema": get_schema(recorded_tables_and_columns)
|
137 |
+
}
|
138 |
+
)
|
139 |
+
|
140 |
+
return splitted_samples
|
141 |
+
|
142 |
+
def merge_pred_results(sample, pred_results):
|
143 |
+
# table_names = [table["table_name"] for table in sample["schema"]["schema_items"]]
|
144 |
+
# column_names = [table["column_names"] for table in sample["schema"]["schema_items"]]
|
145 |
+
table_names = []
|
146 |
+
column_names = []
|
147 |
+
for table in sample["schema"]["schema_items"]:
|
148 |
+
table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \
|
149 |
+
if table["table_comment"] != "" else table["table_name"])
|
150 |
+
column_names.append([column_name + " ( " + column_comment + " ) " \
|
151 |
+
if column_comment != "" else column_name \
|
152 |
+
for column_name, column_comment in zip(table["column_names"], table["column_comments"])])
|
153 |
+
|
154 |
+
merged_results = []
|
155 |
+
for table_id, table_name in enumerate(table_names):
|
156 |
+
table_prob = 0
|
157 |
+
column_probs = []
|
158 |
+
for result_dict in pred_results:
|
159 |
+
if table_name in result_dict:
|
160 |
+
if table_prob < result_dict[table_name]["table_prob"]:
|
161 |
+
table_prob = result_dict[table_name]["table_prob"]
|
162 |
+
column_probs += result_dict[table_name]["column_probs"]
|
163 |
+
|
164 |
+
merged_results.append(
|
165 |
+
{
|
166 |
+
"table_name": table_name,
|
167 |
+
"table_prob": table_prob,
|
168 |
+
"column_names": column_names[table_id],
|
169 |
+
"column_probs": column_probs
|
170 |
+
}
|
171 |
+
)
|
172 |
+
|
173 |
+
return merged_results
|
174 |
+
|
175 |
+
def filter_schema(data, sic, num_top_k_tables = 5, num_top_k_columns = 5):
|
176 |
+
filtered_schema = dict()
|
177 |
+
filtered_matched_contents = dict()
|
178 |
+
filtered_schema["schema_items"] = []
|
179 |
+
filtered_schema["foreign_keys"] = []
|
180 |
+
|
181 |
+
table_names = [table["table_name"] for table in data["schema"]["schema_items"]]
|
182 |
+
table_comments = [table["table_comment"] for table in data["schema"]["schema_items"]]
|
183 |
+
column_names = [table["column_names"] for table in data["schema"]["schema_items"]]
|
184 |
+
column_types = [table["column_types"] for table in data["schema"]["schema_items"]]
|
185 |
+
column_comments = [table["column_comments"] for table in data["schema"]["schema_items"]]
|
186 |
+
column_contents = [table["column_contents"] for table in data["schema"]["schema_items"]]
|
187 |
+
pk_indicators = [table["pk_indicators"] for table in data["schema"]["schema_items"]]
|
188 |
+
|
189 |
+
# predict scores for each tables and columns
|
190 |
+
pred_results = sic.predict(data)
|
191 |
+
# remain top_k1 tables for each database and top_k2 columns for each remained table
|
192 |
+
table_probs = [pred_result["table_prob"] for pred_result in pred_results]
|
193 |
+
table_indices = np.argsort(-np.array(table_probs), kind="stable")[:num_top_k_tables].tolist()
|
194 |
+
|
195 |
+
for table_idx in table_indices:
|
196 |
+
column_probs = pred_results[table_idx]["column_probs"]
|
197 |
+
column_indices = np.argsort(-np.array(column_probs), kind="stable")[:num_top_k_columns].tolist()
|
198 |
+
|
199 |
+
filtered_schema["schema_items"].append(
|
200 |
+
{
|
201 |
+
"table_name": table_names[table_idx],
|
202 |
+
"table_comment": table_comments[table_idx],
|
203 |
+
"column_names": [column_names[table_idx][column_idx] for column_idx in column_indices],
|
204 |
+
"column_types": [column_types[table_idx][column_idx] for column_idx in column_indices],
|
205 |
+
"column_comments": [column_comments[table_idx][column_idx] for column_idx in column_indices],
|
206 |
+
"column_contents": [column_contents[table_idx][column_idx] for column_idx in column_indices],
|
207 |
+
"pk_indicators": [pk_indicators[table_idx][column_idx] for column_idx in column_indices]
|
208 |
+
}
|
209 |
+
)
|
210 |
+
|
211 |
+
# extract matched contents of remained columns
|
212 |
+
for column_name in [column_names[table_idx][column_idx] for column_idx in column_indices]:
|
213 |
+
tc_name = "{}.{}".format(table_names[table_idx], column_name)
|
214 |
+
if tc_name in data["matched_contents"]:
|
215 |
+
filtered_matched_contents[tc_name] = data["matched_contents"][tc_name]
|
216 |
+
|
217 |
+
# extract foreign keys among remianed tables
|
218 |
+
filtered_table_names = [table_names[table_idx] for table_idx in table_indices]
|
219 |
+
for foreign_key in data["schema"]["foreign_keys"]:
|
220 |
+
source_table, source_column, target_table, target_column = foreign_key
|
221 |
+
if source_table in filtered_table_names and target_table in filtered_table_names:
|
222 |
+
filtered_schema["foreign_keys"].append(foreign_key)
|
223 |
+
|
224 |
+
# replace the old schema with the filtered schema
|
225 |
+
data["schema"] = filtered_schema
|
226 |
+
# replace the old matched contents with the filtered matched contents
|
227 |
+
data["matched_contents"] = filtered_matched_contents
|
228 |
+
|
229 |
+
return data
|
230 |
+
|
231 |
+
def lista_contains_listb(lista, listb):
|
232 |
+
for b in listb:
|
233 |
+
if b not in lista:
|
234 |
+
return 0
|
235 |
+
|
236 |
+
return 1
|
237 |
+
|
238 |
+
class SchemaItemClassifierInference():
|
239 |
+
def __init__(self, model_save_path):
|
240 |
+
set_seed(42)
|
241 |
+
# load tokenizer
|
242 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True)
|
243 |
+
# initialize model
|
244 |
+
self.model = SchemaItemClassifier(model_save_path, "test")
|
245 |
+
# load fine-tuned params
|
246 |
+
self.model.load_state_dict(torch.load(model_save_path + "/dense_classifier.pt", map_location=torch.device('cpu')), strict=False)
|
247 |
+
if torch.cuda.is_available():
|
248 |
+
self.model = self.model.cuda()
|
249 |
+
self.model.eval()
|
250 |
+
|
251 |
+
def predict_one(self, sample):
|
252 |
+
encoder_input_ids, encoder_input_attention_mask, column_name_token_indices,\
|
253 |
+
table_name_token_indices, column_num_in_each_table = prepare_inputs_and_labels(sample, self.tokenizer)
|
254 |
+
|
255 |
+
with torch.no_grad():
|
256 |
+
model_outputs = self.model(
|
257 |
+
encoder_input_ids,
|
258 |
+
encoder_input_attention_mask,
|
259 |
+
[column_name_token_indices],
|
260 |
+
[table_name_token_indices],
|
261 |
+
[column_num_in_each_table]
|
262 |
+
)
|
263 |
+
|
264 |
+
table_logits = model_outputs["batch_table_name_cls_logits"][0]
|
265 |
+
table_pred_probs = torch.nn.functional.softmax(table_logits, dim = 1)[:, 1].cpu().tolist()
|
266 |
+
|
267 |
+
column_logits = model_outputs["batch_column_info_cls_logits"][0]
|
268 |
+
column_pred_probs = torch.nn.functional.softmax(column_logits, dim = 1)[:, 1].cpu().tolist()
|
269 |
+
|
270 |
+
splitted_column_pred_probs = []
|
271 |
+
# split predicted column probs into each table
|
272 |
+
for table_id, column_num in enumerate(column_num_in_each_table):
|
273 |
+
splitted_column_pred_probs.append(column_pred_probs[sum(column_num_in_each_table[:table_id]): sum(column_num_in_each_table[:table_id]) + column_num])
|
274 |
+
column_pred_probs = splitted_column_pred_probs
|
275 |
+
|
276 |
+
result_dict = dict()
|
277 |
+
for table_idx, table in enumerate(sample["schema"]["schema_items"]):
|
278 |
+
result_dict[table["table_name"]] = {
|
279 |
+
"table_name": table["table_name"],
|
280 |
+
"table_prob": table_pred_probs[table_idx],
|
281 |
+
"column_names": table["column_names"],
|
282 |
+
"column_probs": column_pred_probs[table_idx],
|
283 |
+
}
|
284 |
+
|
285 |
+
return result_dict
|
286 |
+
|
287 |
+
def predict(self, test_sample):
|
288 |
+
splitted_samples = split_sample(test_sample, self.tokenizer)
|
289 |
+
pred_results = []
|
290 |
+
for splitted_sample in splitted_samples:
|
291 |
+
pred_results.append(self.predict_one(splitted_sample))
|
292 |
+
|
293 |
+
return merge_pred_results(test_sample, pred_results)
|
294 |
+
|
295 |
+
def evaluate_coverage(self, dataset):
|
296 |
+
max_k = 100
|
297 |
+
total_num_for_table_coverage, total_num_for_column_coverage = 0, 0
|
298 |
+
table_coverage_results = [0]*max_k
|
299 |
+
column_coverage_results = [0]*max_k
|
300 |
+
|
301 |
+
for data in dataset:
|
302 |
+
indices_of_used_tables = [idx for idx, label in enumerate(data["table_labels"]) if label == 1]
|
303 |
+
pred_results = sic.predict(data)
|
304 |
+
# print(pred_results)
|
305 |
+
table_probs = [res["table_prob"] for res in pred_results]
|
306 |
+
for k in range(max_k):
|
307 |
+
indices_of_top_k_tables = np.argsort(-np.array(table_probs), kind="stable")[:k+1].tolist()
|
308 |
+
if lista_contains_listb(indices_of_top_k_tables, indices_of_used_tables):
|
309 |
+
table_coverage_results[k] += 1
|
310 |
+
total_num_for_table_coverage += 1
|
311 |
+
|
312 |
+
for table_idx in range(len(data["table_labels"])):
|
313 |
+
indices_of_used_columns = [idx for idx, label in enumerate(data["column_labels"][table_idx]) if label == 1]
|
314 |
+
if len(indices_of_used_columns) == 0:
|
315 |
+
continue
|
316 |
+
column_probs = pred_results[table_idx]["column_probs"]
|
317 |
+
for k in range(max_k):
|
318 |
+
indices_of_top_k_columns = np.argsort(-np.array(column_probs), kind="stable")[:k+1].tolist()
|
319 |
+
if lista_contains_listb(indices_of_top_k_columns, indices_of_used_columns):
|
320 |
+
column_coverage_results[k] += 1
|
321 |
+
|
322 |
+
total_num_for_column_coverage += 1
|
323 |
+
|
324 |
+
indices_of_top_10_columns = np.argsort(-np.array(column_probs), kind="stable")[:10].tolist()
|
325 |
+
if lista_contains_listb(indices_of_top_10_columns, indices_of_used_columns) == 0:
|
326 |
+
print(pred_results[table_idx])
|
327 |
+
print(data["column_labels"][table_idx])
|
328 |
+
print(data["question"])
|
329 |
+
|
330 |
+
print(total_num_for_table_coverage)
|
331 |
+
print(table_coverage_results)
|
332 |
+
print(total_num_for_column_coverage)
|
333 |
+
print(column_coverage_results)
|
334 |
+
|
335 |
+
if __name__ == "__main__":
|
336 |
+
dataset_name = "bird_with_evidence"
|
337 |
+
# dataset_name = "bird"
|
338 |
+
# dataset_name = "spider"
|
339 |
+
sic = SchemaItemClassifierInference("sic_ckpts/sic_{}".format(dataset_name))
|
340 |
+
import json
|
341 |
+
dataset = json.load(open("./data/sft_eval_{}_text2sql.json".format(dataset_name)))
|
342 |
+
|
343 |
+
sic.evaluate_coverage(dataset)
|
text2sql.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import copy
|
5 |
+
import re
|
6 |
+
import sqlparse
|
7 |
+
import sqlite3
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
from utils.db_utils import get_db_schema
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
+
from whoosh.index import create_in
|
13 |
+
from whoosh.fields import Schema, TEXT
|
14 |
+
from whoosh.qparser import QueryParser
|
15 |
+
from utils.db_utils import check_sql_executability, get_matched_contents, get_db_schema_sequence, get_matched_content_sequence
|
16 |
+
from schema_item_filter import SchemaItemClassifierInference, filter_schema
|
17 |
+
|
18 |
+
def remove_similar_comments(names, comments):
|
19 |
+
'''
|
20 |
+
Remove table (or column) comments that have a high degree of similarity with their names
|
21 |
+
'''
|
22 |
+
new_comments = []
|
23 |
+
for name, comment in zip(names, comments):
|
24 |
+
if name.replace("_", "").replace(" ", "") == comment.replace("_", "").replace(" ", ""):
|
25 |
+
new_comments.append("")
|
26 |
+
else:
|
27 |
+
new_comments.append(comment)
|
28 |
+
|
29 |
+
return new_comments
|
30 |
+
|
31 |
+
def load_db_comments(table_json_path):
|
32 |
+
additional_db_info = json.load(open(table_json_path))
|
33 |
+
db_comments = dict()
|
34 |
+
for db_info in additional_db_info:
|
35 |
+
comment_dict = dict()
|
36 |
+
|
37 |
+
column_names = [column_name.lower() for _, column_name in db_info["column_names_original"]]
|
38 |
+
table_idx_of_each_column = [t_idx for t_idx, _ in db_info["column_names_original"]]
|
39 |
+
column_comments = [column_comment.lower() for _, column_comment in db_info["column_names"]]
|
40 |
+
|
41 |
+
assert len(column_names) == len(column_comments)
|
42 |
+
column_comments = remove_similar_comments(column_names, column_comments)
|
43 |
+
|
44 |
+
table_names = [table_name.lower() for table_name in db_info["table_names_original"]]
|
45 |
+
table_comments = [table_comment.lower() for table_comment in db_info["table_names"]]
|
46 |
+
|
47 |
+
assert len(table_names) == len(table_comments)
|
48 |
+
table_comments = remove_similar_comments(table_names, table_comments)
|
49 |
+
|
50 |
+
for table_idx, (table_name, table_comment) in enumerate(zip(table_names, table_comments)):
|
51 |
+
comment_dict[table_name] = {
|
52 |
+
"table_comment": table_comment,
|
53 |
+
"column_comments": dict()
|
54 |
+
}
|
55 |
+
for t_idx, column_name, column_comment in zip(table_idx_of_each_column, column_names, column_comments):
|
56 |
+
if t_idx == table_idx:
|
57 |
+
comment_dict[table_name]["column_comments"][column_name] = column_comment
|
58 |
+
|
59 |
+
db_comments[db_info["db_id"]] = comment_dict
|
60 |
+
|
61 |
+
return db_comments
|
62 |
+
|
63 |
+
def get_db_id2schema(db_path, tables_json):
|
64 |
+
db_comments = load_db_comments(tables_json)
|
65 |
+
db_id2schema = dict()
|
66 |
+
|
67 |
+
for db_id in tqdm(os.listdir(db_path)):
|
68 |
+
db_id2schema[db_id] = get_db_schema(os.path.join(db_path, db_id, db_id + ".sqlite"), db_comments, db_id)
|
69 |
+
|
70 |
+
return db_id2schema
|
71 |
+
|
72 |
+
def get_db_id2ddl(db_path):
|
73 |
+
db_ids = os.listdir(db_path)
|
74 |
+
db_id2ddl = dict()
|
75 |
+
|
76 |
+
for db_id in db_ids:
|
77 |
+
conn = sqlite3.connect(os.path.join(db_path, db_id, db_id + ".sqlite"))
|
78 |
+
cursor = conn.cursor()
|
79 |
+
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';")
|
80 |
+
tables = cursor.fetchall()
|
81 |
+
ddl = []
|
82 |
+
|
83 |
+
for table in tables:
|
84 |
+
table_name = table[0]
|
85 |
+
table_ddl = table[1]
|
86 |
+
table_ddl.replace("\t", " ")
|
87 |
+
while " " in table_ddl:
|
88 |
+
table_ddl = table_ddl.replace(" ", " ")
|
89 |
+
|
90 |
+
table_ddl = re.sub(r'--.*', '', table_ddl)
|
91 |
+
table_ddl = sqlparse.format(table_ddl, keyword_case = "upper", identifier_case = "lower", reindent_aligned = True)
|
92 |
+
table_ddl = table_ddl.replace(", ", ",\n ")
|
93 |
+
|
94 |
+
if table_ddl.endswith(";"):
|
95 |
+
table_ddl = table_ddl[:-1]
|
96 |
+
table_ddl = table_ddl[:-1] + "\n);"
|
97 |
+
table_ddl = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", table_ddl)
|
98 |
+
|
99 |
+
ddl.append(table_ddl)
|
100 |
+
db_id2ddl[db_id] = "\n\n".join(ddl)
|
101 |
+
|
102 |
+
return db_id2ddl
|
103 |
+
|
104 |
+
class ChatBot():
|
105 |
+
def __init__(self) -> None:
|
106 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
107 |
+
model_name = "seeklhy/codes-7b-merged"
|
108 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
109 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16)
|
110 |
+
self.max_length = 4096
|
111 |
+
self.max_new_tokens = 256
|
112 |
+
self.max_prefix_length = self.max_length - self.max_new_tokens
|
113 |
+
|
114 |
+
self.sic = SchemaItemClassifierInference("sic_ckpts/sic_bird")
|
115 |
+
|
116 |
+
self.db_id2content_searcher = dict()
|
117 |
+
for db_id in os.listdir("db_contents_index"):
|
118 |
+
schema = Schema(content=TEXT(stored=True))
|
119 |
+
index_dir = os.path.join("db_contents_index", db_id)
|
120 |
+
if not os.path.exists(index_dir):
|
121 |
+
os.makedirs(index_dir)
|
122 |
+
ix = create_in(index_dir, schema)
|
123 |
+
writer = ix.writer()
|
124 |
+
with open(os.path.join(index_dir, f"{db_id}.json"), "r") as file:
|
125 |
+
data = json.load(file)
|
126 |
+
for item in data:
|
127 |
+
writer.add_document(content=item['content'])
|
128 |
+
writer.commit()
|
129 |
+
self.db_id2content_searcher[db_id] = ix
|
130 |
+
|
131 |
+
self.db_ids = sorted(os.listdir("databases"))
|
132 |
+
self.db_id2schema = get_db_id2schema("databases", "data/tables.json")
|
133 |
+
self.db_id2ddl = get_db_id2ddl("databases")
|
134 |
+
|
135 |
+
def get_response(self, question, db_id):
|
136 |
+
data = {
|
137 |
+
"text": question,
|
138 |
+
"schema": copy.deepcopy(self.db_id2schema[db_id]),
|
139 |
+
"matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id])
|
140 |
+
}
|
141 |
+
data = filter_schema(data, self.sic, 6, 10)
|
142 |
+
data["schema_sequence"] = get_db_schema_sequence(data["schema"])
|
143 |
+
data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
|
144 |
+
|
145 |
+
prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n"
|
146 |
+
print(prefix_seq)
|
147 |
+
|
148 |
+
input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq , truncation = False)["input_ids"]
|
149 |
+
if len(input_ids) > self.max_prefix_length:
|
150 |
+
print("the current input sequence exceeds the max_tokens, we will truncate it.")
|
151 |
+
input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):]
|
152 |
+
attention_mask = [1] * len(input_ids)
|
153 |
+
|
154 |
+
inputs = {
|
155 |
+
"input_ids": torch.tensor([input_ids], dtype = torch.int64).to(self.model.device),
|
156 |
+
"attention_mask": torch.tensor([attention_mask], dtype = torch.int64).to(self.model.device)
|
157 |
+
}
|
158 |
+
input_length = inputs["input_ids"].shape[1]
|
159 |
+
|
160 |
+
with torch.no_grad():
|
161 |
+
generate_ids = self.model.generate(
|
162 |
+
**inputs,
|
163 |
+
max_new_tokens = self.max_new_tokens,
|
164 |
+
num_beams = 4,
|
165 |
+
num_return_sequences = 4
|
166 |
+
)
|
167 |
+
|
168 |
+
generated_sqls = self.tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False)
|
169 |
+
final_generated_sql = None
|
170 |
+
for generated_sql in generated_sqls:
|
171 |
+
execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite"))
|
172 |
+
if execution_error is None:
|
173 |
+
final_generated_sql = generated_sql
|
174 |
+
break
|
175 |
+
|
176 |
+
if final_generated_sql is None:
|
177 |
+
if generated_sqls[0].strip() != "":
|
178 |
+
final_generated_sql = generated_sqls[0].strip()
|
179 |
+
else:
|
180 |
+
final_generated_sql = "Sorry, I can not generate a suitable SQL query for your question."
|
181 |
+
|
182 |
+
return final_generated_sql.replace("\n", " ")
|
utils/bridge_content_encoder.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2020, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
|
7 |
+
Encode DB content.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import difflib
|
11 |
+
from typing import List, Optional, Tuple
|
12 |
+
from rapidfuzz import fuzz
|
13 |
+
import sqlite3
|
14 |
+
import functools
|
15 |
+
|
16 |
+
# fmt: off
|
17 |
+
_stopwords = {'who', 'ourselves', 'down', 'only', 'were', 'him', 'at', "weren't", 'has', 'few', "it's", 'm', 'again',
|
18 |
+
'd', 'haven', 'been', 'other', 'we', 'an', 'own', 'doing', 'ma', 'hers', 'all', "haven't", 'in', 'but',
|
19 |
+
"shouldn't", 'does', 'out', 'aren', 'you', "you'd", 'himself', "isn't", 'most', 'y', 'below', 'is',
|
20 |
+
"wasn't", 'hasn', 'them', 'wouldn', 'against', 'this', 'about', 'there', 'don', "that'll", 'a', 'being',
|
21 |
+
'with', 'your', 'theirs', 'its', 'any', 'why', 'now', 'during', 'weren', 'if', 'should', 'those', 'be',
|
22 |
+
'they', 'o', 't', 'of', 'or', 'me', 'i', 'some', 'her', 'do', 'will', 'yours', 'for', 'mightn', 'nor',
|
23 |
+
'needn', 'the', 'until', "couldn't", 'he', 'which', 'yourself', 'to', "needn't", "you're", 'because',
|
24 |
+
'their', 'where', 'it', "didn't", 've', 'whom', "should've", 'can', "shan't", 'on', 'had', 'have',
|
25 |
+
'myself', 'am', "don't", 'under', 'was', "won't", 'these', 'so', 'as', 'after', 'above', 'each', 'ours',
|
26 |
+
'hadn', 'having', 'wasn', 's', 'doesn', "hadn't", 'than', 'by', 'that', 'both', 'herself', 'his',
|
27 |
+
"wouldn't", 'into', "doesn't", 'before', 'my', 'won', 'more', 'are', 'through', 'same', 'how', 'what',
|
28 |
+
'over', 'll', 'yourselves', 'up', 'mustn', "mustn't", "she's", 're', 'such', 'didn', "you'll", 'shan',
|
29 |
+
'when', "you've", 'themselves', "mightn't", 'she', 'from', 'isn', 'ain', 'between', 'once', 'here',
|
30 |
+
'shouldn', 'our', 'and', 'not', 'too', 'very', 'further', 'while', 'off', 'couldn', "hasn't", 'itself',
|
31 |
+
'then', 'did', 'just', "aren't"}
|
32 |
+
# fmt: on
|
33 |
+
|
34 |
+
_commonwords = {"no", "yes", "many"}
|
35 |
+
|
36 |
+
|
37 |
+
def is_number(s: str) -> bool:
|
38 |
+
try:
|
39 |
+
float(s.replace(",", ""))
|
40 |
+
return True
|
41 |
+
except:
|
42 |
+
return False
|
43 |
+
|
44 |
+
|
45 |
+
def is_stopword(s: str) -> bool:
|
46 |
+
return s.strip() in _stopwords
|
47 |
+
|
48 |
+
|
49 |
+
def is_commonword(s: str) -> bool:
|
50 |
+
return s.strip() in _commonwords
|
51 |
+
|
52 |
+
|
53 |
+
def is_common_db_term(s: str) -> bool:
|
54 |
+
return s.strip() in ["id"]
|
55 |
+
|
56 |
+
|
57 |
+
class Match(object):
|
58 |
+
def __init__(self, start: int, size: int) -> None:
|
59 |
+
self.start = start
|
60 |
+
self.size = size
|
61 |
+
|
62 |
+
|
63 |
+
def is_span_separator(c: str) -> bool:
|
64 |
+
return c in "'\"()`,.?! "
|
65 |
+
|
66 |
+
|
67 |
+
def split(s: str) -> List[str]:
|
68 |
+
return [c.lower() for c in s.strip()]
|
69 |
+
|
70 |
+
|
71 |
+
def prefix_match(s1: str, s2: str) -> bool:
|
72 |
+
i, j = 0, 0
|
73 |
+
for i in range(len(s1)):
|
74 |
+
if not is_span_separator(s1[i]):
|
75 |
+
break
|
76 |
+
for j in range(len(s2)):
|
77 |
+
if not is_span_separator(s2[j]):
|
78 |
+
break
|
79 |
+
if i < len(s1) and j < len(s2):
|
80 |
+
return s1[i] == s2[j]
|
81 |
+
elif i >= len(s1) and j >= len(s2):
|
82 |
+
return True
|
83 |
+
else:
|
84 |
+
return False
|
85 |
+
|
86 |
+
|
87 |
+
def get_effective_match_source(s: str, start: int, end: int) -> Match:
|
88 |
+
_start = -1
|
89 |
+
|
90 |
+
for i in range(start, start - 2, -1):
|
91 |
+
if i < 0:
|
92 |
+
_start = i + 1
|
93 |
+
break
|
94 |
+
if is_span_separator(s[i]):
|
95 |
+
_start = i
|
96 |
+
break
|
97 |
+
|
98 |
+
if _start < 0:
|
99 |
+
return None
|
100 |
+
|
101 |
+
_end = -1
|
102 |
+
for i in range(end - 1, end + 3):
|
103 |
+
if i >= len(s):
|
104 |
+
_end = i - 1
|
105 |
+
break
|
106 |
+
if is_span_separator(s[i]):
|
107 |
+
_end = i
|
108 |
+
break
|
109 |
+
|
110 |
+
if _end < 0:
|
111 |
+
return None
|
112 |
+
|
113 |
+
while _start < len(s) and is_span_separator(s[_start]):
|
114 |
+
_start += 1
|
115 |
+
while _end >= 0 and is_span_separator(s[_end]):
|
116 |
+
_end -= 1
|
117 |
+
|
118 |
+
return Match(_start, _end - _start + 1)
|
119 |
+
|
120 |
+
|
121 |
+
def get_matched_entries(
|
122 |
+
s: str, field_values: List[str], m_theta: float = 0.85, s_theta: float = 0.85
|
123 |
+
) -> Optional[List[Tuple[str, Tuple[str, str, float, float, int]]]]:
|
124 |
+
if not field_values:
|
125 |
+
return None
|
126 |
+
|
127 |
+
if isinstance(s, str):
|
128 |
+
n_grams = split(s)
|
129 |
+
else:
|
130 |
+
n_grams = s
|
131 |
+
|
132 |
+
matched = dict()
|
133 |
+
for field_value in field_values:
|
134 |
+
if not isinstance(field_value, str):
|
135 |
+
continue
|
136 |
+
fv_tokens = split(field_value)
|
137 |
+
sm = difflib.SequenceMatcher(None, n_grams, fv_tokens)
|
138 |
+
match = sm.find_longest_match(0, len(n_grams), 0, len(fv_tokens))
|
139 |
+
if match.size > 0:
|
140 |
+
source_match = get_effective_match_source(
|
141 |
+
n_grams, match.a, match.a + match.size
|
142 |
+
)
|
143 |
+
if source_match: # and source_match.size > 1
|
144 |
+
match_str = field_value[match.b : match.b + match.size]
|
145 |
+
source_match_str = s[
|
146 |
+
source_match.start : source_match.start + source_match.size
|
147 |
+
]
|
148 |
+
c_match_str = match_str.lower().strip()
|
149 |
+
c_source_match_str = source_match_str.lower().strip()
|
150 |
+
c_field_value = field_value.lower().strip()
|
151 |
+
if c_match_str and not is_common_db_term(c_match_str): # and not is_number(c_match_str)
|
152 |
+
if (
|
153 |
+
is_stopword(c_match_str)
|
154 |
+
or is_stopword(c_source_match_str)
|
155 |
+
or is_stopword(c_field_value)
|
156 |
+
):
|
157 |
+
continue
|
158 |
+
if c_source_match_str.endswith(c_match_str + "'s"):
|
159 |
+
match_score = 1.0
|
160 |
+
else:
|
161 |
+
if prefix_match(c_field_value, c_source_match_str):
|
162 |
+
match_score = fuzz.ratio(c_field_value, c_source_match_str) / 100
|
163 |
+
else:
|
164 |
+
match_score = 0
|
165 |
+
if (
|
166 |
+
is_commonword(c_match_str)
|
167 |
+
or is_commonword(c_source_match_str)
|
168 |
+
or is_commonword(c_field_value)
|
169 |
+
) and match_score < 1:
|
170 |
+
continue
|
171 |
+
s_match_score = match_score
|
172 |
+
if match_score >= m_theta and s_match_score >= s_theta:
|
173 |
+
if field_value.isupper() and match_score * s_match_score < 1:
|
174 |
+
continue
|
175 |
+
matched[match_str] = (
|
176 |
+
field_value,
|
177 |
+
source_match_str,
|
178 |
+
match_score,
|
179 |
+
s_match_score,
|
180 |
+
match.size,
|
181 |
+
)
|
182 |
+
|
183 |
+
if not matched:
|
184 |
+
return None
|
185 |
+
else:
|
186 |
+
return sorted(
|
187 |
+
matched.items(),
|
188 |
+
key=lambda x: (1e16 * x[1][2] + 1e8 * x[1][3] + x[1][4]),
|
189 |
+
reverse=True,
|
190 |
+
)
|
191 |
+
|
192 |
+
|
193 |
+
@functools.lru_cache(maxsize=1000, typed=False)
|
194 |
+
def get_column_picklist(table_name: str, column_name: str, db_path: str) -> list:
|
195 |
+
fetch_sql = "SELECT DISTINCT `{}` FROM `{}`".format(column_name, table_name)
|
196 |
+
try:
|
197 |
+
conn = sqlite3.connect(db_path)
|
198 |
+
conn.text_factory = bytes
|
199 |
+
c = conn.cursor()
|
200 |
+
c.execute(fetch_sql)
|
201 |
+
picklist = set()
|
202 |
+
for x in c.fetchall():
|
203 |
+
if isinstance(x[0], str):
|
204 |
+
picklist.add(x[0].encode("utf-8"))
|
205 |
+
elif isinstance(x[0], bytes):
|
206 |
+
try:
|
207 |
+
picklist.add(x[0].decode("utf-8"))
|
208 |
+
except UnicodeDecodeError:
|
209 |
+
picklist.add(x[0].decode("latin-1"))
|
210 |
+
else:
|
211 |
+
picklist.add(x[0])
|
212 |
+
picklist = list(picklist)
|
213 |
+
except Exception as e:
|
214 |
+
picklist = []
|
215 |
+
finally:
|
216 |
+
conn.close()
|
217 |
+
return picklist
|
218 |
+
|
219 |
+
|
220 |
+
def get_database_matches(
|
221 |
+
question: str,
|
222 |
+
table_name: str,
|
223 |
+
column_name: str,
|
224 |
+
db_path: str,
|
225 |
+
top_k_matches: int = 2,
|
226 |
+
match_threshold: float = 0.85,
|
227 |
+
) -> List[str]:
|
228 |
+
picklist = get_column_picklist(
|
229 |
+
table_name=table_name, column_name=column_name, db_path=db_path
|
230 |
+
)
|
231 |
+
# only maintain data in ``str'' type
|
232 |
+
picklist = [ele.strip() for ele in picklist if isinstance(ele, str)]
|
233 |
+
# picklist is unordered, we sort it to ensure the reproduction stability
|
234 |
+
picklist = sorted(picklist)
|
235 |
+
|
236 |
+
matches = []
|
237 |
+
if picklist and isinstance(picklist[0], str):
|
238 |
+
matched_entries = get_matched_entries(
|
239 |
+
s=question,
|
240 |
+
field_values=picklist,
|
241 |
+
m_theta=match_threshold,
|
242 |
+
s_theta=match_threshold,
|
243 |
+
)
|
244 |
+
|
245 |
+
if matched_entries:
|
246 |
+
num_values_inserted = 0
|
247 |
+
for _match_str, (
|
248 |
+
field_value,
|
249 |
+
_s_match_str,
|
250 |
+
match_score,
|
251 |
+
s_match_score,
|
252 |
+
_match_size,
|
253 |
+
) in matched_entries:
|
254 |
+
if "name" in column_name and match_score * s_match_score < 1:
|
255 |
+
continue
|
256 |
+
if table_name != "sqlite_sequence": # Spider database artifact
|
257 |
+
matches.append(field_value.strip())
|
258 |
+
num_values_inserted += 1
|
259 |
+
if num_values_inserted >= top_k_matches:
|
260 |
+
break
|
261 |
+
return matches
|
utils/classifier_model.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import AutoConfig, RobertaModel
|
5 |
+
|
6 |
+
class SchemaItemClassifier(nn.Module):
|
7 |
+
def __init__(self, model_name_or_path, mode):
|
8 |
+
super(SchemaItemClassifier, self).__init__()
|
9 |
+
if mode in ["eval", "test"]:
|
10 |
+
# load config
|
11 |
+
config = AutoConfig.from_pretrained(model_name_or_path)
|
12 |
+
# randomly initialize model's parameters according to the config
|
13 |
+
self.plm_encoder = RobertaModel(config)
|
14 |
+
elif mode == "train":
|
15 |
+
self.plm_encoder = RobertaModel.from_pretrained(model_name_or_path)
|
16 |
+
else:
|
17 |
+
raise ValueError()
|
18 |
+
|
19 |
+
self.plm_hidden_size = self.plm_encoder.config.hidden_size
|
20 |
+
|
21 |
+
# column cls head
|
22 |
+
self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
|
23 |
+
self.column_info_cls_head_linear2 = nn.Linear(256, 2)
|
24 |
+
|
25 |
+
# column bi-lstm layer
|
26 |
+
self.column_info_bilstm = nn.LSTM(
|
27 |
+
input_size = self.plm_hidden_size,
|
28 |
+
hidden_size = int(self.plm_hidden_size/2),
|
29 |
+
num_layers = 2,
|
30 |
+
dropout = 0,
|
31 |
+
bidirectional = True
|
32 |
+
)
|
33 |
+
|
34 |
+
# linear layer after column bi-lstm layer
|
35 |
+
self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
|
36 |
+
|
37 |
+
# table cls head
|
38 |
+
self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
|
39 |
+
self.table_name_cls_head_linear2 = nn.Linear(256, 2)
|
40 |
+
|
41 |
+
# table bi-lstm pooling layer
|
42 |
+
self.table_name_bilstm = nn.LSTM(
|
43 |
+
input_size = self.plm_hidden_size,
|
44 |
+
hidden_size = int(self.plm_hidden_size/2),
|
45 |
+
num_layers = 2,
|
46 |
+
dropout = 0,
|
47 |
+
bidirectional = True
|
48 |
+
)
|
49 |
+
# linear layer after table bi-lstm layer
|
50 |
+
self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
|
51 |
+
|
52 |
+
# activation function
|
53 |
+
self.leakyrelu = nn.LeakyReLU()
|
54 |
+
self.tanh = nn.Tanh()
|
55 |
+
|
56 |
+
# table-column cross-attention layer
|
57 |
+
self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8)
|
58 |
+
|
59 |
+
# dropout function, p=0.2 means randomly set 20% neurons to 0
|
60 |
+
self.dropout = nn.Dropout(p = 0.2)
|
61 |
+
|
62 |
+
def table_column_cross_attention(
|
63 |
+
self,
|
64 |
+
table_name_embeddings_in_one_db,
|
65 |
+
column_info_embeddings_in_one_db,
|
66 |
+
column_number_in_each_table
|
67 |
+
):
|
68 |
+
table_num = table_name_embeddings_in_one_db.shape[0]
|
69 |
+
table_name_embedding_attn_list = []
|
70 |
+
for table_id in range(table_num):
|
71 |
+
table_name_embedding = table_name_embeddings_in_one_db[[table_id], :]
|
72 |
+
column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[
|
73 |
+
sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :]
|
74 |
+
|
75 |
+
table_name_embedding_attn, _ = self.table_column_cross_attention_layer(
|
76 |
+
table_name_embedding,
|
77 |
+
column_info_embeddings_in_one_table,
|
78 |
+
column_info_embeddings_in_one_table
|
79 |
+
)
|
80 |
+
|
81 |
+
table_name_embedding_attn_list.append(table_name_embedding_attn)
|
82 |
+
|
83 |
+
# residual connection
|
84 |
+
table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0)
|
85 |
+
# row-wise L2 norm
|
86 |
+
table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1)
|
87 |
+
|
88 |
+
return table_name_embeddings_in_one_db
|
89 |
+
|
90 |
+
def table_column_cls(
|
91 |
+
self,
|
92 |
+
encoder_input_ids,
|
93 |
+
encoder_input_attention_mask,
|
94 |
+
batch_aligned_column_info_ids,
|
95 |
+
batch_aligned_table_name_ids,
|
96 |
+
batch_column_number_in_each_table
|
97 |
+
):
|
98 |
+
batch_size = encoder_input_ids.shape[0]
|
99 |
+
|
100 |
+
encoder_output = self.plm_encoder(
|
101 |
+
input_ids = encoder_input_ids,
|
102 |
+
attention_mask = encoder_input_attention_mask,
|
103 |
+
return_dict = True
|
104 |
+
) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size)
|
105 |
+
|
106 |
+
batch_table_name_cls_logits, batch_column_info_cls_logits = [], []
|
107 |
+
|
108 |
+
# handle each data in current batch
|
109 |
+
for batch_id in range(batch_size):
|
110 |
+
column_number_in_each_table = batch_column_number_in_each_table[batch_id]
|
111 |
+
sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size)
|
112 |
+
|
113 |
+
# obtain table ids for each table
|
114 |
+
aligned_table_name_ids = batch_aligned_table_name_ids[batch_id]
|
115 |
+
# obtain column ids for each column
|
116 |
+
aligned_column_info_ids = batch_aligned_column_info_ids[batch_id]
|
117 |
+
|
118 |
+
table_name_embedding_list, column_info_embedding_list = [], []
|
119 |
+
|
120 |
+
# obtain table embedding via bi-lstm pooling + a non-linear layer
|
121 |
+
for table_name_ids in aligned_table_name_ids:
|
122 |
+
table_name_embeddings = sequence_embeddings[table_name_ids, :]
|
123 |
+
|
124 |
+
# BiLSTM pooling
|
125 |
+
output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings)
|
126 |
+
table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size)
|
127 |
+
table_name_embedding_list.append(table_name_embedding)
|
128 |
+
table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0)
|
129 |
+
# non-linear mlp layer
|
130 |
+
table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db))
|
131 |
+
|
132 |
+
# obtain column embedding via bi-lstm pooling + a non-linear layer
|
133 |
+
for column_info_ids in aligned_column_info_ids:
|
134 |
+
column_info_embeddings = sequence_embeddings[column_info_ids, :]
|
135 |
+
|
136 |
+
# BiLSTM pooling
|
137 |
+
output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings)
|
138 |
+
column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size)
|
139 |
+
column_info_embedding_list.append(column_info_embedding)
|
140 |
+
column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0)
|
141 |
+
# non-linear mlp layer
|
142 |
+
column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db))
|
143 |
+
|
144 |
+
# table-column (tc) cross-attention
|
145 |
+
table_name_embeddings_in_one_db = self.table_column_cross_attention(
|
146 |
+
table_name_embeddings_in_one_db,
|
147 |
+
column_info_embeddings_in_one_db,
|
148 |
+
column_number_in_each_table
|
149 |
+
)
|
150 |
+
|
151 |
+
# calculate table 0-1 logits
|
152 |
+
table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db)
|
153 |
+
table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db))
|
154 |
+
table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db)
|
155 |
+
|
156 |
+
# calculate column 0-1 logits
|
157 |
+
column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db)
|
158 |
+
column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db))
|
159 |
+
column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db)
|
160 |
+
|
161 |
+
batch_table_name_cls_logits.append(table_name_cls_logits)
|
162 |
+
batch_column_info_cls_logits.append(column_info_cls_logits)
|
163 |
+
|
164 |
+
return batch_table_name_cls_logits, batch_column_info_cls_logits
|
165 |
+
|
166 |
+
def forward(
|
167 |
+
self,
|
168 |
+
encoder_input_ids,
|
169 |
+
encoder_attention_mask,
|
170 |
+
batch_aligned_column_info_ids,
|
171 |
+
batch_aligned_table_name_ids,
|
172 |
+
batch_column_number_in_each_table,
|
173 |
+
):
|
174 |
+
batch_table_name_cls_logits, batch_column_info_cls_logits \
|
175 |
+
= self.table_column_cls(
|
176 |
+
encoder_input_ids,
|
177 |
+
encoder_attention_mask,
|
178 |
+
batch_aligned_column_info_ids,
|
179 |
+
batch_aligned_table_name_ids,
|
180 |
+
batch_column_number_in_each_table
|
181 |
+
)
|
182 |
+
|
183 |
+
return {
|
184 |
+
"batch_table_name_cls_logits" : batch_table_name_cls_logits,
|
185 |
+
"batch_column_info_cls_logits": batch_column_info_cls_logits
|
186 |
+
}
|
utils/db_utils.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import sqlite3
|
4 |
+
|
5 |
+
from func_timeout import func_set_timeout, FunctionTimedOut
|
6 |
+
from utils.bridge_content_encoder import get_matched_entries
|
7 |
+
from nltk.tokenize import word_tokenize
|
8 |
+
from nltk import ngrams
|
9 |
+
|
10 |
+
def add_a_record(question, db_id):
|
11 |
+
conn = sqlite3.connect('data/history/history.sqlite')
|
12 |
+
cursor = conn.cursor()
|
13 |
+
cursor.execute("INSERT INTO record (question, db_id) VALUES (?, ?)", (question, db_id))
|
14 |
+
|
15 |
+
conn.commit()
|
16 |
+
conn.close()
|
17 |
+
|
18 |
+
def obtain_n_grams(sequence, max_n):
|
19 |
+
tokens = word_tokenize(sequence)
|
20 |
+
all_grams = []
|
21 |
+
for n in range(1, max_n + 1):
|
22 |
+
all_grams.extend([" ".join(gram) for gram in ngrams(tokens, n)])
|
23 |
+
|
24 |
+
return all_grams
|
25 |
+
|
26 |
+
# get the database cursor for a sqlite database path
|
27 |
+
def get_cursor_from_path(sqlite_path):
|
28 |
+
try:
|
29 |
+
if not os.path.exists(sqlite_path):
|
30 |
+
print("Openning a new connection %s" % sqlite_path)
|
31 |
+
connection = sqlite3.connect(sqlite_path, check_same_thread = False)
|
32 |
+
except Exception as e:
|
33 |
+
print(sqlite_path)
|
34 |
+
raise e
|
35 |
+
connection.text_factory = lambda b: b.decode(errors="ignore")
|
36 |
+
cursor = connection.cursor()
|
37 |
+
return cursor
|
38 |
+
|
39 |
+
# execute predicted sql with a time limitation
|
40 |
+
@func_set_timeout(15)
|
41 |
+
def execute_sql(cursor, sql):
|
42 |
+
cursor.execute(sql)
|
43 |
+
|
44 |
+
return cursor.fetchall()
|
45 |
+
|
46 |
+
# execute predicted sql with a long time limitation (for buiding content index)
|
47 |
+
@func_set_timeout(2000)
|
48 |
+
def execute_sql_long_time_limitation(cursor, sql):
|
49 |
+
cursor.execute(sql)
|
50 |
+
|
51 |
+
return cursor.fetchall()
|
52 |
+
|
53 |
+
def check_sql_executability(generated_sql, db):
|
54 |
+
if generated_sql.strip() == "":
|
55 |
+
return "Error: empty string"
|
56 |
+
try:
|
57 |
+
cursor = get_cursor_from_path(db)
|
58 |
+
execute_sql(cursor, generated_sql)
|
59 |
+
execution_error = None
|
60 |
+
except FunctionTimedOut as fto:
|
61 |
+
print("SQL execution time out error: {}.".format(fto))
|
62 |
+
execution_error = "SQL execution times out."
|
63 |
+
except Exception as e:
|
64 |
+
print("SQL execution runtime error: {}.".format(e))
|
65 |
+
execution_error = str(e)
|
66 |
+
|
67 |
+
return execution_error
|
68 |
+
|
69 |
+
def is_number(s):
|
70 |
+
try:
|
71 |
+
float(s)
|
72 |
+
return True
|
73 |
+
except ValueError:
|
74 |
+
return False
|
75 |
+
|
76 |
+
def detect_special_char(name):
|
77 |
+
for special_char in ['(', '-', ')', ' ', '/']:
|
78 |
+
if special_char in name:
|
79 |
+
return True
|
80 |
+
|
81 |
+
return False
|
82 |
+
|
83 |
+
def add_quotation_mark(s):
|
84 |
+
return "`" + s + "`"
|
85 |
+
|
86 |
+
def get_column_contents(column_name, table_name, cursor):
|
87 |
+
select_column_sql = "SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL LIMIT 2;".format(column_name, table_name, column_name)
|
88 |
+
results = execute_sql_long_time_limitation(cursor, select_column_sql)
|
89 |
+
column_contents = [str(result[0]).strip() for result in results]
|
90 |
+
# remove empty and extremely-long contents
|
91 |
+
column_contents = [content for content in column_contents if len(content) != 0 and len(content) <= 25]
|
92 |
+
|
93 |
+
return column_contents
|
94 |
+
|
95 |
+
def get_matched_contents(question, searcher):
|
96 |
+
# coarse-grained matching between the input text and all contents in database
|
97 |
+
grams = obtain_n_grams(question, 4)
|
98 |
+
hits = []
|
99 |
+
for query in grams:
|
100 |
+
hits.extend(searcher.search(query, k = 10))
|
101 |
+
|
102 |
+
coarse_matched_contents = dict()
|
103 |
+
for i in range(len(hits)):
|
104 |
+
matched_result = json.loads(hits[i].raw)
|
105 |
+
# `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
|
106 |
+
tc_name = ".".join(matched_result["id"].split("-**-")[:2])
|
107 |
+
if tc_name in coarse_matched_contents.keys():
|
108 |
+
if matched_result["contents"] not in coarse_matched_contents[tc_name]:
|
109 |
+
coarse_matched_contents[tc_name].append(matched_result["contents"])
|
110 |
+
else:
|
111 |
+
coarse_matched_contents[tc_name] = [matched_result["contents"]]
|
112 |
+
|
113 |
+
fine_matched_contents = dict()
|
114 |
+
for tc_name, contents in coarse_matched_contents.items():
|
115 |
+
# fine-grained matching between the question and coarse matched contents
|
116 |
+
fm_contents = get_matched_entries(question, contents)
|
117 |
+
|
118 |
+
if fm_contents is None:
|
119 |
+
continue
|
120 |
+
for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents:
|
121 |
+
if match_score < 0.9:
|
122 |
+
continue
|
123 |
+
if tc_name in fine_matched_contents.keys():
|
124 |
+
if len(fine_matched_contents[tc_name]) < 25:
|
125 |
+
fine_matched_contents[tc_name].append(field_value.strip())
|
126 |
+
else:
|
127 |
+
fine_matched_contents[tc_name] = [field_value.strip()]
|
128 |
+
|
129 |
+
return fine_matched_contents
|
130 |
+
|
131 |
+
def get_db_schema_sequence(schema):
|
132 |
+
schema_sequence = "database schema :\n"
|
133 |
+
for table in schema["schema_items"]:
|
134 |
+
table_name, table_comment = table["table_name"], table["table_comment"]
|
135 |
+
if detect_special_char(table_name):
|
136 |
+
table_name = add_quotation_mark(table_name)
|
137 |
+
|
138 |
+
# if table_comment != "":
|
139 |
+
# table_name += " ( comment : " + table_comment + " )"
|
140 |
+
|
141 |
+
column_info_list = []
|
142 |
+
for column_name, column_type, column_comment, column_content, pk_indicator in \
|
143 |
+
zip(table["column_names"], table["column_types"], table["column_comments"], table["column_contents"], table["pk_indicators"]):
|
144 |
+
if detect_special_char(column_name):
|
145 |
+
column_name = add_quotation_mark(column_name)
|
146 |
+
additional_column_info = []
|
147 |
+
# column type
|
148 |
+
additional_column_info.append(column_type)
|
149 |
+
# pk indicator
|
150 |
+
if pk_indicator != 0:
|
151 |
+
additional_column_info.append("primary key")
|
152 |
+
# column comment
|
153 |
+
if column_comment != "":
|
154 |
+
additional_column_info.append("comment : " + column_comment)
|
155 |
+
# representive column values
|
156 |
+
if len(column_content) != 0:
|
157 |
+
additional_column_info.append("values : " + " , ".join(column_content))
|
158 |
+
|
159 |
+
column_info_list.append(table_name + "." + column_name + " ( " + " | ".join(additional_column_info) + " )")
|
160 |
+
|
161 |
+
schema_sequence += "table "+ table_name + " , columns = [ " + " , ".join(column_info_list) + " ]\n"
|
162 |
+
|
163 |
+
if len(schema["foreign_keys"]) != 0:
|
164 |
+
schema_sequence += "foreign keys :\n"
|
165 |
+
for foreign_key in schema["foreign_keys"]:
|
166 |
+
for i in range(len(foreign_key)):
|
167 |
+
if detect_special_char(foreign_key[i]):
|
168 |
+
foreign_key[i] = add_quotation_mark(foreign_key[i])
|
169 |
+
schema_sequence += "{}.{} = {}.{}\n".format(foreign_key[0], foreign_key[1], foreign_key[2], foreign_key[3])
|
170 |
+
else:
|
171 |
+
schema_sequence += "foreign keys : None\n"
|
172 |
+
|
173 |
+
return schema_sequence.strip()
|
174 |
+
|
175 |
+
def get_matched_content_sequence(matched_contents):
|
176 |
+
content_sequence = ""
|
177 |
+
if len(matched_contents) != 0:
|
178 |
+
content_sequence += "matched contents :\n"
|
179 |
+
for tc_name, contents in matched_contents.items():
|
180 |
+
table_name = tc_name.split(".")[0]
|
181 |
+
column_name = tc_name.split(".")[1]
|
182 |
+
if detect_special_char(table_name):
|
183 |
+
table_name = add_quotation_mark(table_name)
|
184 |
+
if detect_special_char(column_name):
|
185 |
+
column_name = add_quotation_mark(column_name)
|
186 |
+
|
187 |
+
content_sequence += table_name + "." + column_name + " ( " + " , ".join(contents) + " )\n"
|
188 |
+
else:
|
189 |
+
content_sequence = "matched contents : None"
|
190 |
+
|
191 |
+
return content_sequence.strip()
|
192 |
+
|
193 |
+
def get_db_schema(db_path, db_comments, db_id):
|
194 |
+
if db_id in db_comments:
|
195 |
+
db_comment = db_comments[db_id]
|
196 |
+
else:
|
197 |
+
db_comment = None
|
198 |
+
|
199 |
+
cursor = get_cursor_from_path(db_path)
|
200 |
+
|
201 |
+
# obtain table names
|
202 |
+
results = execute_sql(cursor, "SELECT name FROM sqlite_master WHERE type='table';")
|
203 |
+
table_names = [result[0].lower() for result in results]
|
204 |
+
|
205 |
+
schema = dict()
|
206 |
+
schema["schema_items"] = []
|
207 |
+
foreign_keys = []
|
208 |
+
# for each table
|
209 |
+
for table_name in table_names:
|
210 |
+
# skip SQLite system table: sqlite_sequence
|
211 |
+
if table_name == "sqlite_sequence":
|
212 |
+
continue
|
213 |
+
# obtain column names in the current table
|
214 |
+
results = execute_sql(cursor, "SELECT name, type, pk FROM PRAGMA_TABLE_INFO('{}')".format(table_name))
|
215 |
+
column_names_in_one_table = [result[0].lower() for result in results]
|
216 |
+
column_types_in_one_table = [result[1].lower() for result in results]
|
217 |
+
pk_indicators_in_one_table = [result[2] for result in results]
|
218 |
+
|
219 |
+
column_contents = []
|
220 |
+
for column_name in column_names_in_one_table:
|
221 |
+
column_contents.append(get_column_contents(column_name, table_name, cursor))
|
222 |
+
|
223 |
+
# obtain foreign keys in the current table
|
224 |
+
results = execute_sql(cursor, "SELECT * FROM pragma_foreign_key_list('{}');".format(table_name))
|
225 |
+
for result in results:
|
226 |
+
if None not in [result[3], result[2], result[4]]:
|
227 |
+
foreign_keys.append([table_name.lower(), result[3].lower(), result[2].lower(), result[4].lower()])
|
228 |
+
|
229 |
+
# obtain comments for each schema item
|
230 |
+
if db_comment is not None:
|
231 |
+
if table_name in db_comment: # record comments for tables and columns
|
232 |
+
table_comment = db_comment[table_name]["table_comment"]
|
233 |
+
column_comments = [db_comment[table_name]["column_comments"][column_name] \
|
234 |
+
if column_name in db_comment[table_name]["column_comments"] else "" \
|
235 |
+
for column_name in column_names_in_one_table]
|
236 |
+
else: # current database has comment information, but the current table does not
|
237 |
+
table_comment = ""
|
238 |
+
column_comments = ["" for _ in column_names_in_one_table]
|
239 |
+
else: # current database has no comment information
|
240 |
+
table_comment = ""
|
241 |
+
column_comments = ["" for _ in column_names_in_one_table]
|
242 |
+
|
243 |
+
schema["schema_items"].append({
|
244 |
+
"table_name": table_name,
|
245 |
+
"table_comment": table_comment,
|
246 |
+
"column_names": column_names_in_one_table,
|
247 |
+
"column_types": column_types_in_one_table,
|
248 |
+
"column_comments": column_comments,
|
249 |
+
"column_contents": column_contents,
|
250 |
+
"pk_indicators": pk_indicators_in_one_table
|
251 |
+
})
|
252 |
+
|
253 |
+
schema["foreign_keys"] = foreign_keys
|
254 |
+
|
255 |
+
return schema
|
utils/translate_utils.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import random
|
3 |
+
import json
|
4 |
+
|
5 |
+
def translate_zh_to_en(question, token):
|
6 |
+
url = 'https://aip.baidubce.com/rpc/2.0/mt/texttrans/v1?access_token=' + token
|
7 |
+
|
8 |
+
from_lang = 'auto'
|
9 |
+
to_lang = 'en'
|
10 |
+
term_ids = ''
|
11 |
+
|
12 |
+
# Build request
|
13 |
+
headers = {'Content-Type': 'application/json'}
|
14 |
+
payload = {'q': question, 'from': from_lang, 'to': to_lang, 'termIds' : term_ids}
|
15 |
+
|
16 |
+
# Send request
|
17 |
+
r = requests.post(url, params=payload, headers=headers)
|
18 |
+
result = r.json()
|
19 |
+
|
20 |
+
return result["result"]["trans_result"][0]["dst"]
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
print(translate_zh_to_en("你好啊!"))
|