File size: 2,044 Bytes
79899c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import asyncio
import time
from typing import List
from service.rerank import RerankService
from search_service.base_search import BaseSearchService
from utils.bio_logger import bio_logger as logger

from dto.bio_document import BaseBioDocument

from bio_requests.rag_request import RagRequest


class RagService:
    def __init__(self):
        self.rerank_service = RerankService()
        # 确保所有子类都被加载
        self.search_services = [
            subclass() for subclass in BaseSearchService.get_subclasses()
        ]
        logger.info(
            f"Loaded search services: {[service.__class__.__name__ for service in self.search_services]}"
        )

    async def multi_query(self, rag_request: RagRequest) -> List[BaseBioDocument]:
        start_time = time.time()
        batch_search = [
            service.filter_search(rag_request=rag_request)
            for service in self.search_services
        ]
        task_result = await asyncio.gather(*batch_search, return_exceptions=True)
        all_results = []
        for result in task_result:
            if isinstance(result, Exception):
                logger.error(f"Error in search service: {result}")
                continue
            all_results.extend(result)
        end_search_time = time.time()
        logger.info(
            f"Found {len(all_results)} results in total,time used:{end_search_time - start_time:.2f}s"
        )
        if rag_request.is_rerank:
            logger.info("RerankService: is_rerank is True")
            reranked_results = await self.rerank_service.rerank(
                rag_request=rag_request, documents=all_results
            )
            end_rerank_time = time.time()
            logger.info(
                f"Reranked {len(reranked_results)} results,time used:{end_rerank_time - end_search_time:.2f}s"
            )
        else:
            logger.info("RerankService: is_rerank is False, skip rerank")
            reranked_results = all_results

        return reranked_results[0 : rag_request.top_k]