File size: 4,440 Bytes
b5be522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b759b87
b5be522
b759b87
 
 
 
 
b5be522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b759b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5be522
b759b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# import streamlit as st
# from text2sql import ChatBot
# from langdetect import detect
# from utils.translate_utils import translate_zh_to_en
# from utils.db_utils import add_a_record
# from langdetect.lang_detect_exception import LangDetectException

# # Initialize chatbot and other variables
# text2sql_bot = ChatBot()
# baidu_api_token = None

# # Define database schemas for demonstration
# db_schemas = {
#     "singer": """
#     CREATE TABLE "singer" (
#         "Singer_ID" int,
#         "Name" text,
#         "Birth_Year" real,
#         "Net_Worth_Millions" real,
#         "Citizenship" text,
#         PRIMARY KEY ("Singer_ID")
#     );

#     CREATE TABLE "song" (
#         "Song_ID" int,
#         "Title" text,
#         "Singer_ID" int,
#         "Sales" real,
#         "Highest_Position" real,
#         PRIMARY KEY ("Song_ID"),
#         FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
#     );
#     """,
#     # Add other schemas as needed
# }

# # Streamlit UI
# st.title("Text-to-SQL Chatbot")
# st.sidebar.header("Select a Database")

# # Sidebar for selecting a database
# selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))

# # Display the selected schema
# st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)

# # User input section
# question = st.text_input("Enter your question:")
# db_id = selected_db  # Use selected database for DB ID

# if question:
#     add_a_record(question, db_id)

#     try:
#         if baidu_api_token is not None and detect(question) != "en":
#             print("Before translation:", question)
#             question = translate_zh_to_en(question, baidu_api_token)
#             print("After translation:", question)
#     except LangDetectException as e:
#         print("Language detection error:", str(e))

#     predicted_sql = text2sql_bot.get_response(question, db_id)
#     st.write(f"**Database:** {db_id}")
#     st.write(f"**Predicted SQL query:** {predicted_sql}")


import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from langdetect import detect
from utils.translate_utils import translate_zh_to_en
from utils.db_utils import add_a_record
from langdetect.lang_detect_exception import LangDetectException

class SchemaItemClassifierInference:
    def __init__(self, model_name):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=True)

    def predict(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        outputs = self.model(**inputs)
        return outputs.logits

class ChatBot:
    def __init__(self):
        model_name = "Roxanne-WANG/LangSQL"
        self.sic = SchemaItemClassifierInference(model_name)

    def get_response(self, question, db_id):
        prediction = self.sic.predict(question)
        return prediction

text2sql_bot = ChatBot()
baidu_api_token = None

db_schemas = {
    "singer": """
    CREATE TABLE "singer" (
        "Singer_ID" int,
        "Name" text,
        "Birth_Year" real,
        "Net_Worth_Millions" real,
        "Citizenship" text,
        PRIMARY KEY ("Singer_ID")
    );

    CREATE TABLE "song" (
        "Song_ID" int,
        "Title" text,
        "Singer_ID" int,
        "Sales" real,
        "Highest_Position" real,
        PRIMARY KEY ("Song_ID"),
        FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
    );
    """,
}

# Streamlit UI
st.title("Text-to-SQL Chatbot")
st.sidebar.header("Select a Database")

selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))

st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)

question = st.text_input("Enter your question:")
db_id = selected_db

if question:
    add_a_record(question, db_id)

    try:
        if baidu_api_token is not None and detect(question) != "en":
            print("Before translation:", question)
            question = translate_zh_to_en(question, baidu_api_token)
            print("After translation:", question)
    except LangDetectException as e:
        print("Language detection error:", str(e))

    predicted_sql = text2sql_bot.get_response(question, db_id)
    st.write(f"**Database:** {db_id}")
    st.write(f"**Predicted SQL query:** {predicted_sql}")