File size: 2,232 Bytes
2c5f455 baaabde 2c5f455 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# vectordb_relank_law.py
import faiss
import numpy as np
import os
from chromadb import PersistentClient
from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
from retriever.reranker import rerank_documents
# chroma vector config v2
embedding_models = [
"upskyy/bge-m3-korean",
"jhgan/ko-sbert-sts",
"BM-K/KoSimCSE-roberta",
"BM-K/KoSimCSE-v2-multitask",
"snunlp/KR-SBERT-V40K-klueNLI-augSTS",
"beomi/KcELECTRA-small-v2022",
]
# law_db config v2
CHROMA_PATH = os.path.abspath("data/index/law_db")
COLLECTION_NAME = "law_all"
EMBEDDING_MODEL_NAME = embedding_models[0] # ์ฌ์ฉํ๊ณ ์ ํ๋ ๋ชจ๋ธ ์ ํ
# 1. ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ v2
# embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
# 2. ์๋ฒ ๋ฉ ํจ์ ์ค์
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL_NAME)
# 3. Chroma ํด๋ผ์ด์ธํธ ๋ฐ ์ปฌ๋ ์
๋ก๋
client = PersistentClient(path=CHROMA_PATH)
collection = client.get_collection(name=COLLECTION_NAME, embedding_function=embedding_fn)
# 4. ๊ฒ์ ํจ์
def search_documents(query: str, top_k: int = 5):
print(f"\n๐ ๊ฒ์์ด: '{query}'")
results = collection.query(
query_texts=[query],
n_results=top_k,
include=["documents", "metadatas", "distances"]
)
# ๋ฌธ์ ๋ฆฌ์คํธ๋ง ์ถ์ถ
docs = results['documents'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]
# Rerank ๋ฌธ์
reranked_docs = rerank_documents(query, docs, top_k=top_k)
# Rerank๋ ๋ฌธ์์ ๋ง์ถฐ metadata, distance ๋ค์ ์ ๋ ฌ
reranked_data = []
for doc in reranked_docs:
idx = docs.index(doc)
reranked_data.append((doc, metadatas[idx], distances[idx]))
for i, (doc, meta, dist) in enumerate(reranked_data):
print(f"\n๐ ๊ฒฐ๊ณผ {i+1} (์ ์ฌ๋: {1 - dist:.2f})")
print(f"๋ฌธ์: {doc[:150]}...")
print("๋ฉํ๋ฐ์ดํฐ:")
print(meta)
return reranked_data # ํ์ํ๋ฉด ๋ฆฌํด
|