File size: 7,062 Bytes
79899c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import asyncio
import re
import time
from typing import Dict, List

from dto.bio_document import BaseBioDocument, create_bio_document
from search_service.base_search import BaseSearchService
from bio_requests.rag_request import RagRequest
from utils.bio_logger import bio_logger as logger


from service.query_rewrite import QueryRewriteService
from service.pubmed_api import PubMedApi
from service.pubmed_async_api import PubMedAsyncApi
from config.global_storage import get_model_config


class PubMedSearchService(BaseSearchService):
    def __init__(self):
        self.query_rewrite_service = QueryRewriteService()
        self.model_config = get_model_config()

        self.pubmed_topk = self.model_config["recall"]["pubmed_topk"]
        self.es_topk = self.model_config["recall"]["es_topk"]
        self.data_source = "pubmed"

    async def get_query_list(self, rag_request: RagRequest) -> List[Dict]:
        """根据RagRequest获取查询列表"""
        if rag_request.is_rewrite:
            query_list = await self.query_rewrite_service.query_split(rag_request.query)
            logger.info(f"length of query_list after query_split: {len(query_list)}")
            if len(query_list) == 0:
                logger.info("query_list is empty, use query_split_for_simple")
                query_list = await self.query_rewrite_service.query_split_for_simple(
                    rag_request.query
                )
                logger.info(
                    f"length of query_list after query_split_for_simple: {len(query_list)}"
                )
            self.pubmed_topk = rag_request.pubmed_topk
            self.es_topk = rag_request.pubmed_topk
        else:
            self.pubmed_topk = rag_request.top_k
            self.es_topk = rag_request.top_k
            query_list = [
                {
                    "query_item": rag_request.query,
                    "search_type": rag_request.search_type,
                }
            ]
        return query_list

    async def search(self, rag_request: RagRequest) -> List[BaseBioDocument]:
        """异步搜索PubMed数据库"""
        if not rag_request.query:
            return []

        start_time = time.time()
        query_list = await self.get_query_list(rag_request)

        # 使用异步并发替代线程池
        articles_id_list = []
        es_articles = []

        try:
            # 创建异步任务列表,使用PubMedApi的search_database方法
            async_tasks = []
            for query in query_list:
                task = self._search_pubmed_with_sync_api(
                    query["query_item"], self.pubmed_topk, query["search_type"]
                )
                async_tasks.append((query, task))

            # 并发执行所有搜索任务
            results = await asyncio.gather(
                *[task for _, task in async_tasks], return_exceptions=True
            )

            # 处理结果
            for i, (query, _) in enumerate(async_tasks):
                result = results[i]

                if isinstance(result, Exception):
                    logger.error(f"Error in search pubmed: {result}")
                else:
                    articles_id_list.extend(result)

        except Exception as e:
            logger.error(f"Error in concurrent PubMed search: {e}")

        # 获取文章详细信息
        pubmed_docs = await self.fetch_article_details(articles_id_list)

        # 合并结果
        all_results = []
        all_results.extend(pubmed_docs)
        all_results.extend(es_articles)

        logger.info(
            f"""Finished searching PubMed, query:{rag_request.query}, 
            total articles: {len(articles_id_list)}, total time: {time.time() - start_time:.2f}s"""
        )
        return all_results

    async def _search_pubmed_with_sync_api(
        self, query: str, top_k: int, search_type: str
    ) -> List[str]:
        """
        使用PubMedApi的search_database方法,但通过异步包装来提升并发效率

        Args:
            query: 搜索查询
            top_k: 返回结果数量
            search_type: 搜索类型

        Returns:
            文章ID列表
        """
        try:
            # 在线程池中运行同步的search_database方法
            loop = asyncio.get_event_loop()
            pubmed_api = PubMedApi()

            # 使用run_in_executor来异步执行同步方法
            id_list = await loop.run_in_executor(
                None,  # 使用默认线程池
                pubmed_api.search_database,
                query,
                top_k,
                search_type,
            )
            return id_list
        except Exception as e:
            logger.error(f"Error in PubMed search for query '{query}': {e}")
            raise e

    async def fetch_article_details(
        self, articles_id_list: List[str]
    ) -> List[BaseBioDocument]:
        """根据文章ID从pubmed获取文章详细信息"""
        if not articles_id_list:
            return []

        # 将articles_id_list去重
        articles_id_list = list(set(articles_id_list))

        # 将articles_id_list以group_size个一组切分成不同的列表
        group_size = 80
        articles_id_groups = [
            articles_id_list[i : i + group_size]
            for i in range(0, len(articles_id_list), group_size)
        ]

        try:
            # 并发获取所有组的详细信息
            batch_tasks = []
            for ids in articles_id_groups:
                pubmed_async_api = PubMedAsyncApi()
                task = pubmed_async_api.fetch_details(id_list=ids)
                batch_tasks.append(task)

            task_results = await asyncio.gather(*batch_tasks, return_exceptions=True)

            fetch_results = []
            for result in task_results:
                if isinstance(result, Exception):
                    logger.error(f"Error in fetch_details: {result}")
                    continue
                fetch_results.extend(result)

        except Exception as e:
            logger.error(f"Error in concurrent fetch_details: {e}")
            return []

        # 转换为BioDocument对象
        all_results = [
            create_bio_document(
                title=result["title"],
                abstract=result["abstract"],
                authors=self.process_authors(result["authors"]),
                doi=result["doi"],
                source=self.data_source,
                source_id=result["pmid"],
                pub_date=result["pub_date"],
                journal=result["journal"],
                text=result["abstract"],
                url=f'https://pubmed.ncbi.nlm.nih.gov/{result["pmid"]}',
            )
            for result in fetch_results
        ]
        return all_results

    def process_authors(self, author_list: List[Dict]) -> str:
        """处理作者列表,将其转换为字符串"""
        return ", ".join(
            [f"{author['forename']} {author['lastname']}" for author in author_list]
        )