import gradio as gr import torch try: from unsloth import FastLanguageModel except ImportError: print("Unsloth가 설치되어 있지 않습니다. 설치 중...") import subprocess subprocess.check_call(["pip", "install", "unsloth"]) from unsloth import FastLanguageModel # Hugging Face에 업로드된 모델 사용 MODEL_NAME = "huggingface-KREW/Llama-3.1-8B-Spider-SQL-Ko" print(f"Loading model from Hugging Face: {MODEL_NAME}") # Unsloth를 사용하여 모델과 토크나이저 로드 try: model, tokenizer = FastLanguageModel.from_pretrained( model_name=MODEL_NAME, max_seq_length=2048, dtype=None, # 자동 감지 load_in_4bit=True, # 4비트 양자화 사용 ) # 추론 모드로 설정 FastLanguageModel.for_inference(model) print("Model loaded successfully with Unsloth!") except Exception as e: print(f"Error loading model with Unsloth: {e}") print("\n모델이 Hugging Face에 제대로 업로드되지 않았을 수 있습니다.") print("로컬 모델을 사용하거나 모델을 다시 업로드해주세요.") raise # Example databases and questions examples = [ { "db_id": "department_management", "question": "각 부서별 직원 수를 보여주세요.", "schema": """데이터베이스 스키마: 테이블: department 컬럼: - Department_ID (number) (기본 키) - Name (text) - Creation (text) - Ranking (number) - Budget_in_Billions (number) - Num_Employees (number) 테이블: head 컬럼: - head_ID (number) (기본 키) - name (text) - born_state (text) - age (number) 테이블: management 컬럼: - department_ID (number) (기본 키) - head_ID (number) - temporary_acting (text) 외래 키 관계: - management.head_ID → head.head_ID - management.department_ID → department.Department_ID""" }, { "db_id": "concert_singer", "question": "가장 많은 콘서트를 연 가수는 누구인가요?", "schema": """데이터베이스 스키마: 테이블: singer 컬럼: - Singer_ID (number) (기본 키) - Name (text) - Country (text) - Song_Name (text) - Song_release_year (text) - Age (number) - Is_male (text) 테이블: concert 컬럼: - concert_ID (number) (기본 키) - concert_Name (text) - Theme (text) - Stadium_ID (number) - Year (text) 테이블: singer_in_concert 컬럼: - concert_ID (number) - Singer_ID (number) 외래 키 관계: - singer_in_concert.Singer_ID → singer.Singer_ID - singer_in_concert.concert_ID → concert.concert_ID""" }, { "db_id": "pets_1", "question": "가장 나이가 많은 개의 이름은 무엇인가요?", "schema": """데이터베이스 스키마: 테이블: Student 컬럼: - StuID (number) (기본 키) - LName (text) - Fname (text) - Age (number) - Sex (text) - Major (number) - Advisor (number) - city_code (text) 테이블: Has_Pet 컬럼: - StuID (number) - PetID (number) 테이블: Pets 컬럼: - PetID (number) (기본 키) - PetType (text) - pet_age (number) - weight (number)""" }, { "db_id": "car_1", "question": "미국산 자동차 중 가장 빠른 자동차는 무엇인가요?", "schema": """데이터베이스 스키마: 테이블: continents 컬럼: - ContId (number) (기본 키) - Continent (text) 테이블: countries 컬럼: - CountryId (number) (기본 키) - CountryName (text) - Continent (number) 테이블: car_makers 컬럼: - Id (number) (기본 키) - Maker (text) - FullName (text) - Country (number) 테이블: model_list 컬럼: - ModelId (number) (기본 키) - Maker (number) - Model (text) 테이블: car_names 컬럼: - MakeId (number) (기본 키) - Model (text) - Make (text) 테이블: cars_data 컬럼: - Id (number) (기본 키) - MPG (text) - Cylinders (number) - Edispl (text) - Horsepower (text) - Weight (number) - Accelerate (number) - Year (number)""" }, { "db_id": "tvshow", "question": "가장 높은 평점을 받은 TV 쇼는 무엇인가요?", "schema": """데이터베이스 스키마: 테이블: TV_Channel 컬럼: - id (number) (기본 키) - series_name (text) - Country (text) - Language (text) - Content (text) - Pixel_aspect_ratio_PAR (text) - Hight_definition_TV (text) - Pay_per_view_PPV (text) - Package_Option (text) 테이블: TV_series 컬럼: - id (number) - Episode (text) - Air_Date (text) - Rating (text) - Share (text) - 18_49_Rating_Share (text) - Viewers_m (text) - Weekly_Rank (number) - Channel (number) 테이블: Cartoon 컬럼: - id (number) (기본 키) - Title (text) - Directed_by (text) - Written_by (text) - Original_air_date (text) - Production_code (number) - Channel (number)""" } ] def generate_sql(question, db_id, schema_info): """Generate SQL query using the model.""" # Create prompt with schema prompt = f"""당신은 자연어 질문을 SQL 쿼리로 변환하는 전문 AI 어시스턴트입니다. 사용자가 데이터베이스에서 정보를 얻기 위해 일상 언어로 질문하면, 당신은 해당 질문을 정확한 SQL 쿼리로 변환해야 합니다. {schema_info} 질문: {question} SQL:""" # 채팅 메시지로 포맷팅 messages = [{"role": "user", "content": prompt}] # 채팅 템플릿 적용 inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(model.device) # Generate with torch.no_grad(): outputs = model.generate( inputs, max_new_tokens=256, temperature=0.1, top_p=0.95, do_sample=True, use_cache=True ) # Decode response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract SQL after the prompt if prompt in response: sql_part = response.split(prompt)[-1].strip() else: sql_part = response # Clean up the response if sql_part.startswith("assistant"): sql_part = sql_part[len("assistant"):].strip() # Extract SQL query lines = sql_part.split('\n') sql_query = "" for line in lines: line = line.strip() if line.lower().startswith(('select', 'with', '(select')): sql_query = line # Continue collecting lines until we hit a semicolon or empty line for next_line in lines[lines.index(line)+1:]: next_line = next_line.strip() if not next_line or next_line.startswith(('질문', '데이터베이스')): break sql_query += " " + next_line if next_line.endswith(';'): break break # Clean up SQL sql_query = sql_query.strip() if sql_query.endswith(';'): sql_query = sql_query[:-1] return sql_query if sql_query else "SQL 생성에 실패했습니다." def process_question(question, db_id, custom_schema=None): """Process user question and generate SQL query.""" if not question or not db_id: return "질문과 데이터베이스 ID를 입력해주세요." # Use custom schema if provided, otherwise find from examples if custom_schema and custom_schema.strip(): schema_info = custom_schema else: # Find schema from examples schema_info = None for example in examples: if example["db_id"] == db_id: schema_info = example["schema"] break if not schema_info: return "스키마 정보를 찾을 수 없습니다. 커스텀 스키마를 입력해주세요." # Generate SQL try: sql_query = generate_sql(question, db_id, schema_info) return sql_query except Exception as e: return f"오류 발생: {str(e)}" # Create Gradio interface with gr.Blocks(title="Spider SQL Generator - Korean", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🕷️ Spider SQL Generator - Korean 한국어 질문을 SQL 쿼리로 변환하는 Llama 3.1 8B 모델입니다. ## 📊 성능 - **Exact Match**: 42.65% - **Execution Accuracy**: 65.47% - **Training**: Spider 데이터셋 (한국어 번역) """) with gr.Row(): with gr.Column(): db_id_input = gr.Textbox( label="데이터베이스 ID", placeholder="예: department_management", value="department_management" ) question_input = gr.Textbox( label="질문 (한국어)", placeholder="예: 각 부서별 직원 수를 보여주세요.", lines=2 ) with gr.Accordion("스키마 정보 (선택사항)", open=False): schema_input = gr.Textbox( label="커스텀 스키마", placeholder="커스텀 스키마를 입력하세요. 비워두면 예제 스키마를 사용합니다.", lines=10 ) submit_btn = gr.Button("SQL 생성", variant="primary", size="lg") with gr.Column(): output = gr.Textbox( label="생성된 SQL 쿼리", lines=4, elem_classes=["code"] ) gr.Markdown(""" ### 💡 사용 팁 - 데이터베이스 ID는 예제에서 선택하거나 직접 입력하세요 - 질문은 한국어로 자연스럽게 작성하세요 - 스키마 정보는 선택사항입니다 """) # Examples gr.Markdown("### 📚 예제 (클릭하여 사용)") gr.Examples( examples=[ [ex["db_id"], ex["question"], ex["schema"]] for ex in examples ], inputs=[db_id_input, question_input, schema_input], outputs=output, fn=process_question, cache_examples=False ) # Submit action submit_btn.click( fn=process_question, inputs=[question_input, db_id_input, schema_input], outputs=output ) # Model info gr.Markdown(f""" --- ### 🤖 모델 정보 - **Hugging Face**: [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) - **Base Model**: Llama 3.1 8B - **Fine-tuning**: LoRA with Unsloth - **Dataset**: Spider (Korean translated) """) # Launch the app if __name__ == "__main__": demo.launch(share=True)