Spaces:
Sleeping
Sleeping
import asyncio | |
import time | |
from typing import List | |
from service.rerank import RerankService | |
from search_service.base_search import BaseSearchService | |
from utils.bio_logger import bio_logger as logger | |
from dto.bio_document import BaseBioDocument | |
from bio_requests.rag_request import RagRequest | |
class RagService: | |
def __init__(self): | |
self.rerank_service = RerankService() | |
# 确保所有子类都被加载 | |
self.search_services = [ | |
subclass() for subclass in BaseSearchService.get_subclasses() | |
] | |
logger.info( | |
f"Loaded search services: {[service.__class__.__name__ for service in self.search_services]}" | |
) | |
async def multi_query(self, rag_request: RagRequest) -> List[BaseBioDocument]: | |
start_time = time.time() | |
batch_search = [ | |
service.filter_search(rag_request=rag_request) | |
for service in self.search_services | |
] | |
task_result = await asyncio.gather(*batch_search, return_exceptions=True) | |
all_results = [] | |
for result in task_result: | |
if isinstance(result, Exception): | |
logger.error(f"Error in search service: {result}") | |
continue | |
all_results.extend(result) | |
end_search_time = time.time() | |
logger.info( | |
f"Found {len(all_results)} results in total,time used:{end_search_time - start_time:.2f}s" | |
) | |
if rag_request.is_rerank: | |
logger.info("RerankService: is_rerank is True") | |
reranked_results = await self.rerank_service.rerank( | |
rag_request=rag_request, documents=all_results | |
) | |
end_rerank_time = time.time() | |
logger.info( | |
f"Reranked {len(reranked_results)} results,time used:{end_rerank_time - end_search_time:.2f}s" | |
) | |
else: | |
logger.info("RerankService: is_rerank is False, skip rerank") | |
reranked_results = all_results | |
return reranked_results[0 : rag_request.top_k] | |