gemma / services /rag_pipeline.py
dasomaru's picture
Upload folder using huggingface_hub
9b14ff1 verified
raw
history blame
678 Bytes
# from retriever.vectordb import search_documents
from retriever.vectordb_rerank import search_documents
from generator.prompt_builder import build_prompt
from generator.llm_inference import generate_answer
def rag_pipeline(query: str, top_k: int = 5) -> str:
"""
1. ์‚ฌ์šฉ์ž ์งˆ๋ฌธ์œผ๋กœ ๊ด€๋ จ ๋ฌธ์„œ๋ฅผ ๊ฒ€์ƒ‰
2. ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ์™€ ํ•จ๊ป˜ ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
3. ํ”„๋กฌํ”„ํŠธ๋กœ๋ถ€ํ„ฐ ๋‹ต๋ณ€ ์ƒ์„ฑ
"""
# 1. ๊ฒ€์ƒ‰
context_docs = search_documents(query, top_k=top_k)
# 2. ํ”„๋กฌํ”„ํŠธ ์กฐ๋ฆฝ
prompt = build_prompt(query, context_docs)
# 3. ๋ชจ๋ธ ์ถ”๋ก 
output = generate_answer(prompt)
return output