jackkuo's picture
add QA
79899c0
import asyncio
import time
from typing import List
from service.rerank import RerankService
from search_service.base_search import BaseSearchService
from utils.bio_logger import bio_logger as logger
from dto.bio_document import BaseBioDocument
from bio_requests.rag_request import RagRequest
class RagService:
def __init__(self):
self.rerank_service = RerankService()
# 确保所有子类都被加载
self.search_services = [
subclass() for subclass in BaseSearchService.get_subclasses()
]
logger.info(
f"Loaded search services: {[service.__class__.__name__ for service in self.search_services]}"
)
async def multi_query(self, rag_request: RagRequest) -> List[BaseBioDocument]:
start_time = time.time()
batch_search = [
service.filter_search(rag_request=rag_request)
for service in self.search_services
]
task_result = await asyncio.gather(*batch_search, return_exceptions=True)
all_results = []
for result in task_result:
if isinstance(result, Exception):
logger.error(f"Error in search service: {result}")
continue
all_results.extend(result)
end_search_time = time.time()
logger.info(
f"Found {len(all_results)} results in total,time used:{end_search_time - start_time:.2f}s"
)
if rag_request.is_rerank:
logger.info("RerankService: is_rerank is True")
reranked_results = await self.rerank_service.rerank(
rag_request=rag_request, documents=all_results
)
end_rerank_time = time.time()
logger.info(
f"Reranked {len(reranked_results)} results,time used:{end_rerank_time - end_search_time:.2f}s"
)
else:
logger.info("RerankService: is_rerank is False, skip rerank")
reranked_results = all_results
return reranked_results[0 : rag_request.top_k]