File size: 2,085 Bytes
b759b87
50405fd
b053c71
 
 
 
 
b759b87
 
 
 
b053c71
b759b87
b053c71
b759b87
b053c71
b759b87
b053c71
b759b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b053c71
b759b87
 
b053c71
b759b87
 
 
 
 
b053c71
b759b87
b5be522
b759b87
 
b053c71
b759b87
 
 
b053c71
 
b759b87
 
b053c71
b759b87
b053c71
 
b759b87
ebacf16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import streamlit as st
from text2sql import ChatBot
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    logging as hf_logging
)
from langdetect import detect
from utils.translate_utils import translate_zh_to_en
from utils.db_utils import add_a_record
from langdetect.lang_detect_exception import LangDetectException
import os

# -------- Streamlit Web Application --------
text2sql_bot = ChatBot()
baidu_api_token = None  # Your Baidu API token (if needed for translation)

# Define some database schemas for demonstration purposes
db_schemas = {
    "singer": """
    CREATE TABLE "singer" (
        "Singer_ID" int,
        "Name" text,
        "Birth_Year" real,
        "Net_Worth_Millions" real,
        "Citizenship" text,
        PRIMARY KEY ("Singer_ID")
    );

    CREATE TABLE "song" (
        "Song_ID" int,
        "Title" text,
        "Singer_ID" int,
        "Sales" real,
        "Highest_Position" real,
        PRIMARY KEY ("Song_ID"),
        FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
    );
    """,
    # More schemas can be added here
}

# Streamlit interface
st.title("Text-to-SQL Chatbot")
st.sidebar.header("Select a Database")
selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)

# Get user input for the question
question = st.text_input("Enter your question:")
db_id = selected_db

if question:
    # Store the question in the database (or perform any additional processing)
    add_a_record(question, db_id)

    try:
        # If translation is required, handle it here
        if baidu_api_token and detect(question) != "en":
            question = translate_zh_to_en(question, baidu_api_token)
    except LangDetectException as e:
        st.warning(f"Language detection error: {e}")

    # Get the model's response (in this case, SQL query or logits)
    response = text2sql_bot.get_response(question, db_id)
    st.write(f"**Database:** {db_id}")
    st.write(f"**Results:** {response}")