Spaces:
Paused
Paused
# import streamlit as st | |
# from text2sql import ChatBot | |
# from langdetect import detect | |
# from utils.translate_utils import translate_zh_to_en | |
# from utils.db_utils import add_a_record | |
# from langdetect.lang_detect_exception import LangDetectException | |
# # Initialize chatbot and other variables | |
# text2sql_bot = ChatBot() | |
# baidu_api_token = None | |
# # Define database schemas for demonstration | |
# db_schemas = { | |
# "singer": """ | |
# CREATE TABLE "singer" ( | |
# "Singer_ID" int, | |
# "Name" text, | |
# "Birth_Year" real, | |
# "Net_Worth_Millions" real, | |
# "Citizenship" text, | |
# PRIMARY KEY ("Singer_ID") | |
# ); | |
# CREATE TABLE "song" ( | |
# "Song_ID" int, | |
# "Title" text, | |
# "Singer_ID" int, | |
# "Sales" real, | |
# "Highest_Position" real, | |
# PRIMARY KEY ("Song_ID"), | |
# FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID") | |
# ); | |
# """, | |
# # Add other schemas as needed | |
# } | |
# # Streamlit UI | |
# st.title("Text-to-SQL Chatbot") | |
# st.sidebar.header("Select a Database") | |
# # Sidebar for selecting a database | |
# selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys())) | |
# # Display the selected schema | |
# st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600) | |
# # User input section | |
# question = st.text_input("Enter your question:") | |
# db_id = selected_db # Use selected database for DB ID | |
# if question: | |
# add_a_record(question, db_id) | |
# try: | |
# if baidu_api_token is not None and detect(question) != "en": | |
# print("Before translation:", question) | |
# question = translate_zh_to_en(question, baidu_api_token) | |
# print("After translation:", question) | |
# except LangDetectException as e: | |
# print("Language detection error:", str(e)) | |
# predicted_sql = text2sql_bot.get_response(question, db_id) | |
# st.write(f"**Database:** {db_id}") | |
# st.write(f"**Predicted SQL query:** {predicted_sql}") | |
import streamlit as st | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
logging as hf_logging | |
) | |
from langdetect import detect | |
from utils.translate_utils import translate_zh_to_en | |
from utils.db_utils import add_a_record | |
from langdetect.lang_detect_exception import LangDetectException | |
import os | |
# Suppress excessive warnings from Hugging Face transformers library | |
hf_logging.set_verbosity_error() | |
# SchemaItemClassifierInference class for loading the Hugging Face model | |
class SchemaItemClassifierInference: | |
def __init__(self, model_name: str, token=None): | |
""" | |
model_name: Hugging Face repository path, e.g., "Roxanne-WANG/LangSQL" | |
token: Authentication token for Hugging Face (if the model is private) | |
""" | |
# Load the tokenizer and model from Hugging Face, trust remote code if needed | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
use_auth_token=token, # Pass the token for accessing private models | |
trust_remote_code=True # Trust custom model code from Hugging Face repo | |
) | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
model_name, | |
use_auth_token=token, | |
trust_remote_code=True | |
) | |
def predict(self, text: str): | |
# Tokenize the input text and get predictions from the model | |
inputs = self.tokenizer( | |
text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True | |
) | |
outputs = self.model(**inputs) | |
return outputs.logits | |
# ChatBot class that interacts with SchemaItemClassifierInference | |
class ChatBot: | |
def __init__(self): | |
# Specify the Hugging Face model name (replace with your model's path) | |
model_name = "Roxanne-WANG/LangSQL" | |
hf_token = os.getenv('HF_TOKEN') # Get token from environment variables | |
if hf_token is None: | |
raise ValueError("Hugging Face token is required. Please set HF_TOKEN.") | |
# Initialize the schema item classifier with Hugging Face token | |
self.sic = SchemaItemClassifierInference(model_name, token=hf_token) | |
def get_response(self, question: str, db_id: str): | |
# Get the model's prediction (logits) for the input question | |
logits = self.sic.predict(question) | |
# For now, return logits as a placeholder for the actual SQL query | |
return logits | |
# -------- Streamlit Web Application -------- | |
text2sql_bot = ChatBot() | |
baidu_api_token = None # Your Baidu API token (if needed for translation) | |
# Define some database schemas for demonstration purposes | |
db_schemas = { | |
"singer": """ | |
CREATE TABLE "singer" ( | |
"Singer_ID" int, | |
"Name" text, | |
"Birth_Year" real, | |
"Net_Worth_Millions" real, | |
"Citizenship" text, | |
PRIMARY KEY ("Singer_ID") | |
); | |
CREATE TABLE "song" ( | |
"Song_ID" int, | |
"Title" text, | |
"Singer_ID" int, | |
"Sales" real, | |
"Highest_Position" real, | |
PRIMARY KEY ("Song_ID"), | |
FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID") | |
); | |
""", | |
# More schemas can be added here | |
} | |
# Streamlit interface | |
st.title("Text-to-SQL Chatbot") | |
st.sidebar.header("Select a Database") | |
selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys())) | |
st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600) | |
# Get user input for the question | |
question = st.text_input("Enter your question:") | |
db_id = selected_db | |
if question: | |
# Store the question in the database (or perform any additional processing) | |
add_a_record(question, db_id) | |
try: | |
# If translation is required, handle it here | |
if baidu_api_token and detect(question) != "en": | |
question = translate_zh_to_en(question, baidu_api_token) | |
except LangDetectException as e: | |
st.warning(f"Language detection error: {e}") | |
# Get the model's response (in this case, SQL query or logits) | |
response = text2sql_bot.get_response(question, db_id) | |
st.write(f"**Database:** {db_id}") | |
st.write(f"**Model logits (Example Output):** {response}") | |