Roxanne-WANG commited on
Commit
b759b87
·
1 Parent(s): b423caf

update webpage

Browse files
Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use NVIDIA CUDA base image for GPU support
2
+ FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Update and install system dependencies (including Java and other tools)
8
+ RUN apt-get update && \
9
+ apt-get install -y openjdk-11-jdk git rsync make build-essential libssl-dev zlib1g-dev \
10
+ libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm \
11
+ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
12
+ libffi-dev liblzma-dev git-lfs ffmpeg libsm6 libxext6 cmake \
13
+ libgl1-mesa-glx && rm -rf /var/lib/apt/lists/* && git lfs install
14
+
15
+ # Set JAVA_HOME environment variable
16
+ ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64
17
+ ENV PATH="${JAVA_HOME}/bin:${PATH}"
18
+
19
+ # Install Python version manager (pyenv) and Python 3.10
20
+ RUN curl https://pyenv.run | bash
21
+ RUN pyenv install 3.10 && pyenv global 3.10 && pyenv rehash
22
+
23
+ # Install pip and other dependencies
24
+ RUN pip install --no-cache-dir --upgrade pip
25
+ RUN pip install --no-cache-dir datasets transformers langdetect streamlit
26
+
27
+ # Install PyTorch and CUDA dependencies
28
+ RUN pip install --no-cache-dir torch==1.13.1+cu117 torchvision==0.14.1 torchaudio==0.13.1
29
+
30
+ # Copy requirements.txt and install dependencies
31
+ COPY requirements.txt /app/
32
+ RUN pip install --no-cache-dir -r requirements.txt
33
+
34
+ # Copy the application code to the container
35
+ COPY . /app/
36
+
37
+ # Expose the port the app will run on
38
+ EXPOSE 8501
39
+
40
+ # Set the environment variable for streamlit
41
+ ENV STREAMLIT_SERVER_PORT=8501
42
+ ENV STREAMLIT_SERVER_HEADLESS=true
43
+
44
+ # Command to run the application
45
+ CMD ["streamlit", "run", "app.py"]
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: LangSQL
3
- emoji: 🏢
4
- colorFrom: indigo
5
  colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.44.1
 
1
  ---
2
  title: LangSQL
3
+ emoji: 🦕
4
+ colorFrom: blue
5
  colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.44.1
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from text2sql import ChatBot
3
+ from langdetect import detect
4
+ from utils.translate_utils import translate_zh_to_en
5
+ from utils.db_utils import add_a_record
6
+ from langdetect.lang_detect_exception import LangDetectException
7
+
8
+ # Initialize chatbot and other variables
9
+ text2sql_bot = ChatBot()
10
+ baidu_api_token = None
11
+
12
+ # Define database schemas for demonstration
13
+ db_schemas = {
14
+ "singer": """
15
+ CREATE TABLE "singer" (
16
+ "Singer_ID" int,
17
+ "Name" text,
18
+ "Birth_Year" real,
19
+ "Net_Worth_Millions" real,
20
+ "Citizenship" text,
21
+ PRIMARY KEY ("Singer_ID")
22
+ );
23
+
24
+ CREATE TABLE "song" (
25
+ "Song_ID" int,
26
+ "Title" text,
27
+ "Singer_ID" int,
28
+ "Sales" real,
29
+ "Highest_Position" real,
30
+ PRIMARY KEY ("Song_ID"),
31
+ FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
32
+ );
33
+ """,
34
+ # Add other schemas as needed
35
+ }
36
+
37
+ # Streamlit UI
38
+ st.title("Text-to-SQL Chatbot")
39
+ st.sidebar.header("Select a Database")
40
+
41
+ # Sidebar for selecting a database
42
+ selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
43
+
44
+ # Display the selected schema
45
+ st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)
46
+
47
+ # User input section
48
+ question = st.text_input("Enter your question:")
49
+ db_id = selected_db # Use selected database for DB ID
50
+
51
+ if question:
52
+ add_a_record(question, db_id)
53
+
54
+ try:
55
+ if baidu_api_token is not None and detect(question) != "en":
56
+ print("Before translation:", question)
57
+ question = translate_zh_to_en(question, baidu_api_token)
58
+ print("After translation:", question)
59
+ except LangDetectException as e:
60
+ print("Language detection error:", str(e))
61
+
62
+ predicted_sql = text2sql_bot.get_response(question, db_id)
63
+ st.write(f"**Database:** {db_id}")
64
+ st.write(f"**Predicted SQL query:** {predicted_sql}")
build_contents_index.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.db_utils import get_cursor_from_path, execute_sql_long_time_limitation
2
+ import json
3
+ import os, shutil
4
+
5
+ def remove_contents_of_a_folder(index_path):
6
+ # if index_path does not exist, then create it
7
+ os.makedirs(index_path, exist_ok = True)
8
+ # remove files in index_path
9
+ for filename in os.listdir(index_path):
10
+ file_path = os.path.join(index_path, filename)
11
+ try:
12
+ if os.path.isfile(file_path) or os.path.islink(file_path):
13
+ os.unlink(file_path)
14
+ elif os.path.isdir(file_path):
15
+ shutil.rmtree(file_path)
16
+ except Exception as e:
17
+ print('Failed to delete %s. Reason: %s' % (file_path, e))
18
+
19
+ def build_content_index(db_path, index_path):
20
+ '''
21
+ Create a BM25 index for all contents in a database
22
+ '''
23
+ cursor = get_cursor_from_path(db_path)
24
+ results = execute_sql_long_time_limitation(cursor, "SELECT name FROM sqlite_master WHERE type='table';")
25
+ table_names = [result[0] for result in results]
26
+
27
+ all_column_contents = []
28
+ for table_name in table_names:
29
+ # skip SQLite system table: sqlite_sequence
30
+ if table_name == "sqlite_sequence":
31
+ continue
32
+ results = execute_sql_long_time_limitation(cursor, "SELECT name FROM PRAGMA_TABLE_INFO('{}')".format(table_name))
33
+ column_names_in_one_table = [result[0] for result in results]
34
+ for column_name in column_names_in_one_table:
35
+ try:
36
+ print("SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL;".format(column_name, table_name, column_name))
37
+ results = execute_sql_long_time_limitation(cursor, "SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL;".format(column_name, table_name, column_name))
38
+ column_contents = [str(result[0]).strip() for result in results]
39
+
40
+ for c_id, column_content in enumerate(column_contents):
41
+ # remove empty and extremely-long contents
42
+ if len(column_content) != 0 and len(column_content) <= 25:
43
+ all_column_contents.append(
44
+ {
45
+ "id": "{}-**-{}-**-{}".format(table_name, column_name, c_id).lower(),
46
+ "contents": column_content
47
+ }
48
+ )
49
+ except Exception as e:
50
+ print(str(e))
51
+
52
+ with open("./data/temp_db_index/contents.json", "w") as f:
53
+ f.write(json.dumps(all_column_contents, indent = 2, ensure_ascii = True))
54
+
55
+ # Building a BM25 Index (Direct Java Implementation), see https://github.com/castorini/pyserini/blob/master/docs/usage-index.md
56
+ cmd = "python -m pyserini.index.lucene --collection JsonCollection --input ./data/temp_db_index --index {} --generator DefaultLuceneDocumentGenerator --threads 16 --storePositions --storeDocvectors --storeRaw".format(index_path)
57
+
58
+ d = os.system(cmd)
59
+ print(d)
60
+ os.remove("./data/temp_db_index/contents.json")
61
+
62
+ if __name__ == "__main__":
63
+ os.makedirs('./data/temp_db_index', exist_ok = True)
64
+
65
+ print("build content index for databases...")
66
+ remove_contents_of_a_folder("db_contents_index")
67
+ # build content index for Bank_Financials's training set databases
68
+ for db_id in os.listdir("databases"):
69
+ print(db_id)
70
+ build_content_index(
71
+ os.path.join("databases", db_id, db_id + ".sqlite"),
72
+ os.path.join("db_contents_index", db_id)
73
+ )
74
+
75
+ os.rmdir('./data/temp_db_index')
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/history/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/tables.json ADDED
The diff for this file is too large to render. See raw diff
 
databases/.DS_Store ADDED
Binary file (6.15 kB). View file
 
databases/singer/schema.sql ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRAGMA foreign_keys = ON;
2
+
3
+ CREATE TABLE "singer" (
4
+ "Singer_ID" int,
5
+ "Name" text,
6
+ "Birth_Year" real,
7
+ "Net_Worth_Millions" real,
8
+ "Citizenship" text,
9
+ PRIMARY KEY ("Singer_ID")
10
+ );
11
+
12
+ CREATE TABLE "song" (
13
+ "Song_ID" int,
14
+ "Title" text,
15
+ "Singer_ID" int,
16
+ "Sales" real,
17
+ "Highest_Position" real,
18
+ PRIMARY KEY ("Song_ID"),
19
+ FOREIGN KEY ("Singer_ID") REFERENCES `singer`("Singer_ID")
20
+ );
21
+
22
+ INSERT INTO "singer" VALUES (1,"Liliane Bettencourt","1944","30.0","France");
23
+ INSERT INTO "singer" VALUES (2,"Christy Walton","1948","28.8","United States");
24
+ INSERT INTO "singer" VALUES (3,"Alice Walton","1949","26.3","United States");
25
+ INSERT INTO "singer" VALUES (4,"Iris Fontbona","1942","17.4","Chile");
26
+ INSERT INTO "singer" VALUES (5,"Jacqueline Mars","1940","17.8","United States");
27
+ INSERT INTO "singer" VALUES (6,"Gina Rinehart","1953","17","Australia");
28
+ INSERT INTO "singer" VALUES (7,"Susanne Klatten","1962","14.3","Germany");
29
+ INSERT INTO "singer" VALUES (8,"Abigail Johnson","1961","12.7","United States");
30
+
31
+ INSERT INTO "song" VALUES ("1","Do They Know It's Christmas",1,"1094000","1");
32
+ INSERT INTO "song" VALUES ("2","F**k It (I Don't Want You Back)",1,"552407","1");
33
+ INSERT INTO "song" VALUES ("3","Cha Cha Slide",2,"351421","1");
34
+ INSERT INTO "song" VALUES ("4","Call on Me",4,"335000","1");
35
+ INSERT INTO "song" VALUES ("5","Yeah",2,"300000","1");
36
+ INSERT INTO "song" VALUES ("6","All This Time",6,"292000","1");
37
+ INSERT INTO "song" VALUES ("7","Left Outside Alone",5,"275000","3");
38
+ INSERT INTO "song" VALUES ("8","Mysterious Girl",7,"261000","1");
39
+
databases/singer/singer.sqlite ADDED
Binary file (20.5 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ text2sql
3
+ langdetect
4
+ faiss-cpu
5
+ func_timeout
6
+ nltk
7
+ numpy
8
+ pandas
9
+ rapidfuzz
10
+ tqdm
11
+ transformers
12
+ chardet
13
+ sqlparse
14
+ accelerate
15
+ bitsandbytes
16
+ sql_metadata
17
+ datasets
18
+ whoosh
schema_item_filter.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+
5
+ from tqdm import tqdm
6
+ from transformers import AutoTokenizer
7
+ from utils.classifier_model import SchemaItemClassifier
8
+ from transformers.trainer_utils import set_seed
9
+
10
+ def prepare_inputs_and_labels(sample, tokenizer):
11
+ table_names = [table["table_name"] for table in sample["schema"]["schema_items"]]
12
+ column_names = [table["column_names"] for table in sample["schema"]["schema_items"]]
13
+ column_num_in_each_table = [len(table["column_names"]) for table in sample["schema"]["schema_items"]]
14
+
15
+ # `column_name_word_indices` and `table_name_word_indices` record the word indices of each column and table in `input_words`, whose element is an integer
16
+ column_name_word_indices, table_name_word_indices = [], []
17
+
18
+ input_words = [sample["text"]]
19
+ for table_id, table_name in enumerate(table_names):
20
+ input_words.append("|")
21
+ input_words.append(table_name)
22
+ table_name_word_indices.append(len(input_words) - 1)
23
+ input_words.append(":")
24
+
25
+ for column_name in column_names[table_id]:
26
+ input_words.append(column_name)
27
+ column_name_word_indices.append(len(input_words) - 1)
28
+ input_words.append(",")
29
+
30
+ # remove the last ","
31
+ input_words = input_words[:-1]
32
+
33
+ tokenized_inputs = tokenizer(
34
+ input_words,
35
+ return_tensors="pt",
36
+ is_split_into_words = True,
37
+ padding = "max_length",
38
+ max_length = 512,
39
+ truncation = True
40
+ )
41
+
42
+ # after tokenizing, one table name or column name may be splitted into multiple tokens (i.e., sub-words)
43
+ # `column_name_token_indices` and `table_name_token_indices` records the token indices of each column and table in `input_ids`, whose element is a list of integer
44
+ column_name_token_indices, table_name_token_indices = [], []
45
+ word_indices = tokenized_inputs.word_ids(batch_index = 0)
46
+
47
+ # obtain token indices of each column in `input_ids`
48
+ for column_name_word_index in column_name_word_indices:
49
+ column_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if column_name_word_index == word_index])
50
+
51
+ # obtain token indices of each table in `input_ids`
52
+ for table_name_word_index in table_name_word_indices:
53
+ table_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if table_name_word_index == word_index])
54
+
55
+ encoder_input_ids = tokenized_inputs["input_ids"]
56
+ encoder_input_attention_mask = tokenized_inputs["attention_mask"]
57
+
58
+ # print("\n".join(tokenizer.batch_decode(encoder_input_ids, skip_special_tokens = True)))
59
+
60
+ if torch.cuda.is_available():
61
+ encoder_input_ids = encoder_input_ids.cuda()
62
+ encoder_input_attention_mask = encoder_input_attention_mask.cuda()
63
+
64
+ return encoder_input_ids, encoder_input_attention_mask, \
65
+ column_name_token_indices, table_name_token_indices, column_num_in_each_table
66
+
67
+ def get_schema(tables_and_columns):
68
+ schema_items = []
69
+ table_names = list(dict.fromkeys([t for t, c in tables_and_columns]))
70
+ for table_name in table_names:
71
+ schema_items.append(
72
+ {
73
+ "table_name": table_name,
74
+ "column_names": [c for t, c in tables_and_columns if t == table_name]
75
+ }
76
+ )
77
+
78
+ return {"schema_items": schema_items}
79
+
80
+ def get_sequence_length(text, tables_and_columns, tokenizer):
81
+ table_names = [t for t, c in tables_and_columns]
82
+ # duplicate `table_names` while preserving order
83
+ table_names = list(dict.fromkeys(table_names))
84
+
85
+ column_names = []
86
+ for table_name in table_names:
87
+ column_names.append([c for t, c in tables_and_columns if t == table_name])
88
+
89
+ input_words = [text]
90
+ for table_id, table_name in enumerate(table_names):
91
+ input_words.append("|")
92
+ input_words.append(table_name)
93
+ input_words.append(":")
94
+ for column_name in column_names[table_id]:
95
+ input_words.append(column_name)
96
+ input_words.append(",")
97
+ # remove the last ","
98
+ input_words = input_words[:-1]
99
+
100
+ tokenized_inputs = tokenizer(input_words, is_split_into_words = True)
101
+
102
+ return len(tokenized_inputs["input_ids"])
103
+
104
+ # handle extremely long schema sequences
105
+ def split_sample(sample, tokenizer):
106
+ text = sample["text"]
107
+
108
+ table_names = []
109
+ column_names = []
110
+ for table in sample["schema"]["schema_items"]:
111
+ table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \
112
+ if table["table_comment"] != "" else table["table_name"])
113
+ column_names.append([column_name + " ( " + column_comment + " ) " \
114
+ if column_comment != "" else column_name \
115
+ for column_name, column_comment in zip(table["column_names"], table["column_comments"])])
116
+
117
+ splitted_samples = []
118
+ recorded_tables_and_columns = []
119
+
120
+ for table_idx, table_name in enumerate(table_names):
121
+ for column_name in column_names[table_idx]:
122
+ if get_sequence_length(text, recorded_tables_and_columns + [[table_name, column_name]], tokenizer) < 500:
123
+ recorded_tables_and_columns.append([table_name, column_name])
124
+ else:
125
+ splitted_samples.append(
126
+ {
127
+ "text": text,
128
+ "schema": get_schema(recorded_tables_and_columns)
129
+ }
130
+ )
131
+ recorded_tables_and_columns = [[table_name, column_name]]
132
+
133
+ splitted_samples.append(
134
+ {
135
+ "text": text,
136
+ "schema": get_schema(recorded_tables_and_columns)
137
+ }
138
+ )
139
+
140
+ return splitted_samples
141
+
142
+ def merge_pred_results(sample, pred_results):
143
+ # table_names = [table["table_name"] for table in sample["schema"]["schema_items"]]
144
+ # column_names = [table["column_names"] for table in sample["schema"]["schema_items"]]
145
+ table_names = []
146
+ column_names = []
147
+ for table in sample["schema"]["schema_items"]:
148
+ table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \
149
+ if table["table_comment"] != "" else table["table_name"])
150
+ column_names.append([column_name + " ( " + column_comment + " ) " \
151
+ if column_comment != "" else column_name \
152
+ for column_name, column_comment in zip(table["column_names"], table["column_comments"])])
153
+
154
+ merged_results = []
155
+ for table_id, table_name in enumerate(table_names):
156
+ table_prob = 0
157
+ column_probs = []
158
+ for result_dict in pred_results:
159
+ if table_name in result_dict:
160
+ if table_prob < result_dict[table_name]["table_prob"]:
161
+ table_prob = result_dict[table_name]["table_prob"]
162
+ column_probs += result_dict[table_name]["column_probs"]
163
+
164
+ merged_results.append(
165
+ {
166
+ "table_name": table_name,
167
+ "table_prob": table_prob,
168
+ "column_names": column_names[table_id],
169
+ "column_probs": column_probs
170
+ }
171
+ )
172
+
173
+ return merged_results
174
+
175
+ def filter_schema(data, sic, num_top_k_tables = 5, num_top_k_columns = 5):
176
+ filtered_schema = dict()
177
+ filtered_matched_contents = dict()
178
+ filtered_schema["schema_items"] = []
179
+ filtered_schema["foreign_keys"] = []
180
+
181
+ table_names = [table["table_name"] for table in data["schema"]["schema_items"]]
182
+ table_comments = [table["table_comment"] for table in data["schema"]["schema_items"]]
183
+ column_names = [table["column_names"] for table in data["schema"]["schema_items"]]
184
+ column_types = [table["column_types"] for table in data["schema"]["schema_items"]]
185
+ column_comments = [table["column_comments"] for table in data["schema"]["schema_items"]]
186
+ column_contents = [table["column_contents"] for table in data["schema"]["schema_items"]]
187
+ pk_indicators = [table["pk_indicators"] for table in data["schema"]["schema_items"]]
188
+
189
+ # predict scores for each tables and columns
190
+ pred_results = sic.predict(data)
191
+ # remain top_k1 tables for each database and top_k2 columns for each remained table
192
+ table_probs = [pred_result["table_prob"] for pred_result in pred_results]
193
+ table_indices = np.argsort(-np.array(table_probs), kind="stable")[:num_top_k_tables].tolist()
194
+
195
+ for table_idx in table_indices:
196
+ column_probs = pred_results[table_idx]["column_probs"]
197
+ column_indices = np.argsort(-np.array(column_probs), kind="stable")[:num_top_k_columns].tolist()
198
+
199
+ filtered_schema["schema_items"].append(
200
+ {
201
+ "table_name": table_names[table_idx],
202
+ "table_comment": table_comments[table_idx],
203
+ "column_names": [column_names[table_idx][column_idx] for column_idx in column_indices],
204
+ "column_types": [column_types[table_idx][column_idx] for column_idx in column_indices],
205
+ "column_comments": [column_comments[table_idx][column_idx] for column_idx in column_indices],
206
+ "column_contents": [column_contents[table_idx][column_idx] for column_idx in column_indices],
207
+ "pk_indicators": [pk_indicators[table_idx][column_idx] for column_idx in column_indices]
208
+ }
209
+ )
210
+
211
+ # extract matched contents of remained columns
212
+ for column_name in [column_names[table_idx][column_idx] for column_idx in column_indices]:
213
+ tc_name = "{}.{}".format(table_names[table_idx], column_name)
214
+ if tc_name in data["matched_contents"]:
215
+ filtered_matched_contents[tc_name] = data["matched_contents"][tc_name]
216
+
217
+ # extract foreign keys among remianed tables
218
+ filtered_table_names = [table_names[table_idx] for table_idx in table_indices]
219
+ for foreign_key in data["schema"]["foreign_keys"]:
220
+ source_table, source_column, target_table, target_column = foreign_key
221
+ if source_table in filtered_table_names and target_table in filtered_table_names:
222
+ filtered_schema["foreign_keys"].append(foreign_key)
223
+
224
+ # replace the old schema with the filtered schema
225
+ data["schema"] = filtered_schema
226
+ # replace the old matched contents with the filtered matched contents
227
+ data["matched_contents"] = filtered_matched_contents
228
+
229
+ return data
230
+
231
+ def lista_contains_listb(lista, listb):
232
+ for b in listb:
233
+ if b not in lista:
234
+ return 0
235
+
236
+ return 1
237
+
238
+ class SchemaItemClassifierInference():
239
+ def __init__(self, model_save_path):
240
+ set_seed(42)
241
+ # load tokenizer
242
+ self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True)
243
+ # initialize model
244
+ self.model = SchemaItemClassifier(model_save_path, "test")
245
+ # load fine-tuned params
246
+ self.model.load_state_dict(torch.load(model_save_path + "/dense_classifier.pt", map_location=torch.device('cpu')), strict=False)
247
+ if torch.cuda.is_available():
248
+ self.model = self.model.cuda()
249
+ self.model.eval()
250
+
251
+ def predict_one(self, sample):
252
+ encoder_input_ids, encoder_input_attention_mask, column_name_token_indices,\
253
+ table_name_token_indices, column_num_in_each_table = prepare_inputs_and_labels(sample, self.tokenizer)
254
+
255
+ with torch.no_grad():
256
+ model_outputs = self.model(
257
+ encoder_input_ids,
258
+ encoder_input_attention_mask,
259
+ [column_name_token_indices],
260
+ [table_name_token_indices],
261
+ [column_num_in_each_table]
262
+ )
263
+
264
+ table_logits = model_outputs["batch_table_name_cls_logits"][0]
265
+ table_pred_probs = torch.nn.functional.softmax(table_logits, dim = 1)[:, 1].cpu().tolist()
266
+
267
+ column_logits = model_outputs["batch_column_info_cls_logits"][0]
268
+ column_pred_probs = torch.nn.functional.softmax(column_logits, dim = 1)[:, 1].cpu().tolist()
269
+
270
+ splitted_column_pred_probs = []
271
+ # split predicted column probs into each table
272
+ for table_id, column_num in enumerate(column_num_in_each_table):
273
+ splitted_column_pred_probs.append(column_pred_probs[sum(column_num_in_each_table[:table_id]): sum(column_num_in_each_table[:table_id]) + column_num])
274
+ column_pred_probs = splitted_column_pred_probs
275
+
276
+ result_dict = dict()
277
+ for table_idx, table in enumerate(sample["schema"]["schema_items"]):
278
+ result_dict[table["table_name"]] = {
279
+ "table_name": table["table_name"],
280
+ "table_prob": table_pred_probs[table_idx],
281
+ "column_names": table["column_names"],
282
+ "column_probs": column_pred_probs[table_idx],
283
+ }
284
+
285
+ return result_dict
286
+
287
+ def predict(self, test_sample):
288
+ splitted_samples = split_sample(test_sample, self.tokenizer)
289
+ pred_results = []
290
+ for splitted_sample in splitted_samples:
291
+ pred_results.append(self.predict_one(splitted_sample))
292
+
293
+ return merge_pred_results(test_sample, pred_results)
294
+
295
+ def evaluate_coverage(self, dataset):
296
+ max_k = 100
297
+ total_num_for_table_coverage, total_num_for_column_coverage = 0, 0
298
+ table_coverage_results = [0]*max_k
299
+ column_coverage_results = [0]*max_k
300
+
301
+ for data in dataset:
302
+ indices_of_used_tables = [idx for idx, label in enumerate(data["table_labels"]) if label == 1]
303
+ pred_results = sic.predict(data)
304
+ # print(pred_results)
305
+ table_probs = [res["table_prob"] for res in pred_results]
306
+ for k in range(max_k):
307
+ indices_of_top_k_tables = np.argsort(-np.array(table_probs), kind="stable")[:k+1].tolist()
308
+ if lista_contains_listb(indices_of_top_k_tables, indices_of_used_tables):
309
+ table_coverage_results[k] += 1
310
+ total_num_for_table_coverage += 1
311
+
312
+ for table_idx in range(len(data["table_labels"])):
313
+ indices_of_used_columns = [idx for idx, label in enumerate(data["column_labels"][table_idx]) if label == 1]
314
+ if len(indices_of_used_columns) == 0:
315
+ continue
316
+ column_probs = pred_results[table_idx]["column_probs"]
317
+ for k in range(max_k):
318
+ indices_of_top_k_columns = np.argsort(-np.array(column_probs), kind="stable")[:k+1].tolist()
319
+ if lista_contains_listb(indices_of_top_k_columns, indices_of_used_columns):
320
+ column_coverage_results[k] += 1
321
+
322
+ total_num_for_column_coverage += 1
323
+
324
+ indices_of_top_10_columns = np.argsort(-np.array(column_probs), kind="stable")[:10].tolist()
325
+ if lista_contains_listb(indices_of_top_10_columns, indices_of_used_columns) == 0:
326
+ print(pred_results[table_idx])
327
+ print(data["column_labels"][table_idx])
328
+ print(data["question"])
329
+
330
+ print(total_num_for_table_coverage)
331
+ print(table_coverage_results)
332
+ print(total_num_for_column_coverage)
333
+ print(column_coverage_results)
334
+
335
+ if __name__ == "__main__":
336
+ dataset_name = "bird_with_evidence"
337
+ # dataset_name = "bird"
338
+ # dataset_name = "spider"
339
+ sic = SchemaItemClassifierInference("sic_ckpts/sic_{}".format(dataset_name))
340
+ import json
341
+ dataset = json.load(open("./data/sft_eval_{}_text2sql.json".format(dataset_name)))
342
+
343
+ sic.evaluate_coverage(dataset)
text2sql.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import copy
5
+ import re
6
+ import sqlparse
7
+ import sqlite3
8
+
9
+ from tqdm import tqdm
10
+ from utils.db_utils import get_db_schema
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from whoosh.index import create_in
13
+ from whoosh.fields import Schema, TEXT
14
+ from whoosh.qparser import QueryParser
15
+ from utils.db_utils import check_sql_executability, get_matched_contents, get_db_schema_sequence, get_matched_content_sequence
16
+ from schema_item_filter import SchemaItemClassifierInference, filter_schema
17
+
18
+ def remove_similar_comments(names, comments):
19
+ '''
20
+ Remove table (or column) comments that have a high degree of similarity with their names
21
+ '''
22
+ new_comments = []
23
+ for name, comment in zip(names, comments):
24
+ if name.replace("_", "").replace(" ", "") == comment.replace("_", "").replace(" ", ""):
25
+ new_comments.append("")
26
+ else:
27
+ new_comments.append(comment)
28
+
29
+ return new_comments
30
+
31
+ def load_db_comments(table_json_path):
32
+ additional_db_info = json.load(open(table_json_path))
33
+ db_comments = dict()
34
+ for db_info in additional_db_info:
35
+ comment_dict = dict()
36
+
37
+ column_names = [column_name.lower() for _, column_name in db_info["column_names_original"]]
38
+ table_idx_of_each_column = [t_idx for t_idx, _ in db_info["column_names_original"]]
39
+ column_comments = [column_comment.lower() for _, column_comment in db_info["column_names"]]
40
+
41
+ assert len(column_names) == len(column_comments)
42
+ column_comments = remove_similar_comments(column_names, column_comments)
43
+
44
+ table_names = [table_name.lower() for table_name in db_info["table_names_original"]]
45
+ table_comments = [table_comment.lower() for table_comment in db_info["table_names"]]
46
+
47
+ assert len(table_names) == len(table_comments)
48
+ table_comments = remove_similar_comments(table_names, table_comments)
49
+
50
+ for table_idx, (table_name, table_comment) in enumerate(zip(table_names, table_comments)):
51
+ comment_dict[table_name] = {
52
+ "table_comment": table_comment,
53
+ "column_comments": dict()
54
+ }
55
+ for t_idx, column_name, column_comment in zip(table_idx_of_each_column, column_names, column_comments):
56
+ if t_idx == table_idx:
57
+ comment_dict[table_name]["column_comments"][column_name] = column_comment
58
+
59
+ db_comments[db_info["db_id"]] = comment_dict
60
+
61
+ return db_comments
62
+
63
+ def get_db_id2schema(db_path, tables_json):
64
+ db_comments = load_db_comments(tables_json)
65
+ db_id2schema = dict()
66
+
67
+ for db_id in tqdm(os.listdir(db_path)):
68
+ db_id2schema[db_id] = get_db_schema(os.path.join(db_path, db_id, db_id + ".sqlite"), db_comments, db_id)
69
+
70
+ return db_id2schema
71
+
72
+ def get_db_id2ddl(db_path):
73
+ db_ids = os.listdir(db_path)
74
+ db_id2ddl = dict()
75
+
76
+ for db_id in db_ids:
77
+ conn = sqlite3.connect(os.path.join(db_path, db_id, db_id + ".sqlite"))
78
+ cursor = conn.cursor()
79
+ cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';")
80
+ tables = cursor.fetchall()
81
+ ddl = []
82
+
83
+ for table in tables:
84
+ table_name = table[0]
85
+ table_ddl = table[1]
86
+ table_ddl.replace("\t", " ")
87
+ while " " in table_ddl:
88
+ table_ddl = table_ddl.replace(" ", " ")
89
+
90
+ table_ddl = re.sub(r'--.*', '', table_ddl)
91
+ table_ddl = sqlparse.format(table_ddl, keyword_case = "upper", identifier_case = "lower", reindent_aligned = True)
92
+ table_ddl = table_ddl.replace(", ", ",\n ")
93
+
94
+ if table_ddl.endswith(";"):
95
+ table_ddl = table_ddl[:-1]
96
+ table_ddl = table_ddl[:-1] + "\n);"
97
+ table_ddl = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", table_ddl)
98
+
99
+ ddl.append(table_ddl)
100
+ db_id2ddl[db_id] = "\n\n".join(ddl)
101
+
102
+ return db_id2ddl
103
+
104
+ class ChatBot():
105
+ def __init__(self) -> None:
106
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
107
+ model_name = "seeklhy/codes-7b-merged"
108
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
109
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16)
110
+ self.max_length = 4096
111
+ self.max_new_tokens = 256
112
+ self.max_prefix_length = self.max_length - self.max_new_tokens
113
+
114
+ self.sic = SchemaItemClassifierInference("sic_ckpts/sic_bird")
115
+
116
+ self.db_id2content_searcher = dict()
117
+ for db_id in os.listdir("db_contents_index"):
118
+ schema = Schema(content=TEXT(stored=True))
119
+ index_dir = os.path.join("db_contents_index", db_id)
120
+ if not os.path.exists(index_dir):
121
+ os.makedirs(index_dir)
122
+ ix = create_in(index_dir, schema)
123
+ writer = ix.writer()
124
+ with open(os.path.join(index_dir, f"{db_id}.json"), "r") as file:
125
+ data = json.load(file)
126
+ for item in data:
127
+ writer.add_document(content=item['content'])
128
+ writer.commit()
129
+ self.db_id2content_searcher[db_id] = ix
130
+
131
+ self.db_ids = sorted(os.listdir("databases"))
132
+ self.db_id2schema = get_db_id2schema("databases", "data/tables.json")
133
+ self.db_id2ddl = get_db_id2ddl("databases")
134
+
135
+ def get_response(self, question, db_id):
136
+ data = {
137
+ "text": question,
138
+ "schema": copy.deepcopy(self.db_id2schema[db_id]),
139
+ "matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id])
140
+ }
141
+ data = filter_schema(data, self.sic, 6, 10)
142
+ data["schema_sequence"] = get_db_schema_sequence(data["schema"])
143
+ data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
144
+
145
+ prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n"
146
+ print(prefix_seq)
147
+
148
+ input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq , truncation = False)["input_ids"]
149
+ if len(input_ids) > self.max_prefix_length:
150
+ print("the current input sequence exceeds the max_tokens, we will truncate it.")
151
+ input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):]
152
+ attention_mask = [1] * len(input_ids)
153
+
154
+ inputs = {
155
+ "input_ids": torch.tensor([input_ids], dtype = torch.int64).to(self.model.device),
156
+ "attention_mask": torch.tensor([attention_mask], dtype = torch.int64).to(self.model.device)
157
+ }
158
+ input_length = inputs["input_ids"].shape[1]
159
+
160
+ with torch.no_grad():
161
+ generate_ids = self.model.generate(
162
+ **inputs,
163
+ max_new_tokens = self.max_new_tokens,
164
+ num_beams = 4,
165
+ num_return_sequences = 4
166
+ )
167
+
168
+ generated_sqls = self.tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False)
169
+ final_generated_sql = None
170
+ for generated_sql in generated_sqls:
171
+ execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite"))
172
+ if execution_error is None:
173
+ final_generated_sql = generated_sql
174
+ break
175
+
176
+ if final_generated_sql is None:
177
+ if generated_sqls[0].strip() != "":
178
+ final_generated_sql = generated_sqls[0].strip()
179
+ else:
180
+ final_generated_sql = "Sorry, I can not generate a suitable SQL query for your question."
181
+
182
+ return final_generated_sql.replace("\n", " ")
utils/bridge_content_encoder.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2020, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+
7
+ Encode DB content.
8
+ """
9
+
10
+ import difflib
11
+ from typing import List, Optional, Tuple
12
+ from rapidfuzz import fuzz
13
+ import sqlite3
14
+ import functools
15
+
16
+ # fmt: off
17
+ _stopwords = {'who', 'ourselves', 'down', 'only', 'were', 'him', 'at', "weren't", 'has', 'few', "it's", 'm', 'again',
18
+ 'd', 'haven', 'been', 'other', 'we', 'an', 'own', 'doing', 'ma', 'hers', 'all', "haven't", 'in', 'but',
19
+ "shouldn't", 'does', 'out', 'aren', 'you', "you'd", 'himself', "isn't", 'most', 'y', 'below', 'is',
20
+ "wasn't", 'hasn', 'them', 'wouldn', 'against', 'this', 'about', 'there', 'don', "that'll", 'a', 'being',
21
+ 'with', 'your', 'theirs', 'its', 'any', 'why', 'now', 'during', 'weren', 'if', 'should', 'those', 'be',
22
+ 'they', 'o', 't', 'of', 'or', 'me', 'i', 'some', 'her', 'do', 'will', 'yours', 'for', 'mightn', 'nor',
23
+ 'needn', 'the', 'until', "couldn't", 'he', 'which', 'yourself', 'to', "needn't", "you're", 'because',
24
+ 'their', 'where', 'it', "didn't", 've', 'whom', "should've", 'can', "shan't", 'on', 'had', 'have',
25
+ 'myself', 'am', "don't", 'under', 'was', "won't", 'these', 'so', 'as', 'after', 'above', 'each', 'ours',
26
+ 'hadn', 'having', 'wasn', 's', 'doesn', "hadn't", 'than', 'by', 'that', 'both', 'herself', 'his',
27
+ "wouldn't", 'into', "doesn't", 'before', 'my', 'won', 'more', 'are', 'through', 'same', 'how', 'what',
28
+ 'over', 'll', 'yourselves', 'up', 'mustn', "mustn't", "she's", 're', 'such', 'didn', "you'll", 'shan',
29
+ 'when', "you've", 'themselves', "mightn't", 'she', 'from', 'isn', 'ain', 'between', 'once', 'here',
30
+ 'shouldn', 'our', 'and', 'not', 'too', 'very', 'further', 'while', 'off', 'couldn', "hasn't", 'itself',
31
+ 'then', 'did', 'just', "aren't"}
32
+ # fmt: on
33
+
34
+ _commonwords = {"no", "yes", "many"}
35
+
36
+
37
+ def is_number(s: str) -> bool:
38
+ try:
39
+ float(s.replace(",", ""))
40
+ return True
41
+ except:
42
+ return False
43
+
44
+
45
+ def is_stopword(s: str) -> bool:
46
+ return s.strip() in _stopwords
47
+
48
+
49
+ def is_commonword(s: str) -> bool:
50
+ return s.strip() in _commonwords
51
+
52
+
53
+ def is_common_db_term(s: str) -> bool:
54
+ return s.strip() in ["id"]
55
+
56
+
57
+ class Match(object):
58
+ def __init__(self, start: int, size: int) -> None:
59
+ self.start = start
60
+ self.size = size
61
+
62
+
63
+ def is_span_separator(c: str) -> bool:
64
+ return c in "'\"()`,.?! "
65
+
66
+
67
+ def split(s: str) -> List[str]:
68
+ return [c.lower() for c in s.strip()]
69
+
70
+
71
+ def prefix_match(s1: str, s2: str) -> bool:
72
+ i, j = 0, 0
73
+ for i in range(len(s1)):
74
+ if not is_span_separator(s1[i]):
75
+ break
76
+ for j in range(len(s2)):
77
+ if not is_span_separator(s2[j]):
78
+ break
79
+ if i < len(s1) and j < len(s2):
80
+ return s1[i] == s2[j]
81
+ elif i >= len(s1) and j >= len(s2):
82
+ return True
83
+ else:
84
+ return False
85
+
86
+
87
+ def get_effective_match_source(s: str, start: int, end: int) -> Match:
88
+ _start = -1
89
+
90
+ for i in range(start, start - 2, -1):
91
+ if i < 0:
92
+ _start = i + 1
93
+ break
94
+ if is_span_separator(s[i]):
95
+ _start = i
96
+ break
97
+
98
+ if _start < 0:
99
+ return None
100
+
101
+ _end = -1
102
+ for i in range(end - 1, end + 3):
103
+ if i >= len(s):
104
+ _end = i - 1
105
+ break
106
+ if is_span_separator(s[i]):
107
+ _end = i
108
+ break
109
+
110
+ if _end < 0:
111
+ return None
112
+
113
+ while _start < len(s) and is_span_separator(s[_start]):
114
+ _start += 1
115
+ while _end >= 0 and is_span_separator(s[_end]):
116
+ _end -= 1
117
+
118
+ return Match(_start, _end - _start + 1)
119
+
120
+
121
+ def get_matched_entries(
122
+ s: str, field_values: List[str], m_theta: float = 0.85, s_theta: float = 0.85
123
+ ) -> Optional[List[Tuple[str, Tuple[str, str, float, float, int]]]]:
124
+ if not field_values:
125
+ return None
126
+
127
+ if isinstance(s, str):
128
+ n_grams = split(s)
129
+ else:
130
+ n_grams = s
131
+
132
+ matched = dict()
133
+ for field_value in field_values:
134
+ if not isinstance(field_value, str):
135
+ continue
136
+ fv_tokens = split(field_value)
137
+ sm = difflib.SequenceMatcher(None, n_grams, fv_tokens)
138
+ match = sm.find_longest_match(0, len(n_grams), 0, len(fv_tokens))
139
+ if match.size > 0:
140
+ source_match = get_effective_match_source(
141
+ n_grams, match.a, match.a + match.size
142
+ )
143
+ if source_match: # and source_match.size > 1
144
+ match_str = field_value[match.b : match.b + match.size]
145
+ source_match_str = s[
146
+ source_match.start : source_match.start + source_match.size
147
+ ]
148
+ c_match_str = match_str.lower().strip()
149
+ c_source_match_str = source_match_str.lower().strip()
150
+ c_field_value = field_value.lower().strip()
151
+ if c_match_str and not is_common_db_term(c_match_str): # and not is_number(c_match_str)
152
+ if (
153
+ is_stopword(c_match_str)
154
+ or is_stopword(c_source_match_str)
155
+ or is_stopword(c_field_value)
156
+ ):
157
+ continue
158
+ if c_source_match_str.endswith(c_match_str + "'s"):
159
+ match_score = 1.0
160
+ else:
161
+ if prefix_match(c_field_value, c_source_match_str):
162
+ match_score = fuzz.ratio(c_field_value, c_source_match_str) / 100
163
+ else:
164
+ match_score = 0
165
+ if (
166
+ is_commonword(c_match_str)
167
+ or is_commonword(c_source_match_str)
168
+ or is_commonword(c_field_value)
169
+ ) and match_score < 1:
170
+ continue
171
+ s_match_score = match_score
172
+ if match_score >= m_theta and s_match_score >= s_theta:
173
+ if field_value.isupper() and match_score * s_match_score < 1:
174
+ continue
175
+ matched[match_str] = (
176
+ field_value,
177
+ source_match_str,
178
+ match_score,
179
+ s_match_score,
180
+ match.size,
181
+ )
182
+
183
+ if not matched:
184
+ return None
185
+ else:
186
+ return sorted(
187
+ matched.items(),
188
+ key=lambda x: (1e16 * x[1][2] + 1e8 * x[1][3] + x[1][4]),
189
+ reverse=True,
190
+ )
191
+
192
+
193
+ @functools.lru_cache(maxsize=1000, typed=False)
194
+ def get_column_picklist(table_name: str, column_name: str, db_path: str) -> list:
195
+ fetch_sql = "SELECT DISTINCT `{}` FROM `{}`".format(column_name, table_name)
196
+ try:
197
+ conn = sqlite3.connect(db_path)
198
+ conn.text_factory = bytes
199
+ c = conn.cursor()
200
+ c.execute(fetch_sql)
201
+ picklist = set()
202
+ for x in c.fetchall():
203
+ if isinstance(x[0], str):
204
+ picklist.add(x[0].encode("utf-8"))
205
+ elif isinstance(x[0], bytes):
206
+ try:
207
+ picklist.add(x[0].decode("utf-8"))
208
+ except UnicodeDecodeError:
209
+ picklist.add(x[0].decode("latin-1"))
210
+ else:
211
+ picklist.add(x[0])
212
+ picklist = list(picklist)
213
+ except Exception as e:
214
+ picklist = []
215
+ finally:
216
+ conn.close()
217
+ return picklist
218
+
219
+
220
+ def get_database_matches(
221
+ question: str,
222
+ table_name: str,
223
+ column_name: str,
224
+ db_path: str,
225
+ top_k_matches: int = 2,
226
+ match_threshold: float = 0.85,
227
+ ) -> List[str]:
228
+ picklist = get_column_picklist(
229
+ table_name=table_name, column_name=column_name, db_path=db_path
230
+ )
231
+ # only maintain data in ``str'' type
232
+ picklist = [ele.strip() for ele in picklist if isinstance(ele, str)]
233
+ # picklist is unordered, we sort it to ensure the reproduction stability
234
+ picklist = sorted(picklist)
235
+
236
+ matches = []
237
+ if picklist and isinstance(picklist[0], str):
238
+ matched_entries = get_matched_entries(
239
+ s=question,
240
+ field_values=picklist,
241
+ m_theta=match_threshold,
242
+ s_theta=match_threshold,
243
+ )
244
+
245
+ if matched_entries:
246
+ num_values_inserted = 0
247
+ for _match_str, (
248
+ field_value,
249
+ _s_match_str,
250
+ match_score,
251
+ s_match_score,
252
+ _match_size,
253
+ ) in matched_entries:
254
+ if "name" in column_name and match_score * s_match_score < 1:
255
+ continue
256
+ if table_name != "sqlite_sequence": # Spider database artifact
257
+ matches.append(field_value.strip())
258
+ num_values_inserted += 1
259
+ if num_values_inserted >= top_k_matches:
260
+ break
261
+ return matches
utils/classifier_model.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import AutoConfig, RobertaModel
5
+
6
+ class SchemaItemClassifier(nn.Module):
7
+ def __init__(self, model_name_or_path, mode):
8
+ super(SchemaItemClassifier, self).__init__()
9
+ if mode in ["eval", "test"]:
10
+ # load config
11
+ config = AutoConfig.from_pretrained(model_name_or_path)
12
+ # randomly initialize model's parameters according to the config
13
+ self.plm_encoder = RobertaModel(config)
14
+ elif mode == "train":
15
+ self.plm_encoder = RobertaModel.from_pretrained(model_name_or_path)
16
+ else:
17
+ raise ValueError()
18
+
19
+ self.plm_hidden_size = self.plm_encoder.config.hidden_size
20
+
21
+ # column cls head
22
+ self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
23
+ self.column_info_cls_head_linear2 = nn.Linear(256, 2)
24
+
25
+ # column bi-lstm layer
26
+ self.column_info_bilstm = nn.LSTM(
27
+ input_size = self.plm_hidden_size,
28
+ hidden_size = int(self.plm_hidden_size/2),
29
+ num_layers = 2,
30
+ dropout = 0,
31
+ bidirectional = True
32
+ )
33
+
34
+ # linear layer after column bi-lstm layer
35
+ self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
36
+
37
+ # table cls head
38
+ self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
39
+ self.table_name_cls_head_linear2 = nn.Linear(256, 2)
40
+
41
+ # table bi-lstm pooling layer
42
+ self.table_name_bilstm = nn.LSTM(
43
+ input_size = self.plm_hidden_size,
44
+ hidden_size = int(self.plm_hidden_size/2),
45
+ num_layers = 2,
46
+ dropout = 0,
47
+ bidirectional = True
48
+ )
49
+ # linear layer after table bi-lstm layer
50
+ self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
51
+
52
+ # activation function
53
+ self.leakyrelu = nn.LeakyReLU()
54
+ self.tanh = nn.Tanh()
55
+
56
+ # table-column cross-attention layer
57
+ self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8)
58
+
59
+ # dropout function, p=0.2 means randomly set 20% neurons to 0
60
+ self.dropout = nn.Dropout(p = 0.2)
61
+
62
+ def table_column_cross_attention(
63
+ self,
64
+ table_name_embeddings_in_one_db,
65
+ column_info_embeddings_in_one_db,
66
+ column_number_in_each_table
67
+ ):
68
+ table_num = table_name_embeddings_in_one_db.shape[0]
69
+ table_name_embedding_attn_list = []
70
+ for table_id in range(table_num):
71
+ table_name_embedding = table_name_embeddings_in_one_db[[table_id], :]
72
+ column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[
73
+ sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :]
74
+
75
+ table_name_embedding_attn, _ = self.table_column_cross_attention_layer(
76
+ table_name_embedding,
77
+ column_info_embeddings_in_one_table,
78
+ column_info_embeddings_in_one_table
79
+ )
80
+
81
+ table_name_embedding_attn_list.append(table_name_embedding_attn)
82
+
83
+ # residual connection
84
+ table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0)
85
+ # row-wise L2 norm
86
+ table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1)
87
+
88
+ return table_name_embeddings_in_one_db
89
+
90
+ def table_column_cls(
91
+ self,
92
+ encoder_input_ids,
93
+ encoder_input_attention_mask,
94
+ batch_aligned_column_info_ids,
95
+ batch_aligned_table_name_ids,
96
+ batch_column_number_in_each_table
97
+ ):
98
+ batch_size = encoder_input_ids.shape[0]
99
+
100
+ encoder_output = self.plm_encoder(
101
+ input_ids = encoder_input_ids,
102
+ attention_mask = encoder_input_attention_mask,
103
+ return_dict = True
104
+ ) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size)
105
+
106
+ batch_table_name_cls_logits, batch_column_info_cls_logits = [], []
107
+
108
+ # handle each data in current batch
109
+ for batch_id in range(batch_size):
110
+ column_number_in_each_table = batch_column_number_in_each_table[batch_id]
111
+ sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size)
112
+
113
+ # obtain table ids for each table
114
+ aligned_table_name_ids = batch_aligned_table_name_ids[batch_id]
115
+ # obtain column ids for each column
116
+ aligned_column_info_ids = batch_aligned_column_info_ids[batch_id]
117
+
118
+ table_name_embedding_list, column_info_embedding_list = [], []
119
+
120
+ # obtain table embedding via bi-lstm pooling + a non-linear layer
121
+ for table_name_ids in aligned_table_name_ids:
122
+ table_name_embeddings = sequence_embeddings[table_name_ids, :]
123
+
124
+ # BiLSTM pooling
125
+ output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings)
126
+ table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size)
127
+ table_name_embedding_list.append(table_name_embedding)
128
+ table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0)
129
+ # non-linear mlp layer
130
+ table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db))
131
+
132
+ # obtain column embedding via bi-lstm pooling + a non-linear layer
133
+ for column_info_ids in aligned_column_info_ids:
134
+ column_info_embeddings = sequence_embeddings[column_info_ids, :]
135
+
136
+ # BiLSTM pooling
137
+ output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings)
138
+ column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size)
139
+ column_info_embedding_list.append(column_info_embedding)
140
+ column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0)
141
+ # non-linear mlp layer
142
+ column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db))
143
+
144
+ # table-column (tc) cross-attention
145
+ table_name_embeddings_in_one_db = self.table_column_cross_attention(
146
+ table_name_embeddings_in_one_db,
147
+ column_info_embeddings_in_one_db,
148
+ column_number_in_each_table
149
+ )
150
+
151
+ # calculate table 0-1 logits
152
+ table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db)
153
+ table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db))
154
+ table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db)
155
+
156
+ # calculate column 0-1 logits
157
+ column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db)
158
+ column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db))
159
+ column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db)
160
+
161
+ batch_table_name_cls_logits.append(table_name_cls_logits)
162
+ batch_column_info_cls_logits.append(column_info_cls_logits)
163
+
164
+ return batch_table_name_cls_logits, batch_column_info_cls_logits
165
+
166
+ def forward(
167
+ self,
168
+ encoder_input_ids,
169
+ encoder_attention_mask,
170
+ batch_aligned_column_info_ids,
171
+ batch_aligned_table_name_ids,
172
+ batch_column_number_in_each_table,
173
+ ):
174
+ batch_table_name_cls_logits, batch_column_info_cls_logits \
175
+ = self.table_column_cls(
176
+ encoder_input_ids,
177
+ encoder_attention_mask,
178
+ batch_aligned_column_info_ids,
179
+ batch_aligned_table_name_ids,
180
+ batch_column_number_in_each_table
181
+ )
182
+
183
+ return {
184
+ "batch_table_name_cls_logits" : batch_table_name_cls_logits,
185
+ "batch_column_info_cls_logits": batch_column_info_cls_logits
186
+ }
utils/db_utils.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import sqlite3
4
+
5
+ from func_timeout import func_set_timeout, FunctionTimedOut
6
+ from utils.bridge_content_encoder import get_matched_entries
7
+ from nltk.tokenize import word_tokenize
8
+ from nltk import ngrams
9
+
10
+ def add_a_record(question, db_id):
11
+ conn = sqlite3.connect('data/history/history.sqlite')
12
+ cursor = conn.cursor()
13
+ cursor.execute("INSERT INTO record (question, db_id) VALUES (?, ?)", (question, db_id))
14
+
15
+ conn.commit()
16
+ conn.close()
17
+
18
+ def obtain_n_grams(sequence, max_n):
19
+ tokens = word_tokenize(sequence)
20
+ all_grams = []
21
+ for n in range(1, max_n + 1):
22
+ all_grams.extend([" ".join(gram) for gram in ngrams(tokens, n)])
23
+
24
+ return all_grams
25
+
26
+ # get the database cursor for a sqlite database path
27
+ def get_cursor_from_path(sqlite_path):
28
+ try:
29
+ if not os.path.exists(sqlite_path):
30
+ print("Openning a new connection %s" % sqlite_path)
31
+ connection = sqlite3.connect(sqlite_path, check_same_thread = False)
32
+ except Exception as e:
33
+ print(sqlite_path)
34
+ raise e
35
+ connection.text_factory = lambda b: b.decode(errors="ignore")
36
+ cursor = connection.cursor()
37
+ return cursor
38
+
39
+ # execute predicted sql with a time limitation
40
+ @func_set_timeout(15)
41
+ def execute_sql(cursor, sql):
42
+ cursor.execute(sql)
43
+
44
+ return cursor.fetchall()
45
+
46
+ # execute predicted sql with a long time limitation (for buiding content index)
47
+ @func_set_timeout(2000)
48
+ def execute_sql_long_time_limitation(cursor, sql):
49
+ cursor.execute(sql)
50
+
51
+ return cursor.fetchall()
52
+
53
+ def check_sql_executability(generated_sql, db):
54
+ if generated_sql.strip() == "":
55
+ return "Error: empty string"
56
+ try:
57
+ cursor = get_cursor_from_path(db)
58
+ execute_sql(cursor, generated_sql)
59
+ execution_error = None
60
+ except FunctionTimedOut as fto:
61
+ print("SQL execution time out error: {}.".format(fto))
62
+ execution_error = "SQL execution times out."
63
+ except Exception as e:
64
+ print("SQL execution runtime error: {}.".format(e))
65
+ execution_error = str(e)
66
+
67
+ return execution_error
68
+
69
+ def is_number(s):
70
+ try:
71
+ float(s)
72
+ return True
73
+ except ValueError:
74
+ return False
75
+
76
+ def detect_special_char(name):
77
+ for special_char in ['(', '-', ')', ' ', '/']:
78
+ if special_char in name:
79
+ return True
80
+
81
+ return False
82
+
83
+ def add_quotation_mark(s):
84
+ return "`" + s + "`"
85
+
86
+ def get_column_contents(column_name, table_name, cursor):
87
+ select_column_sql = "SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL LIMIT 2;".format(column_name, table_name, column_name)
88
+ results = execute_sql_long_time_limitation(cursor, select_column_sql)
89
+ column_contents = [str(result[0]).strip() for result in results]
90
+ # remove empty and extremely-long contents
91
+ column_contents = [content for content in column_contents if len(content) != 0 and len(content) <= 25]
92
+
93
+ return column_contents
94
+
95
+ def get_matched_contents(question, searcher):
96
+ # coarse-grained matching between the input text and all contents in database
97
+ grams = obtain_n_grams(question, 4)
98
+ hits = []
99
+ for query in grams:
100
+ hits.extend(searcher.search(query, k = 10))
101
+
102
+ coarse_matched_contents = dict()
103
+ for i in range(len(hits)):
104
+ matched_result = json.loads(hits[i].raw)
105
+ # `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
106
+ tc_name = ".".join(matched_result["id"].split("-**-")[:2])
107
+ if tc_name in coarse_matched_contents.keys():
108
+ if matched_result["contents"] not in coarse_matched_contents[tc_name]:
109
+ coarse_matched_contents[tc_name].append(matched_result["contents"])
110
+ else:
111
+ coarse_matched_contents[tc_name] = [matched_result["contents"]]
112
+
113
+ fine_matched_contents = dict()
114
+ for tc_name, contents in coarse_matched_contents.items():
115
+ # fine-grained matching between the question and coarse matched contents
116
+ fm_contents = get_matched_entries(question, contents)
117
+
118
+ if fm_contents is None:
119
+ continue
120
+ for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents:
121
+ if match_score < 0.9:
122
+ continue
123
+ if tc_name in fine_matched_contents.keys():
124
+ if len(fine_matched_contents[tc_name]) < 25:
125
+ fine_matched_contents[tc_name].append(field_value.strip())
126
+ else:
127
+ fine_matched_contents[tc_name] = [field_value.strip()]
128
+
129
+ return fine_matched_contents
130
+
131
+ def get_db_schema_sequence(schema):
132
+ schema_sequence = "database schema :\n"
133
+ for table in schema["schema_items"]:
134
+ table_name, table_comment = table["table_name"], table["table_comment"]
135
+ if detect_special_char(table_name):
136
+ table_name = add_quotation_mark(table_name)
137
+
138
+ # if table_comment != "":
139
+ # table_name += " ( comment : " + table_comment + " )"
140
+
141
+ column_info_list = []
142
+ for column_name, column_type, column_comment, column_content, pk_indicator in \
143
+ zip(table["column_names"], table["column_types"], table["column_comments"], table["column_contents"], table["pk_indicators"]):
144
+ if detect_special_char(column_name):
145
+ column_name = add_quotation_mark(column_name)
146
+ additional_column_info = []
147
+ # column type
148
+ additional_column_info.append(column_type)
149
+ # pk indicator
150
+ if pk_indicator != 0:
151
+ additional_column_info.append("primary key")
152
+ # column comment
153
+ if column_comment != "":
154
+ additional_column_info.append("comment : " + column_comment)
155
+ # representive column values
156
+ if len(column_content) != 0:
157
+ additional_column_info.append("values : " + " , ".join(column_content))
158
+
159
+ column_info_list.append(table_name + "." + column_name + " ( " + " | ".join(additional_column_info) + " )")
160
+
161
+ schema_sequence += "table "+ table_name + " , columns = [ " + " , ".join(column_info_list) + " ]\n"
162
+
163
+ if len(schema["foreign_keys"]) != 0:
164
+ schema_sequence += "foreign keys :\n"
165
+ for foreign_key in schema["foreign_keys"]:
166
+ for i in range(len(foreign_key)):
167
+ if detect_special_char(foreign_key[i]):
168
+ foreign_key[i] = add_quotation_mark(foreign_key[i])
169
+ schema_sequence += "{}.{} = {}.{}\n".format(foreign_key[0], foreign_key[1], foreign_key[2], foreign_key[3])
170
+ else:
171
+ schema_sequence += "foreign keys : None\n"
172
+
173
+ return schema_sequence.strip()
174
+
175
+ def get_matched_content_sequence(matched_contents):
176
+ content_sequence = ""
177
+ if len(matched_contents) != 0:
178
+ content_sequence += "matched contents :\n"
179
+ for tc_name, contents in matched_contents.items():
180
+ table_name = tc_name.split(".")[0]
181
+ column_name = tc_name.split(".")[1]
182
+ if detect_special_char(table_name):
183
+ table_name = add_quotation_mark(table_name)
184
+ if detect_special_char(column_name):
185
+ column_name = add_quotation_mark(column_name)
186
+
187
+ content_sequence += table_name + "." + column_name + " ( " + " , ".join(contents) + " )\n"
188
+ else:
189
+ content_sequence = "matched contents : None"
190
+
191
+ return content_sequence.strip()
192
+
193
+ def get_db_schema(db_path, db_comments, db_id):
194
+ if db_id in db_comments:
195
+ db_comment = db_comments[db_id]
196
+ else:
197
+ db_comment = None
198
+
199
+ cursor = get_cursor_from_path(db_path)
200
+
201
+ # obtain table names
202
+ results = execute_sql(cursor, "SELECT name FROM sqlite_master WHERE type='table';")
203
+ table_names = [result[0].lower() for result in results]
204
+
205
+ schema = dict()
206
+ schema["schema_items"] = []
207
+ foreign_keys = []
208
+ # for each table
209
+ for table_name in table_names:
210
+ # skip SQLite system table: sqlite_sequence
211
+ if table_name == "sqlite_sequence":
212
+ continue
213
+ # obtain column names in the current table
214
+ results = execute_sql(cursor, "SELECT name, type, pk FROM PRAGMA_TABLE_INFO('{}')".format(table_name))
215
+ column_names_in_one_table = [result[0].lower() for result in results]
216
+ column_types_in_one_table = [result[1].lower() for result in results]
217
+ pk_indicators_in_one_table = [result[2] for result in results]
218
+
219
+ column_contents = []
220
+ for column_name in column_names_in_one_table:
221
+ column_contents.append(get_column_contents(column_name, table_name, cursor))
222
+
223
+ # obtain foreign keys in the current table
224
+ results = execute_sql(cursor, "SELECT * FROM pragma_foreign_key_list('{}');".format(table_name))
225
+ for result in results:
226
+ if None not in [result[3], result[2], result[4]]:
227
+ foreign_keys.append([table_name.lower(), result[3].lower(), result[2].lower(), result[4].lower()])
228
+
229
+ # obtain comments for each schema item
230
+ if db_comment is not None:
231
+ if table_name in db_comment: # record comments for tables and columns
232
+ table_comment = db_comment[table_name]["table_comment"]
233
+ column_comments = [db_comment[table_name]["column_comments"][column_name] \
234
+ if column_name in db_comment[table_name]["column_comments"] else "" \
235
+ for column_name in column_names_in_one_table]
236
+ else: # current database has comment information, but the current table does not
237
+ table_comment = ""
238
+ column_comments = ["" for _ in column_names_in_one_table]
239
+ else: # current database has no comment information
240
+ table_comment = ""
241
+ column_comments = ["" for _ in column_names_in_one_table]
242
+
243
+ schema["schema_items"].append({
244
+ "table_name": table_name,
245
+ "table_comment": table_comment,
246
+ "column_names": column_names_in_one_table,
247
+ "column_types": column_types_in_one_table,
248
+ "column_comments": column_comments,
249
+ "column_contents": column_contents,
250
+ "pk_indicators": pk_indicators_in_one_table
251
+ })
252
+
253
+ schema["foreign_keys"] = foreign_keys
254
+
255
+ return schema
utils/translate_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import random
3
+ import json
4
+
5
+ def translate_zh_to_en(question, token):
6
+ url = 'https://aip.baidubce.com/rpc/2.0/mt/texttrans/v1?access_token=' + token
7
+
8
+ from_lang = 'auto'
9
+ to_lang = 'en'
10
+ term_ids = ''
11
+
12
+ # Build request
13
+ headers = {'Content-Type': 'application/json'}
14
+ payload = {'q': question, 'from': from_lang, 'to': to_lang, 'termIds' : term_ids}
15
+
16
+ # Send request
17
+ r = requests.post(url, params=payload, headers=headers)
18
+ result = r.json()
19
+
20
+ return result["result"]["trans_result"][0]["dst"]
21
+
22
+ if __name__ == "__main__":
23
+ print(translate_zh_to_en("你好啊!"))