LangSQL / app.py
Roxanne-WANG's picture
update token
b053c71
raw
history blame
6.22 kB
# import streamlit as st
# from text2sql import ChatBot
# from langdetect import detect
# from utils.translate_utils import translate_zh_to_en
# from utils.db_utils import add_a_record
# from langdetect.lang_detect_exception import LangDetectException
# # Initialize chatbot and other variables
# text2sql_bot = ChatBot()
# baidu_api_token = None
# # Define database schemas for demonstration
# db_schemas = {
# "singer": """
# CREATE TABLE "singer" (
# "Singer_ID" int,
# "Name" text,
# "Birth_Year" real,
# "Net_Worth_Millions" real,
# "Citizenship" text,
# PRIMARY KEY ("Singer_ID")
# );
# CREATE TABLE "song" (
# "Song_ID" int,
# "Title" text,
# "Singer_ID" int,
# "Sales" real,
# "Highest_Position" real,
# PRIMARY KEY ("Song_ID"),
# FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
# );
# """,
# # Add other schemas as needed
# }
# # Streamlit UI
# st.title("Text-to-SQL Chatbot")
# st.sidebar.header("Select a Database")
# # Sidebar for selecting a database
# selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
# # Display the selected schema
# st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)
# # User input section
# question = st.text_input("Enter your question:")
# db_id = selected_db # Use selected database for DB ID
# if question:
# add_a_record(question, db_id)
# try:
# if baidu_api_token is not None and detect(question) != "en":
# print("Before translation:", question)
# question = translate_zh_to_en(question, baidu_api_token)
# print("After translation:", question)
# except LangDetectException as e:
# print("Language detection error:", str(e))
# predicted_sql = text2sql_bot.get_response(question, db_id)
# st.write(f"**Database:** {db_id}")
# st.write(f"**Predicted SQL query:** {predicted_sql}")
import streamlit as st
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
logging as hf_logging
)
from langdetect import detect
from utils.translate_utils import translate_zh_to_en
from utils.db_utils import add_a_record
from langdetect.lang_detect_exception import LangDetectException
import os
# Suppress excessive warnings from Hugging Face transformers library
hf_logging.set_verbosity_error()
# SchemaItemClassifierInference class for loading the Hugging Face model
class SchemaItemClassifierInference:
def __init__(self, model_name: str, token=None):
"""
model_name: Hugging Face repository path, e.g., "Roxanne-WANG/LangSQL"
token: Authentication token for Hugging Face (if the model is private)
"""
# Load the tokenizer and model from Hugging Face, trust remote code if needed
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_auth_token=token, # Pass the token for accessing private models
trust_remote_code=True # Trust custom model code from Hugging Face repo
)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name,
use_auth_token=token,
trust_remote_code=True
)
def predict(self, text: str):
# Tokenize the input text and get predictions from the model
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True
)
outputs = self.model(**inputs)
return outputs.logits
# ChatBot class that interacts with SchemaItemClassifierInference
class ChatBot:
def __init__(self):
# Specify the Hugging Face model name (replace with your model's path)
model_name = "Roxanne-WANG/LangSQL"
hf_token = os.getenv('HF_TOKEN') # Get token from environment variables
if hf_token is None:
raise ValueError("Hugging Face token is required. Please set HF_TOKEN.")
# Initialize the schema item classifier with Hugging Face token
self.sic = SchemaItemClassifierInference(model_name, token=hf_token)
def get_response(self, question: str, db_id: str):
# Get the model's prediction (logits) for the input question
logits = self.sic.predict(question)
# For now, return logits as a placeholder for the actual SQL query
return logits
# -------- Streamlit Web Application --------
text2sql_bot = ChatBot()
baidu_api_token = None # Your Baidu API token (if needed for translation)
# Define some database schemas for demonstration purposes
db_schemas = {
"singer": """
CREATE TABLE "singer" (
"Singer_ID" int,
"Name" text,
"Birth_Year" real,
"Net_Worth_Millions" real,
"Citizenship" text,
PRIMARY KEY ("Singer_ID")
);
CREATE TABLE "song" (
"Song_ID" int,
"Title" text,
"Singer_ID" int,
"Sales" real,
"Highest_Position" real,
PRIMARY KEY ("Song_ID"),
FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
);
""",
# More schemas can be added here
}
# Streamlit interface
st.title("Text-to-SQL Chatbot")
st.sidebar.header("Select a Database")
selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)
# Get user input for the question
question = st.text_input("Enter your question:")
db_id = selected_db
if question:
# Store the question in the database (or perform any additional processing)
add_a_record(question, db_id)
try:
# If translation is required, handle it here
if baidu_api_token and detect(question) != "en":
question = translate_zh_to_en(question, baidu_api_token)
except LangDetectException as e:
st.warning(f"Language detection error: {e}")
# Get the model's response (in this case, SQL query or logits)
response = text2sql_bot.get_response(question, db_id)
st.write(f"**Database:** {db_id}")
st.write(f"**Model logits (Example Output):** {response}")