LangSQL / app.py
Roxanne-WANG's picture
update final code
ebacf16
raw
history blame
2.09 kB
import streamlit as st
from text2sql import ChatBot
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
logging as hf_logging
)
from langdetect import detect
from utils.translate_utils import translate_zh_to_en
from utils.db_utils import add_a_record
from langdetect.lang_detect_exception import LangDetectException
import os
# -------- Streamlit Web Application --------
text2sql_bot = ChatBot()
baidu_api_token = None # Your Baidu API token (if needed for translation)
# Define some database schemas for demonstration purposes
db_schemas = {
"singer": """
CREATE TABLE "singer" (
"Singer_ID" int,
"Name" text,
"Birth_Year" real,
"Net_Worth_Millions" real,
"Citizenship" text,
PRIMARY KEY ("Singer_ID")
);
CREATE TABLE "song" (
"Song_ID" int,
"Title" text,
"Singer_ID" int,
"Sales" real,
"Highest_Position" real,
PRIMARY KEY ("Song_ID"),
FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
);
""",
# More schemas can be added here
}
# Streamlit interface
st.title("Text-to-SQL Chatbot")
st.sidebar.header("Select a Database")
selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)
# Get user input for the question
question = st.text_input("Enter your question:")
db_id = selected_db
if question:
# Store the question in the database (or perform any additional processing)
add_a_record(question, db_id)
try:
# If translation is required, handle it here
if baidu_api_token and detect(question) != "en":
question = translate_zh_to_en(question, baidu_api_token)
except LangDetectException as e:
st.warning(f"Language detection error: {e}")
# Get the model's response (in this case, SQL query or logits)
response = text2sql_bot.get_response(question, db_id)
st.write(f"**Database:** {db_id}")
st.write(f"**Results:** {response}")