File size: 6,221 Bytes
b5be522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b759b87
b053c71
 
 
 
 
b759b87
 
 
 
b053c71
b759b87
b053c71
 
b5be522
b053c71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5be522
 
 
b053c71
 
b5be522
 
b053c71
b5be522
b053c71
 
 
 
b5be522
b053c71
 
b5be522
b053c71
 
 
 
 
 
 
 
b759b87
b053c71
b759b87
b053c71
b759b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b053c71
b759b87
 
b053c71
b759b87
 
 
 
 
b053c71
b759b87
b5be522
b759b87
 
b053c71
b759b87
 
 
b053c71
 
b759b87
 
b053c71
b759b87
b053c71
 
b759b87
b053c71
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# import streamlit as st
# from text2sql import ChatBot
# from langdetect import detect
# from utils.translate_utils import translate_zh_to_en
# from utils.db_utils import add_a_record
# from langdetect.lang_detect_exception import LangDetectException

# # Initialize chatbot and other variables
# text2sql_bot = ChatBot()
# baidu_api_token = None

# # Define database schemas for demonstration
# db_schemas = {
#     "singer": """
#     CREATE TABLE "singer" (
#         "Singer_ID" int,
#         "Name" text,
#         "Birth_Year" real,
#         "Net_Worth_Millions" real,
#         "Citizenship" text,
#         PRIMARY KEY ("Singer_ID")
#     );

#     CREATE TABLE "song" (
#         "Song_ID" int,
#         "Title" text,
#         "Singer_ID" int,
#         "Sales" real,
#         "Highest_Position" real,
#         PRIMARY KEY ("Song_ID"),
#         FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
#     );
#     """,
#     # Add other schemas as needed
# }

# # Streamlit UI
# st.title("Text-to-SQL Chatbot")
# st.sidebar.header("Select a Database")

# # Sidebar for selecting a database
# selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))

# # Display the selected schema
# st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)

# # User input section
# question = st.text_input("Enter your question:")
# db_id = selected_db  # Use selected database for DB ID

# if question:
#     add_a_record(question, db_id)

#     try:
#         if baidu_api_token is not None and detect(question) != "en":
#             print("Before translation:", question)
#             question = translate_zh_to_en(question, baidu_api_token)
#             print("After translation:", question)
#     except LangDetectException as e:
#         print("Language detection error:", str(e))

#     predicted_sql = text2sql_bot.get_response(question, db_id)
#     st.write(f"**Database:** {db_id}")
#     st.write(f"**Predicted SQL query:** {predicted_sql}")


import streamlit as st
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    logging as hf_logging
)
from langdetect import detect
from utils.translate_utils import translate_zh_to_en
from utils.db_utils import add_a_record
from langdetect.lang_detect_exception import LangDetectException
import os

# Suppress excessive warnings from Hugging Face transformers library
hf_logging.set_verbosity_error()

# SchemaItemClassifierInference class for loading the Hugging Face model
class SchemaItemClassifierInference:
    def __init__(self, model_name: str, token=None):
        """
        model_name: Hugging Face repository path, e.g., "Roxanne-WANG/LangSQL"
        token: Authentication token for Hugging Face (if the model is private)
        """
        # Load the tokenizer and model from Hugging Face, trust remote code if needed
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            use_auth_token=token,  # Pass the token for accessing private models
            trust_remote_code=True  # Trust custom model code from Hugging Face repo
        )
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            use_auth_token=token,
            trust_remote_code=True
        )

    def predict(self, text: str):
        # Tokenize the input text and get predictions from the model
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True
        )
        outputs = self.model(**inputs)
        return outputs.logits


# ChatBot class that interacts with SchemaItemClassifierInference
class ChatBot:
    def __init__(self):
        # Specify the Hugging Face model name (replace with your model's path)
        model_name = "Roxanne-WANG/LangSQL"
        hf_token = os.getenv('HF_TOKEN')  # Get token from environment variables

        if hf_token is None:
            raise ValueError("Hugging Face token is required. Please set HF_TOKEN.")

        # Initialize the schema item classifier with Hugging Face token
        self.sic = SchemaItemClassifierInference(model_name, token=hf_token)

    def get_response(self, question: str, db_id: str):
        # Get the model's prediction (logits) for the input question
        logits = self.sic.predict(question)
        # For now, return logits as a placeholder for the actual SQL query
        return logits


# -------- Streamlit Web Application --------
text2sql_bot = ChatBot()
baidu_api_token = None  # Your Baidu API token (if needed for translation)

# Define some database schemas for demonstration purposes
db_schemas = {
    "singer": """
    CREATE TABLE "singer" (
        "Singer_ID" int,
        "Name" text,
        "Birth_Year" real,
        "Net_Worth_Millions" real,
        "Citizenship" text,
        PRIMARY KEY ("Singer_ID")
    );

    CREATE TABLE "song" (
        "Song_ID" int,
        "Title" text,
        "Singer_ID" int,
        "Sales" real,
        "Highest_Position" real,
        PRIMARY KEY ("Song_ID"),
        FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
    );
    """,
    # More schemas can be added here
}

# Streamlit interface
st.title("Text-to-SQL Chatbot")
st.sidebar.header("Select a Database")
selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)

# Get user input for the question
question = st.text_input("Enter your question:")
db_id = selected_db

if question:
    # Store the question in the database (or perform any additional processing)
    add_a_record(question, db_id)

    try:
        # If translation is required, handle it here
        if baidu_api_token and detect(question) != "en":
            question = translate_zh_to_en(question, baidu_api_token)
    except LangDetectException as e:
        st.warning(f"Language detection error: {e}")

    # Get the model's response (in this case, SQL query or logits)
    response = text2sql_bot.get_response(question, db_id)
    st.write(f"**Database:** {db_id}")
    st.write(f"**Model logits (Example Output):** {response}")