nuatmochoi's picture
Update app.py
b5ebf9e verified
import gradio as gr
import torch
try:
from unsloth import FastLanguageModel
except ImportError:
print("Unsloth๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์„ค์น˜ ์ค‘...")
import subprocess
subprocess.check_call(["pip", "install", "unsloth"])
from unsloth import FastLanguageModel
# Hugging Face์— ์—…๋กœ๋“œ๋œ ๋ชจ๋ธ ์‚ฌ์šฉ
MODEL_NAME = "huggingface-KREW/Llama-3.1-8B-Spider-SQL-Ko"
print(f"Loading model from Hugging Face: {MODEL_NAME}")
# Unsloth๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
try:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=2048,
dtype=None, # ์ž๋™ ๊ฐ์ง€
load_in_4bit=True, # 4๋น„ํŠธ ์–‘์žํ™” ์‚ฌ์šฉ
)
# ์ถ”๋ก  ๋ชจ๋“œ๋กœ ์„ค์ •
FastLanguageModel.for_inference(model)
print("Model loaded successfully with Unsloth!")
except Exception as e:
print(f"Error loading model with Unsloth: {e}")
print("\n๋ชจ๋ธ์ด Hugging Face์— ์ œ๋Œ€๋กœ ์—…๋กœ๋“œ๋˜์ง€ ์•Š์•˜์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
print("๋กœ์ปฌ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜ ๋ชจ๋ธ์„ ๋‹ค์‹œ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”.")
raise
# Example databases and questions
examples = [
{
"db_id": "department_management",
"question": "๊ฐ ๋ถ€์„œ๋ณ„ ์ง์› ์ˆ˜๋ฅผ ๋ณด์—ฌ์ฃผ์„ธ์š”.",
"schema": """๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์Šคํ‚ค๋งˆ:
ํ…Œ์ด๋ธ”: department
์ปฌ๋Ÿผ:
- Department_ID (number) (๊ธฐ๋ณธ ํ‚ค)
- Name (text)
- Creation (text)
- Ranking (number)
- Budget_in_Billions (number)
- Num_Employees (number)
ํ…Œ์ด๋ธ”: head
์ปฌ๋Ÿผ:
- head_ID (number) (๊ธฐ๋ณธ ํ‚ค)
- name (text)
- born_state (text)
- age (number)
ํ…Œ์ด๋ธ”: management
์ปฌ๋Ÿผ:
- department_ID (number) (๊ธฐ๋ณธ ํ‚ค)
- head_ID (number)
- temporary_acting (text)
์™ธ๋ž˜ ํ‚ค ๊ด€๊ณ„:
- management.head_ID โ†’ head.head_ID
- management.department_ID โ†’ department.Department_ID"""
},
{
"db_id": "concert_singer",
"question": "๊ฐ€์žฅ ๋งŽ์€ ์ฝ˜์„œํŠธ๋ฅผ ์—ฐ ๊ฐ€์ˆ˜๋Š” ๋ˆ„๊ตฌ์ธ๊ฐ€์š”?",
"schema": """๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์Šคํ‚ค๋งˆ:
ํ…Œ์ด๋ธ”: singer
์ปฌ๋Ÿผ:
- Singer_ID (number) (๊ธฐ๋ณธ ํ‚ค)
- Name (text)
- Country (text)
- Song_Name (text)
- Song_release_year (text)
- Age (number)
- Is_male (text)
ํ…Œ์ด๋ธ”: concert
์ปฌ๋Ÿผ:
- concert_ID (number) (๊ธฐ๋ณธ ํ‚ค)
- concert_Name (text)
- Theme (text)
- Stadium_ID (number)
- Year (text)
ํ…Œ์ด๋ธ”: singer_in_concert
์ปฌ๋Ÿผ:
- concert_ID (number)
- Singer_ID (number)
์™ธ๋ž˜ ํ‚ค ๊ด€๊ณ„:
- singer_in_concert.Singer_ID โ†’ singer.Singer_ID
- singer_in_concert.concert_ID โ†’ concert.concert_ID"""
},
{
"db_id": "pets_1",
"question": "๊ฐ€์žฅ ๋‚˜์ด๊ฐ€ ๋งŽ์€ ๊ฐœ์˜ ์ด๋ฆ„์€ ๋ฌด์—‡์ธ๊ฐ€์š”?",
"schema": """๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์Šคํ‚ค๋งˆ:
ํ…Œ์ด๋ธ”: Student
์ปฌ๋Ÿผ:
- StuID (number) (๊ธฐ๋ณธ ํ‚ค)
- LName (text)
- Fname (text)
- Age (number)
- Sex (text)
- Major (number)
- Advisor (number)
- city_code (text)
ํ…Œ์ด๋ธ”: Has_Pet
์ปฌ๋Ÿผ:
- StuID (number)
- PetID (number)
ํ…Œ์ด๋ธ”: Pets
์ปฌ๋Ÿผ:
- PetID (number) (๊ธฐ๋ณธ ํ‚ค)
- PetType (text)
- pet_age (number)
- weight (number)"""
},
{
"db_id": "car_1",
"question": "๋ฏธ๊ตญ์‚ฐ ์ž๋™์ฐจ ์ค‘ ๊ฐ€์žฅ ๋น ๋ฅธ ์ž๋™์ฐจ๋Š” ๋ฌด์—‡์ธ๊ฐ€์š”?",
"schema": """๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์Šคํ‚ค๋งˆ:
ํ…Œ์ด๋ธ”: continents
์ปฌ๋Ÿผ:
- ContId (number) (๊ธฐ๋ณธ ํ‚ค)
- Continent (text)
ํ…Œ์ด๋ธ”: countries
์ปฌ๋Ÿผ:
- CountryId (number) (๊ธฐ๋ณธ ํ‚ค)
- CountryName (text)
- Continent (number)
ํ…Œ์ด๋ธ”: car_makers
์ปฌ๋Ÿผ:
- Id (number) (๊ธฐ๋ณธ ํ‚ค)
- Maker (text)
- FullName (text)
- Country (number)
ํ…Œ์ด๋ธ”: model_list
์ปฌ๋Ÿผ:
- ModelId (number) (๊ธฐ๋ณธ ํ‚ค)
- Maker (number)
- Model (text)
ํ…Œ์ด๋ธ”: car_names
์ปฌ๋Ÿผ:
- MakeId (number) (๊ธฐ๋ณธ ํ‚ค)
- Model (text)
- Make (text)
ํ…Œ์ด๋ธ”: cars_data
์ปฌ๋Ÿผ:
- Id (number) (๊ธฐ๋ณธ ํ‚ค)
- MPG (text)
- Cylinders (number)
- Edispl (text)
- Horsepower (text)
- Weight (number)
- Accelerate (number)
- Year (number)"""
},
{
"db_id": "tvshow",
"question": "๊ฐ€์žฅ ๋†’์€ ํ‰์ ์„ ๋ฐ›์€ TV ์‡ผ๋Š” ๋ฌด์—‡์ธ๊ฐ€์š”?",
"schema": """๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์Šคํ‚ค๋งˆ:
ํ…Œ์ด๋ธ”: TV_Channel
์ปฌ๋Ÿผ:
- id (number) (๊ธฐ๋ณธ ํ‚ค)
- series_name (text)
- Country (text)
- Language (text)
- Content (text)
- Pixel_aspect_ratio_PAR (text)
- Hight_definition_TV (text)
- Pay_per_view_PPV (text)
- Package_Option (text)
ํ…Œ์ด๋ธ”: TV_series
์ปฌ๋Ÿผ:
- id (number)
- Episode (text)
- Air_Date (text)
- Rating (text)
- Share (text)
- 18_49_Rating_Share (text)
- Viewers_m (text)
- Weekly_Rank (number)
- Channel (number)
ํ…Œ์ด๋ธ”: Cartoon
์ปฌ๋Ÿผ:
- id (number) (๊ธฐ๋ณธ ํ‚ค)
- Title (text)
- Directed_by (text)
- Written_by (text)
- Original_air_date (text)
- Production_code (number)
- Channel (number)"""
}
]
def generate_sql(question, db_id, schema_info):
"""Generate SQL query using the model."""
# Create prompt with schema
prompt = f"""๋‹น์‹ ์€ ์ž์—ฐ์–ด ์งˆ๋ฌธ์„ SQL ์ฟผ๋ฆฌ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์ „๋ฌธ AI ์–ด์‹œ์Šคํ„ดํŠธ์ž…๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž๊ฐ€ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ์ •๋ณด๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด ์ผ์ƒ ์–ธ์–ด๋กœ ์งˆ๋ฌธํ•˜๋ฉด, ๋‹น์‹ ์€ ํ•ด๋‹น ์งˆ๋ฌธ์„ ์ •ํ™•ํ•œ SQL ์ฟผ๋ฆฌ๋กœ ๋ณ€ํ™˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
{schema_info}
์งˆ๋ฌธ: {question}
SQL:"""
# ์ฑ„ํŒ… ๋ฉ”์‹œ์ง€๋กœ ํฌ๋งทํŒ…
messages = [{"role": "user", "content": prompt}]
# ์ฑ„ํŒ… ํ…œํ”Œ๋ฆฟ ์ ์šฉ
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=256,
temperature=0.1,
top_p=0.95,
do_sample=True,
use_cache=True
)
# Decode
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract SQL after the prompt
if prompt in response:
sql_part = response.split(prompt)[-1].strip()
else:
sql_part = response
# Clean up the response
if sql_part.startswith("assistant"):
sql_part = sql_part[len("assistant"):].strip()
# Extract SQL query
lines = sql_part.split('\n')
sql_query = ""
for line in lines:
line = line.strip()
if line.lower().startswith(('select', 'with', '(select')):
sql_query = line
# Continue collecting lines until we hit a semicolon or empty line
for next_line in lines[lines.index(line)+1:]:
next_line = next_line.strip()
if not next_line or next_line.startswith(('์งˆ๋ฌธ', '๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค')):
break
sql_query += " " + next_line
if next_line.endswith(';'):
break
break
# Clean up SQL
sql_query = sql_query.strip()
if sql_query.endswith(';'):
sql_query = sql_query[:-1]
return sql_query if sql_query else "SQL ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."
def process_question(question, db_id, custom_schema=None):
"""Process user question and generate SQL query."""
if not question or not db_id:
return "์งˆ๋ฌธ๊ณผ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ID๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
# Use custom schema if provided, otherwise find from examples
if custom_schema and custom_schema.strip():
schema_info = custom_schema
else:
# Find schema from examples
schema_info = None
for example in examples:
if example["db_id"] == db_id:
schema_info = example["schema"]
break
if not schema_info:
return "์Šคํ‚ค๋งˆ ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์ปค์Šคํ…€ ์Šคํ‚ค๋งˆ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
# Generate SQL
try:
sql_query = generate_sql(question, db_id, schema_info)
return sql_query
except Exception as e:
return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Spider SQL Generator - Korean", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ๐Ÿ•ท๏ธ Spider SQL Generator - Korean
ํ•œ๊ตญ์–ด ์งˆ๋ฌธ์„ SQL ์ฟผ๋ฆฌ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” Llama 3.1 8B ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
## ๐Ÿ“Š ์„ฑ๋Šฅ
- **Exact Match**: 42.65%
- **Execution Accuracy**: 65.47%
- **Training**: Spider ๋ฐ์ดํ„ฐ์…‹ (ํ•œ๊ตญ์–ด ๋ฒˆ์—ญ)
""")
with gr.Row():
with gr.Column():
db_id_input = gr.Textbox(
label="๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ID",
placeholder="์˜ˆ: department_management",
value="department_management"
)
question_input = gr.Textbox(
label="์งˆ๋ฌธ (ํ•œ๊ตญ์–ด)",
placeholder="์˜ˆ: ๊ฐ ๋ถ€์„œ๋ณ„ ์ง์› ์ˆ˜๋ฅผ ๋ณด์—ฌ์ฃผ์„ธ์š”.",
lines=2
)
with gr.Accordion("์Šคํ‚ค๋งˆ ์ •๋ณด (์„ ํƒ์‚ฌํ•ญ)", open=False):
schema_input = gr.Textbox(
label="์ปค์Šคํ…€ ์Šคํ‚ค๋งˆ",
placeholder="์ปค์Šคํ…€ ์Šคํ‚ค๋งˆ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”. ๋น„์›Œ๋‘๋ฉด ์˜ˆ์ œ ์Šคํ‚ค๋งˆ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.",
lines=10
)
submit_btn = gr.Button("SQL ์ƒ์„ฑ", variant="primary", size="lg")
with gr.Column():
output = gr.Textbox(
label="์ƒ์„ฑ๋œ SQL ์ฟผ๋ฆฌ",
lines=4,
elem_classes=["code"]
)
gr.Markdown("""
### ๐Ÿ’ก ์‚ฌ์šฉ ํŒ
- ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ID๋Š” ์˜ˆ์ œ์—์„œ ์„ ํƒํ•˜๊ฑฐ๋‚˜ ์ง์ ‘ ์ž…๋ ฅํ•˜์„ธ์š”
- ์งˆ๋ฌธ์€ ํ•œ๊ตญ์–ด๋กœ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ž‘์„ฑํ•˜์„ธ์š”
- ์Šคํ‚ค๋งˆ ์ •๋ณด๋Š” ์„ ํƒ์‚ฌํ•ญ์ž…๋‹ˆ๋‹ค
""")
# Examples
gr.Markdown("### ๐Ÿ“š ์˜ˆ์ œ (ํด๋ฆญํ•˜์—ฌ ์‚ฌ์šฉ)")
gr.Examples(
examples=[
[ex["db_id"], ex["question"], ex["schema"]] for ex in examples
],
inputs=[db_id_input, question_input, schema_input],
outputs=output,
fn=process_question,
cache_examples=False
)
# Submit action
submit_btn.click(
fn=process_question,
inputs=[question_input, db_id_input, schema_input],
outputs=output
)
# Model info
gr.Markdown(f"""
---
### ๐Ÿค– ๋ชจ๋ธ ์ •๋ณด
- **Hugging Face**: [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME})
- **Base Model**: Llama 3.1 8B
- **Fine-tuning**: LoRA with Unsloth
- **Dataset**: Spider (Korean translated)
""")
# Launch the app
if __name__ == "__main__":
demo.launch(share=True)