File size: 1,039 Bytes
9b14ff1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# 1. Reranker ๋ชจ๋ธ ๋กœ๋”ฉ
reranker_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-base")
reranker_model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-base")

def rerank_documents(query: str, docs: list, top_k: int = 5) -> list:
    """

    ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ ๋ฆฌ์ŠคํŠธ๋ฅผ Query์™€ ๋น„๊ตํ•ด์„œ relevance ์ˆœ์„œ๋กœ ์žฌ์ •๋ ฌํ•œ๋‹ค.

    """
    pairs = [(query, doc) for doc in docs]
    
    inputs = reranker_tokenizer.batch_encode_plus(
        pairs,
        padding=True,
        truncation=True,
        return_tensors="pt",
        max_length=512
    )
    
    with torch.no_grad():
        scores = reranker_model(**inputs).logits.squeeze(-1)  # (batch_size,)
    
    scores = scores.tolist()
    
    # ์ ์ˆ˜ ๋†’์€ ์ˆœ์„œ๋Œ€๋กœ ์ •๋ ฌ
    sorted_docs = [doc for _, doc in sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)]
    
    return sorted_docs[:top_k]