Roxanne-WANG commited on
Commit
b053c71
·
1 Parent(s): b5be522

update token

Browse files
Files changed (1) hide show
  1. app.py +64 -21
app.py CHANGED
@@ -65,34 +65,76 @@
65
 
66
 
67
  import streamlit as st
68
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
 
69
  from langdetect import detect
70
  from utils.translate_utils import translate_zh_to_en
71
  from utils.db_utils import add_a_record
72
  from langdetect.lang_detect_exception import LangDetectException
 
73
 
74
- class SchemaItemClassifierInference:
75
- def __init__(self, model_name):
76
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
77
- self.model = AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=True)
78
 
79
- def predict(self, text):
80
- inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  outputs = self.model(**inputs)
82
  return outputs.logits
83
 
 
 
84
  class ChatBot:
85
  def __init__(self):
 
86
  model_name = "Roxanne-WANG/LangSQL"
87
- self.sic = SchemaItemClassifierInference(model_name)
 
 
 
88
 
89
- def get_response(self, question, db_id):
90
- prediction = self.sic.predict(question)
91
- return prediction
92
 
 
 
 
 
 
 
 
 
93
  text2sql_bot = ChatBot()
94
- baidu_api_token = None
95
 
 
96
  db_schemas = {
97
  "singer": """
98
  CREATE TABLE "singer" (
@@ -114,30 +156,31 @@ db_schemas = {
114
  FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
115
  );
116
  """,
 
117
  }
118
 
119
- # Streamlit UI
120
  st.title("Text-to-SQL Chatbot")
121
  st.sidebar.header("Select a Database")
122
-
123
  selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
124
-
125
  st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)
126
 
 
127
  question = st.text_input("Enter your question:")
128
  db_id = selected_db
129
 
130
  if question:
 
131
  add_a_record(question, db_id)
132
 
133
  try:
134
- if baidu_api_token is not None and detect(question) != "en":
135
- print("Before translation:", question)
136
  question = translate_zh_to_en(question, baidu_api_token)
137
- print("After translation:", question)
138
  except LangDetectException as e:
139
- print("Language detection error:", str(e))
140
 
141
- predicted_sql = text2sql_bot.get_response(question, db_id)
 
142
  st.write(f"**Database:** {db_id}")
143
- st.write(f"**Predicted SQL query:** {predicted_sql}")
 
65
 
66
 
67
  import streamlit as st
68
+ from transformers import (
69
+ AutoTokenizer,
70
+ AutoModelForSequenceClassification,
71
+ logging as hf_logging
72
+ )
73
  from langdetect import detect
74
  from utils.translate_utils import translate_zh_to_en
75
  from utils.db_utils import add_a_record
76
  from langdetect.lang_detect_exception import LangDetectException
77
+ import os
78
 
79
+ # Suppress excessive warnings from Hugging Face transformers library
80
+ hf_logging.set_verbosity_error()
 
 
81
 
82
+ # SchemaItemClassifierInference class for loading the Hugging Face model
83
+ class SchemaItemClassifierInference:
84
+ def __init__(self, model_name: str, token=None):
85
+ """
86
+ model_name: Hugging Face repository path, e.g., "Roxanne-WANG/LangSQL"
87
+ token: Authentication token for Hugging Face (if the model is private)
88
+ """
89
+ # Load the tokenizer and model from Hugging Face, trust remote code if needed
90
+ self.tokenizer = AutoTokenizer.from_pretrained(
91
+ model_name,
92
+ use_auth_token=token, # Pass the token for accessing private models
93
+ trust_remote_code=True # Trust custom model code from Hugging Face repo
94
+ )
95
+ self.model = AutoModelForSequenceClassification.from_pretrained(
96
+ model_name,
97
+ use_auth_token=token,
98
+ trust_remote_code=True
99
+ )
100
+
101
+ def predict(self, text: str):
102
+ # Tokenize the input text and get predictions from the model
103
+ inputs = self.tokenizer(
104
+ text,
105
+ return_tensors="pt",
106
+ padding=True,
107
+ truncation=True
108
+ )
109
  outputs = self.model(**inputs)
110
  return outputs.logits
111
 
112
+
113
+ # ChatBot class that interacts with SchemaItemClassifierInference
114
  class ChatBot:
115
  def __init__(self):
116
+ # Specify the Hugging Face model name (replace with your model's path)
117
  model_name = "Roxanne-WANG/LangSQL"
118
+ hf_token = os.getenv('HF_TOKEN') # Get token from environment variables
119
+
120
+ if hf_token is None:
121
+ raise ValueError("Hugging Face token is required. Please set HF_TOKEN.")
122
 
123
+ # Initialize the schema item classifier with Hugging Face token
124
+ self.sic = SchemaItemClassifierInference(model_name, token=hf_token)
 
125
 
126
+ def get_response(self, question: str, db_id: str):
127
+ # Get the model's prediction (logits) for the input question
128
+ logits = self.sic.predict(question)
129
+ # For now, return logits as a placeholder for the actual SQL query
130
+ return logits
131
+
132
+
133
+ # -------- Streamlit Web Application --------
134
  text2sql_bot = ChatBot()
135
+ baidu_api_token = None # Your Baidu API token (if needed for translation)
136
 
137
+ # Define some database schemas for demonstration purposes
138
  db_schemas = {
139
  "singer": """
140
  CREATE TABLE "singer" (
 
156
  FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
157
  );
158
  """,
159
+ # More schemas can be added here
160
  }
161
 
162
+ # Streamlit interface
163
  st.title("Text-to-SQL Chatbot")
164
  st.sidebar.header("Select a Database")
 
165
  selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
 
166
  st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)
167
 
168
+ # Get user input for the question
169
  question = st.text_input("Enter your question:")
170
  db_id = selected_db
171
 
172
  if question:
173
+ # Store the question in the database (or perform any additional processing)
174
  add_a_record(question, db_id)
175
 
176
  try:
177
+ # If translation is required, handle it here
178
+ if baidu_api_token and detect(question) != "en":
179
  question = translate_zh_to_en(question, baidu_api_token)
 
180
  except LangDetectException as e:
181
+ st.warning(f"Language detection error: {e}")
182
 
183
+ # Get the model's response (in this case, SQL query or logits)
184
+ response = text2sql_bot.get_response(question, db_id)
185
  st.write(f"**Database:** {db_id}")
186
+ st.write(f"**Model logits (Example Output):** {response}")