File size: 1,901 Bytes
b759b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import streamlit as st
from text2sql import ChatBot
from langdetect import detect
from utils.translate_utils import translate_zh_to_en
from utils.db_utils import add_a_record
from langdetect.lang_detect_exception import LangDetectException

# Initialize chatbot and other variables
text2sql_bot = ChatBot()
baidu_api_token = None

# Define database schemas for demonstration
db_schemas = {
    "singer": """
    CREATE TABLE "singer" (
        "Singer_ID" int,
        "Name" text,
        "Birth_Year" real,
        "Net_Worth_Millions" real,
        "Citizenship" text,
        PRIMARY KEY ("Singer_ID")
    );

    CREATE TABLE "song" (
        "Song_ID" int,
        "Title" text,
        "Singer_ID" int,
        "Sales" real,
        "Highest_Position" real,
        PRIMARY KEY ("Song_ID"),
        FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
    );
    """,
    # Add other schemas as needed
}

# Streamlit UI
st.title("Text-to-SQL Chatbot")
st.sidebar.header("Select a Database")

# Sidebar for selecting a database
selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))

# Display the selected schema
st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)

# User input section
question = st.text_input("Enter your question:")
db_id = selected_db  # Use selected database for DB ID

if question:
    add_a_record(question, db_id)

    try:
        if baidu_api_token is not None and detect(question) != "en":
            print("Before translation:", question)
            question = translate_zh_to_en(question, baidu_api_token)
            print("After translation:", question)
    except LangDetectException as e:
        print("Language detection error:", str(e))

    predicted_sql = text2sql_bot.get_response(question, db_id)
    st.write(f"**Database:** {db_id}")
    st.write(f"**Predicted SQL query:** {predicted_sql}")