# import streamlit as st # from text2sql import ChatBot # from langdetect import detect # from utils.translate_utils import translate_zh_to_en # from utils.db_utils import add_a_record # from langdetect.lang_detect_exception import LangDetectException # # Initialize chatbot and other variables # text2sql_bot = ChatBot() # baidu_api_token = None # # Define database schemas for demonstration # db_schemas = { # "singer": """ # CREATE TABLE "singer" ( # "Singer_ID" int, # "Name" text, # "Birth_Year" real, # "Net_Worth_Millions" real, # "Citizenship" text, # PRIMARY KEY ("Singer_ID") # ); # CREATE TABLE "song" ( # "Song_ID" int, # "Title" text, # "Singer_ID" int, # "Sales" real, # "Highest_Position" real, # PRIMARY KEY ("Song_ID"), # FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID") # ); # """, # # Add other schemas as needed # } # # Streamlit UI # st.title("Text-to-SQL Chatbot") # st.sidebar.header("Select a Database") # # Sidebar for selecting a database # selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys())) # # Display the selected schema # st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600) # # User input section # question = st.text_input("Enter your question:") # db_id = selected_db # Use selected database for DB ID # if question: # add_a_record(question, db_id) # try: # if baidu_api_token is not None and detect(question) != "en": # print("Before translation:", question) # question = translate_zh_to_en(question, baidu_api_token) # print("After translation:", question) # except LangDetectException as e: # print("Language detection error:", str(e)) # predicted_sql = text2sql_bot.get_response(question, db_id) # st.write(f"**Database:** {db_id}") # st.write(f"**Predicted SQL query:** {predicted_sql}") import streamlit as st from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, logging as hf_logging ) from langdetect import detect from utils.translate_utils import translate_zh_to_en from utils.db_utils import add_a_record from langdetect.lang_detect_exception import LangDetectException import os # Suppress excessive warnings from Hugging Face transformers library hf_logging.set_verbosity_error() # SchemaItemClassifierInference class for loading the Hugging Face model class SchemaItemClassifierInference: def __init__(self, model_name: str, token=None): """ model_name: Hugging Face repository path, e.g., "Roxanne-WANG/LangSQL" token: Authentication token for Hugging Face (if the model is private) """ # Load the tokenizer and model from Hugging Face, trust remote code if needed self.tokenizer = AutoTokenizer.from_pretrained( model_name, use_auth_token=token, # Pass the token for accessing private models trust_remote_code=True # Trust custom model code from Hugging Face repo ) self.model = AutoModelForSequenceClassification.from_pretrained( model_name, use_auth_token=token, trust_remote_code=True ) def predict(self, text: str): # Tokenize the input text and get predictions from the model inputs = self.tokenizer( text, return_tensors="pt", padding=True, truncation=True ) outputs = self.model(**inputs) return outputs.logits # ChatBot class that interacts with SchemaItemClassifierInference class ChatBot: def __init__(self): # Specify the Hugging Face model name (replace with your model's path) model_name = "Roxanne-WANG/LangSQL" hf_token = os.getenv('HF_TOKEN') # Get token from environment variables if hf_token is None: raise ValueError("Hugging Face token is required. Please set HF_TOKEN.") # Initialize the schema item classifier with Hugging Face token self.sic = SchemaItemClassifierInference(model_name, token=hf_token) def get_response(self, question: str, db_id: str): # Get the model's prediction (logits) for the input question logits = self.sic.predict(question) # For now, return logits as a placeholder for the actual SQL query return logits # -------- Streamlit Web Application -------- text2sql_bot = ChatBot() baidu_api_token = None # Your Baidu API token (if needed for translation) # Define some database schemas for demonstration purposes db_schemas = { "singer": """ CREATE TABLE "singer" ( "Singer_ID" int, "Name" text, "Birth_Year" real, "Net_Worth_Millions" real, "Citizenship" text, PRIMARY KEY ("Singer_ID") ); CREATE TABLE "song" ( "Song_ID" int, "Title" text, "Singer_ID" int, "Sales" real, "Highest_Position" real, PRIMARY KEY ("Song_ID"), FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID") ); """, # More schemas can be added here } # Streamlit interface st.title("Text-to-SQL Chatbot") st.sidebar.header("Select a Database") selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys())) st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600) # Get user input for the question question = st.text_input("Enter your question:") db_id = selected_db if question: # Store the question in the database (or perform any additional processing) add_a_record(question, db_id) try: # If translation is required, handle it here if baidu_api_token and detect(question) != "en": question = translate_zh_to_en(question, baidu_api_token) except LangDetectException as e: st.warning(f"Language detection error: {e}") # Get the model's response (in this case, SQL query or logits) response = text2sql_bot.get_response(question, db_id) st.write(f"**Database:** {db_id}") st.write(f"**Model logits (Example Output):** {response}")