Roxanne-WANG commited on
Commit
b5be522
·
1 Parent(s): b759b87

update weight loading

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +87 -8
README.md CHANGED
@@ -2,7 +2,7 @@
2
  title: LangSQL
3
  emoji: 🦕
4
  colorFrom: blue
5
- colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.44.1
8
  app_file: app.py
 
2
  title: LangSQL
3
  emoji: 🦕
4
  colorFrom: blue
5
+ colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.44.1
8
  app_file: app.py
app.py CHANGED
@@ -1,15 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
- from text2sql import ChatBot
3
  from langdetect import detect
4
  from utils.translate_utils import translate_zh_to_en
5
  from utils.db_utils import add_a_record
6
  from langdetect.lang_detect_exception import LangDetectException
7
 
8
- # Initialize chatbot and other variables
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  text2sql_bot = ChatBot()
10
  baidu_api_token = None
11
 
12
- # Define database schemas for demonstration
13
  db_schemas = {
14
  "singer": """
15
  CREATE TABLE "singer" (
@@ -31,22 +114,18 @@ db_schemas = {
31
  FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
32
  );
33
  """,
34
- # Add other schemas as needed
35
  }
36
 
37
  # Streamlit UI
38
  st.title("Text-to-SQL Chatbot")
39
  st.sidebar.header("Select a Database")
40
 
41
- # Sidebar for selecting a database
42
  selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
43
 
44
- # Display the selected schema
45
  st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)
46
 
47
- # User input section
48
  question = st.text_input("Enter your question:")
49
- db_id = selected_db # Use selected database for DB ID
50
 
51
  if question:
52
  add_a_record(question, db_id)
 
1
+ # import streamlit as st
2
+ # from text2sql import ChatBot
3
+ # from langdetect import detect
4
+ # from utils.translate_utils import translate_zh_to_en
5
+ # from utils.db_utils import add_a_record
6
+ # from langdetect.lang_detect_exception import LangDetectException
7
+
8
+ # # Initialize chatbot and other variables
9
+ # text2sql_bot = ChatBot()
10
+ # baidu_api_token = None
11
+
12
+ # # Define database schemas for demonstration
13
+ # db_schemas = {
14
+ # "singer": """
15
+ # CREATE TABLE "singer" (
16
+ # "Singer_ID" int,
17
+ # "Name" text,
18
+ # "Birth_Year" real,
19
+ # "Net_Worth_Millions" real,
20
+ # "Citizenship" text,
21
+ # PRIMARY KEY ("Singer_ID")
22
+ # );
23
+
24
+ # CREATE TABLE "song" (
25
+ # "Song_ID" int,
26
+ # "Title" text,
27
+ # "Singer_ID" int,
28
+ # "Sales" real,
29
+ # "Highest_Position" real,
30
+ # PRIMARY KEY ("Song_ID"),
31
+ # FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
32
+ # );
33
+ # """,
34
+ # # Add other schemas as needed
35
+ # }
36
+
37
+ # # Streamlit UI
38
+ # st.title("Text-to-SQL Chatbot")
39
+ # st.sidebar.header("Select a Database")
40
+
41
+ # # Sidebar for selecting a database
42
+ # selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
43
+
44
+ # # Display the selected schema
45
+ # st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)
46
+
47
+ # # User input section
48
+ # question = st.text_input("Enter your question:")
49
+ # db_id = selected_db # Use selected database for DB ID
50
+
51
+ # if question:
52
+ # add_a_record(question, db_id)
53
+
54
+ # try:
55
+ # if baidu_api_token is not None and detect(question) != "en":
56
+ # print("Before translation:", question)
57
+ # question = translate_zh_to_en(question, baidu_api_token)
58
+ # print("After translation:", question)
59
+ # except LangDetectException as e:
60
+ # print("Language detection error:", str(e))
61
+
62
+ # predicted_sql = text2sql_bot.get_response(question, db_id)
63
+ # st.write(f"**Database:** {db_id}")
64
+ # st.write(f"**Predicted SQL query:** {predicted_sql}")
65
+
66
+
67
  import streamlit as st
68
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
69
  from langdetect import detect
70
  from utils.translate_utils import translate_zh_to_en
71
  from utils.db_utils import add_a_record
72
  from langdetect.lang_detect_exception import LangDetectException
73
 
74
+ class SchemaItemClassifierInference:
75
+ def __init__(self, model_name):
76
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
77
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=True)
78
+
79
+ def predict(self, text):
80
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
81
+ outputs = self.model(**inputs)
82
+ return outputs.logits
83
+
84
+ class ChatBot:
85
+ def __init__(self):
86
+ model_name = "Roxanne-WANG/LangSQL"
87
+ self.sic = SchemaItemClassifierInference(model_name)
88
+
89
+ def get_response(self, question, db_id):
90
+ prediction = self.sic.predict(question)
91
+ return prediction
92
+
93
  text2sql_bot = ChatBot()
94
  baidu_api_token = None
95
 
 
96
  db_schemas = {
97
  "singer": """
98
  CREATE TABLE "singer" (
 
114
  FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
115
  );
116
  """,
 
117
  }
118
 
119
  # Streamlit UI
120
  st.title("Text-to-SQL Chatbot")
121
  st.sidebar.header("Select a Database")
122
 
 
123
  selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys()))
124
 
 
125
  st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600)
126
 
 
127
  question = st.text_input("Enter your question:")
128
+ db_id = selected_db
129
 
130
  if question:
131
  add_a_record(question, db_id)