File size: 1,650 Bytes
c1f976c
 
 
 
336f693
 
c1f976c
 
 
a91f908
d3411be
a91f908
 
 
 
 
 
 
 
d3411be
 
 
c1f976c
d3411be
2e53579
 
f0633ef
2a32abb
ccfdfe2
 
 
 
d3411be
ccfdfe2
d3411be
 
 
 
 
ccfdfe2
 
d3411be
 
 
 
 
 
 
 
 
ccfdfe2
d3411be
c1f976c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# from retriever.vectordb_rerank import search_documents  # 🧠 RAG 검색기 불러오기
from services.rag_pipeline import rag_pipeline

model_name = "dasomaru/gemma-3-4bit-it-demo"


# 1. 모델/토크나이저 1회 로딩
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# 🚀 model은 CPU로만 먼저 올림 (GPU 아직 없음)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,  # 4bit model이니까
    trust_remote_code=True,
)

# 2. 캐시 관리
search_cache = {}

@spaces.GPU(duration=300)
def generate_response(query: str):
    tokenizer = AutoTokenizer.from_pretrained("dasomaru/gemma-3-4bit-it-demo")
    model = AutoModelForCausalLM.from_pretrained("dasomaru/gemma-3-4bit-it-demo")
    model.to("cuda")    

    if query in search_cache:
        print(f"⚡ 캐시 사용: '{query}'")
        return search_cache[query]

    # 🔥 rag_pipeline을 호출해서 검색 + 생성
    results = rag_pipeline(query)

    # 결과가 list일 경우 합치기
    if isinstance(results, list):
        results = "\n\n".join(results)

    search_cache[query] = results
    return results
    
# 3. Gradio 인터페이스
demo = gr.Interface(
    fn=generate_response,
    inputs=gr.Textbox(lines=2, placeholder="질문을 입력하세요"),
    outputs="text",
    title="Law RAG Assistant",
    description="법령 기반 RAG 파이프라인 테스트",
)

# demo.launch(server_name="0.0.0.0", server_port=7860)  # 🚀 API 배포 준비 가능
demo.launch()