Roxanne-WANG commited on
Commit
6670a17
·
1 Parent(s): b053c71

update model

Browse files
Files changed (2) hide show
  1. app.py +0 -54
  2. text2sql.py +2 -2
app.py CHANGED
@@ -76,60 +76,6 @@ from utils.db_utils import add_a_record
76
  from langdetect.lang_detect_exception import LangDetectException
77
  import os
78
 
79
- # Suppress excessive warnings from Hugging Face transformers library
80
- hf_logging.set_verbosity_error()
81
-
82
- # SchemaItemClassifierInference class for loading the Hugging Face model
83
- class SchemaItemClassifierInference:
84
- def __init__(self, model_name: str, token=None):
85
- """
86
- model_name: Hugging Face repository path, e.g., "Roxanne-WANG/LangSQL"
87
- token: Authentication token for Hugging Face (if the model is private)
88
- """
89
- # Load the tokenizer and model from Hugging Face, trust remote code if needed
90
- self.tokenizer = AutoTokenizer.from_pretrained(
91
- model_name,
92
- use_auth_token=token, # Pass the token for accessing private models
93
- trust_remote_code=True # Trust custom model code from Hugging Face repo
94
- )
95
- self.model = AutoModelForSequenceClassification.from_pretrained(
96
- model_name,
97
- use_auth_token=token,
98
- trust_remote_code=True
99
- )
100
-
101
- def predict(self, text: str):
102
- # Tokenize the input text and get predictions from the model
103
- inputs = self.tokenizer(
104
- text,
105
- return_tensors="pt",
106
- padding=True,
107
- truncation=True
108
- )
109
- outputs = self.model(**inputs)
110
- return outputs.logits
111
-
112
-
113
- # ChatBot class that interacts with SchemaItemClassifierInference
114
- class ChatBot:
115
- def __init__(self):
116
- # Specify the Hugging Face model name (replace with your model's path)
117
- model_name = "Roxanne-WANG/LangSQL"
118
- hf_token = os.getenv('HF_TOKEN') # Get token from environment variables
119
-
120
- if hf_token is None:
121
- raise ValueError("Hugging Face token is required. Please set HF_TOKEN.")
122
-
123
- # Initialize the schema item classifier with Hugging Face token
124
- self.sic = SchemaItemClassifierInference(model_name, token=hf_token)
125
-
126
- def get_response(self, question: str, db_id: str):
127
- # Get the model's prediction (logits) for the input question
128
- logits = self.sic.predict(question)
129
- # For now, return logits as a placeholder for the actual SQL query
130
- return logits
131
-
132
-
133
  # -------- Streamlit Web Application --------
134
  text2sql_bot = ChatBot()
135
  baidu_api_token = None # Your Baidu API token (if needed for translation)
 
76
  from langdetect.lang_detect_exception import LangDetectException
77
  import os
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # -------- Streamlit Web Application --------
80
  text2sql_bot = ChatBot()
81
  baidu_api_token = None # Your Baidu API token (if needed for translation)
text2sql.py CHANGED
@@ -104,14 +104,14 @@ def get_db_id2ddl(db_path):
104
  class ChatBot():
105
  def __init__(self) -> None:
106
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
107
- model_name = "seeklhy/codes-7b-merged"
108
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
109
  self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16)
110
  self.max_length = 4096
111
  self.max_new_tokens = 256
112
  self.max_prefix_length = self.max_length - self.max_new_tokens
113
 
114
- self.sic = SchemaItemClassifierInference("sic_ckpts/sic_bird")
115
 
116
  self.db_id2content_searcher = dict()
117
  for db_id in os.listdir("db_contents_index"):
 
104
  class ChatBot():
105
  def __init__(self) -> None:
106
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
107
+ model_name = "seeklhy/codes-1b"
108
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
109
  self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16)
110
  self.max_length = 4096
111
  self.max_new_tokens = 256
112
  self.max_prefix_length = self.max_length - self.max_new_tokens
113
 
114
+ self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL", token=os.getenv('HF_TOKEN'))
115
 
116
  self.db_id2content_searcher = dict()
117
  for db_id in os.listdir("db_contents_index"):