nuatmochoi commited on
Commit
b5ebf9e
ยท
verified ยท
1 Parent(s): 59cdb28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +360 -50
app.py CHANGED
@@ -1,64 +1,374 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
8
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
 
4
+ try:
5
+ from unsloth import FastLanguageModel
6
+ except ImportError:
7
+ print("Unsloth๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์„ค์น˜ ์ค‘...")
8
+ import subprocess
9
+ subprocess.check_call(["pip", "install", "unsloth"])
10
+ from unsloth import FastLanguageModel
11
 
12
+ # Hugging Face์— ์—…๋กœ๋“œ๋œ ๋ชจ๋ธ ์‚ฌ์šฉ
13
+ MODEL_NAME = "huggingface-KREW/Llama-3.1-8B-Spider-SQL-Ko"
14
 
15
+ print(f"Loading model from Hugging Face: {MODEL_NAME}")
 
 
 
 
 
 
 
 
16
 
17
+ # Unsloth๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
18
+ try:
19
+ model, tokenizer = FastLanguageModel.from_pretrained(
20
+ model_name=MODEL_NAME,
21
+ max_seq_length=2048,
22
+ dtype=None, # ์ž๋™ ๊ฐ์ง€
23
+ load_in_4bit=True, # 4๋น„ํŠธ ์–‘์žํ™” ์‚ฌ์šฉ
24
+ )
25
+
26
+ # ์ถ”๋ก  ๋ชจ๋“œ๋กœ ์„ค์ •
27
+ FastLanguageModel.for_inference(model)
28
+ print("Model loaded successfully with Unsloth!")
29
+
30
+ except Exception as e:
31
+ print(f"Error loading model with Unsloth: {e}")
32
+ print("\n๋ชจ๋ธ์ด Hugging Face์— ์ œ๋Œ€๋กœ ์—…๋กœ๋“œ๋˜์ง€ ์•Š์•˜์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
33
+ print("๋กœ์ปฌ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜ ๋ชจ๋ธ์„ ๋‹ค์‹œ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”.")
34
+ raise
35
 
36
+ # Example databases and questions
37
+ examples = [
38
+ {
39
+ "db_id": "department_management",
40
+ "question": "๊ฐ ๋ถ€์„œ๋ณ„ ์ง์› ์ˆ˜๋ฅผ ๋ณด์—ฌ์ฃผ์„ธ์š”.",
41
+ "schema": """๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์Šคํ‚ค๋งˆ:
42
+ ํ…Œ์ด๋ธ”: department
43
+ ์ปฌ๋Ÿผ:
44
+ - Department_ID (number) (๊ธฐ๋ณธ ํ‚ค)
45
+ - Name (text)
46
+ - Creation (text)
47
+ - Ranking (number)
48
+ - Budget_in_Billions (number)
49
+ - Num_Employees (number)
50
+ ํ…Œ์ด๋ธ”: head
51
+ ์ปฌ๋Ÿผ:
52
+ - head_ID (number) (๊ธฐ๋ณธ ํ‚ค)
53
+ - name (text)
54
+ - born_state (text)
55
+ - age (number)
56
+ ํ…Œ์ด๋ธ”: management
57
+ ์ปฌ๋Ÿผ:
58
+ - department_ID (number) (๊ธฐ๋ณธ ํ‚ค)
59
+ - head_ID (number)
60
+ - temporary_acting (text)
61
 
62
+ ์™ธ๋ž˜ ํ‚ค ๊ด€๊ณ„:
63
+ - management.head_ID โ†’ head.head_ID
64
+ - management.department_ID โ†’ department.Department_ID"""
65
+ },
66
+ {
67
+ "db_id": "concert_singer",
68
+ "question": "๊ฐ€์žฅ ๋งŽ์€ ์ฝ˜์„œํŠธ๋ฅผ ์—ฐ ๊ฐ€์ˆ˜๋Š” ๋ˆ„๊ตฌ์ธ๊ฐ€์š”?",
69
+ "schema": """๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์Šคํ‚ค๋งˆ:
70
+ ํ…Œ์ด๋ธ”: singer
71
+ ์ปฌ๋Ÿผ:
72
+ - Singer_ID (number) (๊ธฐ๋ณธ ํ‚ค)
73
+ - Name (text)
74
+ - Country (text)
75
+ - Song_Name (text)
76
+ - Song_release_year (text)
77
+ - Age (number)
78
+ - Is_male (text)
79
+ ํ…Œ์ด๋ธ”: concert
80
+ ์ปฌ๋Ÿผ:
81
+ - concert_ID (number) (๊ธฐ๋ณธ ํ‚ค)
82
+ - concert_Name (text)
83
+ - Theme (text)
84
+ - Stadium_ID (number)
85
+ - Year (text)
86
+ ํ…Œ์ด๋ธ”: singer_in_concert
87
+ ์ปฌ๋Ÿผ:
88
+ - concert_ID (number)
89
+ - Singer_ID (number)
90
 
91
+ ์™ธ๋ž˜ ํ‚ค ๊ด€๊ณ„:
92
+ - singer_in_concert.Singer_ID โ†’ singer.Singer_ID
93
+ - singer_in_concert.concert_ID โ†’ concert.concert_ID"""
94
+ },
95
+ {
96
+ "db_id": "pets_1",
97
+ "question": "๊ฐ€์žฅ ๋‚˜์ด๊ฐ€ ๋งŽ์€ ๊ฐœ์˜ ์ด๋ฆ„์€ ๋ฌด์—‡์ธ๊ฐ€์š”?",
98
+ "schema": """๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์Šคํ‚ค๋งˆ:
99
+ ํ…Œ์ด๋ธ”: Student
100
+ ์ปฌ๋Ÿผ:
101
+ - StuID (number) (๊ธฐ๋ณธ ํ‚ค)
102
+ - LName (text)
103
+ - Fname (text)
104
+ - Age (number)
105
+ - Sex (text)
106
+ - Major (number)
107
+ - Advisor (number)
108
+ - city_code (text)
109
+ ํ…Œ์ด๋ธ”: Has_Pet
110
+ ์ปฌ๋Ÿผ:
111
+ - StuID (number)
112
+ - PetID (number)
113
+ ํ…Œ์ด๋ธ”: Pets
114
+ ์ปฌ๋Ÿผ:
115
+ - PetID (number) (๊ธฐ๋ณธ ํ‚ค)
116
+ - PetType (text)
117
+ - pet_age (number)
118
+ - weight (number)"""
119
+ },
120
+ {
121
+ "db_id": "car_1",
122
+ "question": "๋ฏธ๊ตญ์‚ฐ ์ž๋™์ฐจ ์ค‘ ๊ฐ€์žฅ ๋น ๋ฅธ ์ž๋™์ฐจ๋Š” ๋ฌด์—‡์ธ๊ฐ€์š”?",
123
+ "schema": """๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์Šคํ‚ค๋งˆ:
124
+ ํ…Œ์ด๋ธ”: continents
125
+ ์ปฌ๋Ÿผ:
126
+ - ContId (number) (๊ธฐ๋ณธ ํ‚ค)
127
+ - Continent (text)
128
+ ํ…Œ์ด๋ธ”: countries
129
+ ์ปฌ๋Ÿผ:
130
+ - CountryId (number) (๊ธฐ๋ณธ ํ‚ค)
131
+ - CountryName (text)
132
+ - Continent (number)
133
+ ํ…Œ์ด๋ธ”: car_makers
134
+ ์ปฌ๋Ÿผ:
135
+ - Id (number) (๊ธฐ๋ณธ ํ‚ค)
136
+ - Maker (text)
137
+ - FullName (text)
138
+ - Country (number)
139
+ ํ…Œ์ด๋ธ”: model_list
140
+ ์ปฌ๋Ÿผ:
141
+ - ModelId (number) (๊ธฐ๋ณธ ํ‚ค)
142
+ - Maker (number)
143
+ - Model (text)
144
+ ํ…Œ์ด๋ธ”: car_names
145
+ ์ปฌ๋Ÿผ:
146
+ - MakeId (number) (๊ธฐ๋ณธ ํ‚ค)
147
+ - Model (text)
148
+ - Make (text)
149
+ ํ…Œ์ด๋ธ”: cars_data
150
+ ์ปฌ๋Ÿผ:
151
+ - Id (number) (๊ธฐ๋ณธ ํ‚ค)
152
+ - MPG (text)
153
+ - Cylinders (number)
154
+ - Edispl (text)
155
+ - Horsepower (text)
156
+ - Weight (number)
157
+ - Accelerate (number)
158
+ - Year (number)"""
159
+ },
160
+ {
161
+ "db_id": "tvshow",
162
+ "question": "๊ฐ€์žฅ ๋†’์€ ํ‰์ ์„ ๋ฐ›์€ TV ์‡ผ๋Š” ๋ฌด์—‡์ธ๊ฐ€์š”?",
163
+ "schema": """๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์Šคํ‚ค๋งˆ:
164
+ ํ…Œ์ด๋ธ”: TV_Channel
165
+ ์ปฌ๋Ÿผ:
166
+ - id (number) (๊ธฐ๋ณธ ํ‚ค)
167
+ - series_name (text)
168
+ - Country (text)
169
+ - Language (text)
170
+ - Content (text)
171
+ - Pixel_aspect_ratio_PAR (text)
172
+ - Hight_definition_TV (text)
173
+ - Pay_per_view_PPV (text)
174
+ - Package_Option (text)
175
+ ํ…Œ์ด๋ธ”: TV_series
176
+ ์ปฌ๋Ÿผ:
177
+ - id (number)
178
+ - Episode (text)
179
+ - Air_Date (text)
180
+ - Rating (text)
181
+ - Share (text)
182
+ - 18_49_Rating_Share (text)
183
+ - Viewers_m (text)
184
+ - Weekly_Rank (number)
185
+ - Channel (number)
186
+ ํ…Œ์ด๋ธ”: Cartoon
187
+ ์ปฌ๋Ÿผ:
188
+ - id (number) (๊ธฐ๋ณธ ํ‚ค)
189
+ - Title (text)
190
+ - Directed_by (text)
191
+ - Written_by (text)
192
+ - Original_air_date (text)
193
+ - Production_code (number)
194
+ - Channel (number)"""
195
+ }
196
+ ]
197
+
198
+ def generate_sql(question, db_id, schema_info):
199
+ """Generate SQL query using the model."""
200
+ # Create prompt with schema
201
+ prompt = f"""๋‹น์‹ ์€ ์ž์—ฐ์–ด ์งˆ๋ฌธ์„ SQL ์ฟผ๋ฆฌ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์ „๋ฌธ AI ์–ด์‹œ์Šคํ„ดํŠธ์ž…๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž๊ฐ€ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ์ •๋ณด๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด ์ผ์ƒ ์–ธ์–ด๋กœ ์งˆ๋ฌธํ•˜๋ฉด, ๋‹น์‹ ์€ ํ•ด๋‹น ์งˆ๋ฌธ์„ ์ •ํ™•ํ•œ SQL ์ฟผ๋ฆฌ๋กœ ๋ณ€ํ™˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
202
 
203
+ {schema_info}
 
204
 
205
+ ์งˆ๋ฌธ: {question}
206
+ SQL:"""
207
+
208
+ # ์ฑ„ํŒ… ๋ฉ”์‹œ์ง€๋กœ ํฌ๋งทํŒ…
209
+ messages = [{"role": "user", "content": prompt}]
210
+
211
+ # ์ฑ„ํŒ… ํ…œํ”Œ๋ฆฟ ์ ์šฉ
212
+ inputs = tokenizer.apply_chat_template(
213
+ messages,
214
+ tokenize=True,
215
+ add_generation_prompt=True,
216
+ return_tensors="pt"
217
+ ).to(model.device)
218
+
219
+ # Generate
220
+ with torch.no_grad():
221
+ outputs = model.generate(
222
+ inputs,
223
+ max_new_tokens=256,
224
+ temperature=0.1,
225
+ top_p=0.95,
226
+ do_sample=True,
227
+ use_cache=True
228
+ )
229
+
230
+ # Decode
231
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
232
+
233
+ # Extract SQL after the prompt
234
+ if prompt in response:
235
+ sql_part = response.split(prompt)[-1].strip()
236
+ else:
237
+ sql_part = response
238
+
239
+ # Clean up the response
240
+ if sql_part.startswith("assistant"):
241
+ sql_part = sql_part[len("assistant"):].strip()
242
+
243
+ # Extract SQL query
244
+ lines = sql_part.split('\n')
245
+ sql_query = ""
246
+ for line in lines:
247
+ line = line.strip()
248
+ if line.lower().startswith(('select', 'with', '(select')):
249
+ sql_query = line
250
+ # Continue collecting lines until we hit a semicolon or empty line
251
+ for next_line in lines[lines.index(line)+1:]:
252
+ next_line = next_line.strip()
253
+ if not next_line or next_line.startswith(('์งˆ๋ฌธ', '๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค')):
254
+ break
255
+ sql_query += " " + next_line
256
+ if next_line.endswith(';'):
257
+ break
258
+ break
259
+
260
+ # Clean up SQL
261
+ sql_query = sql_query.strip()
262
+ if sql_query.endswith(';'):
263
+ sql_query = sql_query[:-1]
264
+
265
+ return sql_query if sql_query else "SQL ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."
266
 
267
+ def process_question(question, db_id, custom_schema=None):
268
+ """Process user question and generate SQL query."""
269
+ if not question or not db_id:
270
+ return "์งˆ๋ฌธ๊ณผ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ID๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
271
+
272
+ # Use custom schema if provided, otherwise find from examples
273
+ if custom_schema and custom_schema.strip():
274
+ schema_info = custom_schema
275
+ else:
276
+ # Find schema from examples
277
+ schema_info = None
278
+ for example in examples:
279
+ if example["db_id"] == db_id:
280
+ schema_info = example["schema"]
281
+ break
282
+
283
+ if not schema_info:
284
+ return "์Šคํ‚ค๋งˆ ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์ปค์Šคํ…€ ์Šคํ‚ค๋งˆ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
285
+
286
+ # Generate SQL
287
+ try:
288
+ sql_query = generate_sql(question, db_id, schema_info)
289
+ return sql_query
290
+ except Exception as e:
291
+ return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
292
 
293
+ # Create Gradio interface
294
+ with gr.Blocks(title="Spider SQL Generator - Korean", theme=gr.themes.Soft()) as demo:
295
+ gr.Markdown("""
296
+ # ๐Ÿ•ท๏ธ Spider SQL Generator - Korean
297
+
298
+ ํ•œ๊ตญ์–ด ์งˆ๋ฌธ์„ SQL ์ฟผ๋ฆฌ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” Llama 3.1 8B ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
299
+
300
+ ## ๐Ÿ“Š ์„ฑ๋Šฅ
301
+ - **Exact Match**: 42.65%
302
+ - **Execution Accuracy**: 65.47%
303
+ - **Training**: Spider ๋ฐ์ดํ„ฐ์…‹ (ํ•œ๊ตญ์–ด ๋ฒˆ์—ญ)
304
+ """)
305
+
306
+ with gr.Row():
307
+ with gr.Column():
308
+ db_id_input = gr.Textbox(
309
+ label="๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ID",
310
+ placeholder="์˜ˆ: department_management",
311
+ value="department_management"
312
+ )
313
+
314
+ question_input = gr.Textbox(
315
+ label="์งˆ๋ฌธ (ํ•œ๊ตญ์–ด)",
316
+ placeholder="์˜ˆ: ๊ฐ ๋ถ€์„œ๋ณ„ ์ง์› ์ˆ˜๋ฅผ ๋ณด์—ฌ์ฃผ์„ธ์š”.",
317
+ lines=2
318
+ )
319
+
320
+ with gr.Accordion("์Šคํ‚ค๋งˆ ์ •๋ณด (์„ ํƒ์‚ฌํ•ญ)", open=False):
321
+ schema_input = gr.Textbox(
322
+ label="์ปค์Šคํ…€ ์Šคํ‚ค๋งˆ",
323
+ placeholder="์ปค์Šคํ…€ ์Šคํ‚ค๋งˆ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”. ๋น„์›Œ๋‘๋ฉด ์˜ˆ์ œ ์Šคํ‚ค๋งˆ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.",
324
+ lines=10
325
+ )
326
+
327
+ submit_btn = gr.Button("SQL ์ƒ์„ฑ", variant="primary", size="lg")
328
+
329
+ with gr.Column():
330
+ output = gr.Textbox(
331
+ label="์ƒ์„ฑ๋œ SQL ์ฟผ๋ฆฌ",
332
+ lines=4,
333
+ elem_classes=["code"]
334
+ )
335
+
336
+ gr.Markdown("""
337
+ ### ๐Ÿ’ก ์‚ฌ์šฉ ํŒ
338
+ - ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ID๋Š” ์˜ˆ์ œ์—์„œ ์„ ํƒํ•˜๊ฑฐ๋‚˜ ์ง์ ‘ ์ž…๋ ฅํ•˜์„ธ์š”
339
+ - ์งˆ๋ฌธ์€ ํ•œ๊ตญ์–ด๋กœ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ž‘์„ฑํ•˜์„ธ์š”
340
+ - ์Šคํ‚ค๋งˆ ์ •๋ณด๋Š” ์„ ํƒ์‚ฌํ•ญ์ž…๋‹ˆ๋‹ค
341
+ """)
342
+
343
+ # Examples
344
+ gr.Markdown("### ๐Ÿ“š ์˜ˆ์ œ (ํด๋ฆญํ•˜์—ฌ ์‚ฌ์šฉ)")
345
+ gr.Examples(
346
+ examples=[
347
+ [ex["db_id"], ex["question"], ex["schema"]] for ex in examples
348
+ ],
349
+ inputs=[db_id_input, question_input, schema_input],
350
+ outputs=output,
351
+ fn=process_question,
352
+ cache_examples=False
353
+ )
354
+
355
+ # Submit action
356
+ submit_btn.click(
357
+ fn=process_question,
358
+ inputs=[question_input, db_id_input, schema_input],
359
+ outputs=output
360
+ )
361
+
362
+ # Model info
363
+ gr.Markdown(f"""
364
+ ---
365
+ ### ๐Ÿค– ๋ชจ๋ธ ์ •๋ณด
366
+ - **Hugging Face**: [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME})
367
+ - **Base Model**: Llama 3.1 8B
368
+ - **Fine-tuning**: LoRA with Unsloth
369
+ - **Dataset**: Spider (Korean translated)
370
+ """)
371
 
372
+ # Launch the app
373
  if __name__ == "__main__":
374
+ demo.launch(share=True)