|
|
|
import faiss
|
|
import numpy as np
|
|
import os
|
|
from chromadb import PersistentClient
|
|
from chromadb.utils import embedding_functions
|
|
from sentence_transformers import SentenceTransformer
|
|
from retriever.reranker import rerank_documents
|
|
|
|
|
|
embedding_models = [
|
|
"upskyy/bge-m3-korean",
|
|
"jhgan/ko-sbert-sts",
|
|
"BM-K/KoSimCSE-roberta",
|
|
"BM-K/KoSimCSE-v2-multitask",
|
|
"snunlp/KR-SBERT-V40K-klueNLI-augSTS",
|
|
"beomi/KcELECTRA-small-v2022",
|
|
]
|
|
|
|
CHROMA_PATH = os.path.abspath("data/index/law_db")
|
|
COLLECTION_NAME = "law_all"
|
|
EMBEDDING_MODEL_NAME = embedding_models[0]
|
|
|
|
|
|
|
|
|
|
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
|
|
|
|
|
|
|
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL_NAME)
|
|
|
|
|
|
client = PersistentClient(path=CHROMA_PATH)
|
|
collection = client.get_collection(name=COLLECTION_NAME, embedding_function=embedding_fn)
|
|
|
|
|
|
|
|
|
|
def search_documents(query: str, top_k: int = 5):
|
|
print(f"\n๐ ๊ฒ์์ด: '{query}'")
|
|
results = collection.query(
|
|
query_texts=[query],
|
|
n_results=top_k,
|
|
include=["documents", "metadatas", "distances"]
|
|
)
|
|
|
|
|
|
docs = results['documents'][0]
|
|
metadatas = results['metadatas'][0]
|
|
distances = results['distances'][0]
|
|
|
|
|
|
reranked_docs = rerank_documents(query, docs, top_k=top_k)
|
|
|
|
|
|
reranked_data = []
|
|
for doc in reranked_docs:
|
|
idx = docs.index(doc)
|
|
reranked_data.append((doc, metadatas[idx], distances[idx]))
|
|
|
|
for i, (doc, meta, dist) in enumerate(reranked_data):
|
|
print(f"\n๐ ๊ฒฐ๊ณผ {i+1} (์ ์ฌ๋: {1 - dist:.2f})")
|
|
print(f"๋ฌธ์: {doc[:150]}...")
|
|
print("๋ฉํ๋ฐ์ดํฐ:")
|
|
print(meta)
|
|
|
|
return reranked_data
|
|
|