|
import os |
|
|
|
from huggingface_hub import InferenceClient |
|
|
|
from rag_demo.rag.base.query import Query |
|
from rag_demo.rag.base.template_factory import RAGStep |
|
from rag_demo.preprocessing.embed import EmbeddedChunk |
|
|
|
|
|
class Reranker(RAGStep): |
|
def generate( |
|
self, query: Query, chunks: list[EmbeddedChunk], keep_top_k: int |
|
) -> list[EmbeddedChunk]: |
|
api = InferenceClient( |
|
model="intfloat/multilingual-e5-large-instruct", |
|
token=os.getenv("HF_API_TOKEN"), |
|
) |
|
similarity = api.sentence_similarity( |
|
query.content, [chunk.content for chunk in chunks] |
|
) |
|
for chunk, sim in zip(chunks, similarity): |
|
chunk.similarity = sim |
|
|
|
return sorted(chunks, key=lambda x: x.similarity, reverse=True)[:keep_top_k] |
|
|