|
import gradio as gr |
|
import torch |
|
|
|
try: |
|
from unsloth import FastLanguageModel |
|
except ImportError: |
|
print("Unsloth๊ฐ ์ค์น๋์ด ์์ง ์์ต๋๋ค. ์ค์น ์ค...") |
|
import subprocess |
|
subprocess.check_call(["pip", "install", "unsloth"]) |
|
from unsloth import FastLanguageModel |
|
|
|
|
|
MODEL_NAME = "huggingface-KREW/Llama-3.1-8B-Spider-SQL-Ko" |
|
|
|
print(f"Loading model from Hugging Face: {MODEL_NAME}") |
|
|
|
|
|
try: |
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=MODEL_NAME, |
|
max_seq_length=2048, |
|
dtype=None, |
|
load_in_4bit=True, |
|
) |
|
|
|
|
|
FastLanguageModel.for_inference(model) |
|
print("Model loaded successfully with Unsloth!") |
|
|
|
except Exception as e: |
|
print(f"Error loading model with Unsloth: {e}") |
|
print("\n๋ชจ๋ธ์ด Hugging Face์ ์ ๋๋ก ์
๋ก๋๋์ง ์์์ ์ ์์ต๋๋ค.") |
|
print("๋ก์ปฌ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ฑฐ๋ ๋ชจ๋ธ์ ๋ค์ ์
๋ก๋ํด์ฃผ์ธ์.") |
|
raise |
|
|
|
|
|
examples = [ |
|
{ |
|
"db_id": "department_management", |
|
"question": "๊ฐ ๋ถ์๋ณ ์ง์ ์๋ฅผ ๋ณด์ฌ์ฃผ์ธ์.", |
|
"schema": """๋ฐ์ดํฐ๋ฒ ์ด์ค ์คํค๋ง: |
|
ํ
์ด๋ธ: department |
|
์ปฌ๋ผ: |
|
- Department_ID (number) (๊ธฐ๋ณธ ํค) |
|
- Name (text) |
|
- Creation (text) |
|
- Ranking (number) |
|
- Budget_in_Billions (number) |
|
- Num_Employees (number) |
|
ํ
์ด๋ธ: head |
|
์ปฌ๋ผ: |
|
- head_ID (number) (๊ธฐ๋ณธ ํค) |
|
- name (text) |
|
- born_state (text) |
|
- age (number) |
|
ํ
์ด๋ธ: management |
|
์ปฌ๋ผ: |
|
- department_ID (number) (๊ธฐ๋ณธ ํค) |
|
- head_ID (number) |
|
- temporary_acting (text) |
|
|
|
์ธ๋ ํค ๊ด๊ณ: |
|
- management.head_ID โ head.head_ID |
|
- management.department_ID โ department.Department_ID""" |
|
}, |
|
{ |
|
"db_id": "concert_singer", |
|
"question": "๊ฐ์ฅ ๋ง์ ์ฝ์ํธ๋ฅผ ์ฐ ๊ฐ์๋ ๋๊ตฌ์ธ๊ฐ์?", |
|
"schema": """๋ฐ์ดํฐ๋ฒ ์ด์ค ์คํค๋ง: |
|
ํ
์ด๋ธ: singer |
|
์ปฌ๋ผ: |
|
- Singer_ID (number) (๊ธฐ๋ณธ ํค) |
|
- Name (text) |
|
- Country (text) |
|
- Song_Name (text) |
|
- Song_release_year (text) |
|
- Age (number) |
|
- Is_male (text) |
|
ํ
์ด๋ธ: concert |
|
์ปฌ๋ผ: |
|
- concert_ID (number) (๊ธฐ๋ณธ ํค) |
|
- concert_Name (text) |
|
- Theme (text) |
|
- Stadium_ID (number) |
|
- Year (text) |
|
ํ
์ด๋ธ: singer_in_concert |
|
์ปฌ๋ผ: |
|
- concert_ID (number) |
|
- Singer_ID (number) |
|
|
|
์ธ๋ ํค ๊ด๊ณ: |
|
- singer_in_concert.Singer_ID โ singer.Singer_ID |
|
- singer_in_concert.concert_ID โ concert.concert_ID""" |
|
}, |
|
{ |
|
"db_id": "pets_1", |
|
"question": "๊ฐ์ฅ ๋์ด๊ฐ ๋ง์ ๊ฐ์ ์ด๋ฆ์ ๋ฌด์์ธ๊ฐ์?", |
|
"schema": """๋ฐ์ดํฐ๋ฒ ์ด์ค ์คํค๋ง: |
|
ํ
์ด๋ธ: Student |
|
์ปฌ๋ผ: |
|
- StuID (number) (๊ธฐ๋ณธ ํค) |
|
- LName (text) |
|
- Fname (text) |
|
- Age (number) |
|
- Sex (text) |
|
- Major (number) |
|
- Advisor (number) |
|
- city_code (text) |
|
ํ
์ด๋ธ: Has_Pet |
|
์ปฌ๋ผ: |
|
- StuID (number) |
|
- PetID (number) |
|
ํ
์ด๋ธ: Pets |
|
์ปฌ๋ผ: |
|
- PetID (number) (๊ธฐ๋ณธ ํค) |
|
- PetType (text) |
|
- pet_age (number) |
|
- weight (number)""" |
|
}, |
|
{ |
|
"db_id": "car_1", |
|
"question": "๋ฏธ๊ตญ์ฐ ์๋์ฐจ ์ค ๊ฐ์ฅ ๋น ๋ฅธ ์๋์ฐจ๋ ๋ฌด์์ธ๊ฐ์?", |
|
"schema": """๋ฐ์ดํฐ๋ฒ ์ด์ค ์คํค๋ง: |
|
ํ
์ด๋ธ: continents |
|
์ปฌ๋ผ: |
|
- ContId (number) (๊ธฐ๋ณธ ํค) |
|
- Continent (text) |
|
ํ
์ด๋ธ: countries |
|
์ปฌ๋ผ: |
|
- CountryId (number) (๊ธฐ๋ณธ ํค) |
|
- CountryName (text) |
|
- Continent (number) |
|
ํ
์ด๋ธ: car_makers |
|
์ปฌ๋ผ: |
|
- Id (number) (๊ธฐ๋ณธ ํค) |
|
- Maker (text) |
|
- FullName (text) |
|
- Country (number) |
|
ํ
์ด๋ธ: model_list |
|
์ปฌ๋ผ: |
|
- ModelId (number) (๊ธฐ๋ณธ ํค) |
|
- Maker (number) |
|
- Model (text) |
|
ํ
์ด๋ธ: car_names |
|
์ปฌ๋ผ: |
|
- MakeId (number) (๊ธฐ๋ณธ ํค) |
|
- Model (text) |
|
- Make (text) |
|
ํ
์ด๋ธ: cars_data |
|
์ปฌ๋ผ: |
|
- Id (number) (๊ธฐ๋ณธ ํค) |
|
- MPG (text) |
|
- Cylinders (number) |
|
- Edispl (text) |
|
- Horsepower (text) |
|
- Weight (number) |
|
- Accelerate (number) |
|
- Year (number)""" |
|
}, |
|
{ |
|
"db_id": "tvshow", |
|
"question": "๊ฐ์ฅ ๋์ ํ์ ์ ๋ฐ์ TV ์ผ๋ ๋ฌด์์ธ๊ฐ์?", |
|
"schema": """๋ฐ์ดํฐ๋ฒ ์ด์ค ์คํค๋ง: |
|
ํ
์ด๋ธ: TV_Channel |
|
์ปฌ๋ผ: |
|
- id (number) (๊ธฐ๋ณธ ํค) |
|
- series_name (text) |
|
- Country (text) |
|
- Language (text) |
|
- Content (text) |
|
- Pixel_aspect_ratio_PAR (text) |
|
- Hight_definition_TV (text) |
|
- Pay_per_view_PPV (text) |
|
- Package_Option (text) |
|
ํ
์ด๋ธ: TV_series |
|
์ปฌ๋ผ: |
|
- id (number) |
|
- Episode (text) |
|
- Air_Date (text) |
|
- Rating (text) |
|
- Share (text) |
|
- 18_49_Rating_Share (text) |
|
- Viewers_m (text) |
|
- Weekly_Rank (number) |
|
- Channel (number) |
|
ํ
์ด๋ธ: Cartoon |
|
์ปฌ๋ผ: |
|
- id (number) (๊ธฐ๋ณธ ํค) |
|
- Title (text) |
|
- Directed_by (text) |
|
- Written_by (text) |
|
- Original_air_date (text) |
|
- Production_code (number) |
|
- Channel (number)""" |
|
} |
|
] |
|
|
|
def generate_sql(question, db_id, schema_info): |
|
"""Generate SQL query using the model.""" |
|
|
|
prompt = f"""๋น์ ์ ์์ฐ์ด ์ง๋ฌธ์ SQL ์ฟผ๋ฆฌ๋ก ๋ณํํ๋ ์ ๋ฌธ AI ์ด์์คํดํธ์
๋๋ค. ์ฌ์ฉ์๊ฐ ๋ฐ์ดํฐ๋ฒ ์ด์ค์์ ์ ๋ณด๋ฅผ ์ป๊ธฐ ์ํด ์ผ์ ์ธ์ด๋ก ์ง๋ฌธํ๋ฉด, ๋น์ ์ ํด๋น ์ง๋ฌธ์ ์ ํํ SQL ์ฟผ๋ฆฌ๋ก ๋ณํํด์ผ ํฉ๋๋ค. |
|
|
|
{schema_info} |
|
|
|
์ง๋ฌธ: {question} |
|
SQL:""" |
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
|
|
inputs = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs, |
|
max_new_tokens=256, |
|
temperature=0.1, |
|
top_p=0.95, |
|
do_sample=True, |
|
use_cache=True |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if prompt in response: |
|
sql_part = response.split(prompt)[-1].strip() |
|
else: |
|
sql_part = response |
|
|
|
|
|
if sql_part.startswith("assistant"): |
|
sql_part = sql_part[len("assistant"):].strip() |
|
|
|
|
|
lines = sql_part.split('\n') |
|
sql_query = "" |
|
for line in lines: |
|
line = line.strip() |
|
if line.lower().startswith(('select', 'with', '(select')): |
|
sql_query = line |
|
|
|
for next_line in lines[lines.index(line)+1:]: |
|
next_line = next_line.strip() |
|
if not next_line or next_line.startswith(('์ง๋ฌธ', '๋ฐ์ดํฐ๋ฒ ์ด์ค')): |
|
break |
|
sql_query += " " + next_line |
|
if next_line.endswith(';'): |
|
break |
|
break |
|
|
|
|
|
sql_query = sql_query.strip() |
|
if sql_query.endswith(';'): |
|
sql_query = sql_query[:-1] |
|
|
|
return sql_query if sql_query else "SQL ์์ฑ์ ์คํจํ์ต๋๋ค." |
|
|
|
def process_question(question, db_id, custom_schema=None): |
|
"""Process user question and generate SQL query.""" |
|
if not question or not db_id: |
|
return "์ง๋ฌธ๊ณผ ๋ฐ์ดํฐ๋ฒ ์ด์ค ID๋ฅผ ์
๋ ฅํด์ฃผ์ธ์." |
|
|
|
|
|
if custom_schema and custom_schema.strip(): |
|
schema_info = custom_schema |
|
else: |
|
|
|
schema_info = None |
|
for example in examples: |
|
if example["db_id"] == db_id: |
|
schema_info = example["schema"] |
|
break |
|
|
|
if not schema_info: |
|
return "์คํค๋ง ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. ์ปค์คํ
์คํค๋ง๋ฅผ ์
๋ ฅํด์ฃผ์ธ์." |
|
|
|
|
|
try: |
|
sql_query = generate_sql(question, db_id, schema_info) |
|
return sql_query |
|
except Exception as e: |
|
return f"์ค๋ฅ ๋ฐ์: {str(e)}" |
|
|
|
|
|
with gr.Blocks(title="Spider SQL Generator - Korean", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# ๐ท๏ธ Spider SQL Generator - Korean |
|
|
|
ํ๊ตญ์ด ์ง๋ฌธ์ SQL ์ฟผ๋ฆฌ๋ก ๋ณํํ๋ Llama 3.1 8B ๋ชจ๋ธ์
๋๋ค. |
|
|
|
## ๐ ์ฑ๋ฅ |
|
- **Exact Match**: 42.65% |
|
- **Execution Accuracy**: 65.47% |
|
- **Training**: Spider ๋ฐ์ดํฐ์
(ํ๊ตญ์ด ๋ฒ์ญ) |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
db_id_input = gr.Textbox( |
|
label="๋ฐ์ดํฐ๋ฒ ์ด์ค ID", |
|
placeholder="์: department_management", |
|
value="department_management" |
|
) |
|
|
|
question_input = gr.Textbox( |
|
label="์ง๋ฌธ (ํ๊ตญ์ด)", |
|
placeholder="์: ๊ฐ ๋ถ์๋ณ ์ง์ ์๋ฅผ ๋ณด์ฌ์ฃผ์ธ์.", |
|
lines=2 |
|
) |
|
|
|
with gr.Accordion("์คํค๋ง ์ ๋ณด (์ ํ์ฌํญ)", open=False): |
|
schema_input = gr.Textbox( |
|
label="์ปค์คํ
์คํค๋ง", |
|
placeholder="์ปค์คํ
์คํค๋ง๋ฅผ ์
๋ ฅํ์ธ์. ๋น์๋๋ฉด ์์ ์คํค๋ง๋ฅผ ์ฌ์ฉํฉ๋๋ค.", |
|
lines=10 |
|
) |
|
|
|
submit_btn = gr.Button("SQL ์์ฑ", variant="primary", size="lg") |
|
|
|
with gr.Column(): |
|
output = gr.Textbox( |
|
label="์์ฑ๋ SQL ์ฟผ๋ฆฌ", |
|
lines=4, |
|
elem_classes=["code"] |
|
) |
|
|
|
gr.Markdown(""" |
|
### ๐ก ์ฌ์ฉ ํ |
|
- ๋ฐ์ดํฐ๋ฒ ์ด์ค ID๋ ์์ ์์ ์ ํํ๊ฑฐ๋ ์ง์ ์
๋ ฅํ์ธ์ |
|
- ์ง๋ฌธ์ ํ๊ตญ์ด๋ก ์์ฐ์ค๋ฝ๊ฒ ์์ฑํ์ธ์ |
|
- ์คํค๋ง ์ ๋ณด๋ ์ ํ์ฌํญ์
๋๋ค |
|
""") |
|
|
|
|
|
gr.Markdown("### ๐ ์์ (ํด๋ฆญํ์ฌ ์ฌ์ฉ)") |
|
gr.Examples( |
|
examples=[ |
|
[ex["db_id"], ex["question"], ex["schema"]] for ex in examples |
|
], |
|
inputs=[db_id_input, question_input, schema_input], |
|
outputs=output, |
|
fn=process_question, |
|
cache_examples=False |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=process_question, |
|
inputs=[question_input, db_id_input, schema_input], |
|
outputs=output |
|
) |
|
|
|
|
|
gr.Markdown(f""" |
|
--- |
|
### ๐ค ๋ชจ๋ธ ์ ๋ณด |
|
- **Hugging Face**: [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) |
|
- **Base Model**: Llama 3.1 8B |
|
- **Fine-tuning**: LoRA with Unsloth |
|
- **Dataset**: Spider (Korean translated) |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |