Spaces:
Running
Running
File size: 7,062 Bytes
79899c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import asyncio
import re
import time
from typing import Dict, List
from dto.bio_document import BaseBioDocument, create_bio_document
from search_service.base_search import BaseSearchService
from bio_requests.rag_request import RagRequest
from utils.bio_logger import bio_logger as logger
from service.query_rewrite import QueryRewriteService
from service.pubmed_api import PubMedApi
from service.pubmed_async_api import PubMedAsyncApi
from config.global_storage import get_model_config
class PubMedSearchService(BaseSearchService):
def __init__(self):
self.query_rewrite_service = QueryRewriteService()
self.model_config = get_model_config()
self.pubmed_topk = self.model_config["recall"]["pubmed_topk"]
self.es_topk = self.model_config["recall"]["es_topk"]
self.data_source = "pubmed"
async def get_query_list(self, rag_request: RagRequest) -> List[Dict]:
"""根据RagRequest获取查询列表"""
if rag_request.is_rewrite:
query_list = await self.query_rewrite_service.query_split(rag_request.query)
logger.info(f"length of query_list after query_split: {len(query_list)}")
if len(query_list) == 0:
logger.info("query_list is empty, use query_split_for_simple")
query_list = await self.query_rewrite_service.query_split_for_simple(
rag_request.query
)
logger.info(
f"length of query_list after query_split_for_simple: {len(query_list)}"
)
self.pubmed_topk = rag_request.pubmed_topk
self.es_topk = rag_request.pubmed_topk
else:
self.pubmed_topk = rag_request.top_k
self.es_topk = rag_request.top_k
query_list = [
{
"query_item": rag_request.query,
"search_type": rag_request.search_type,
}
]
return query_list
async def search(self, rag_request: RagRequest) -> List[BaseBioDocument]:
"""异步搜索PubMed数据库"""
if not rag_request.query:
return []
start_time = time.time()
query_list = await self.get_query_list(rag_request)
# 使用异步并发替代线程池
articles_id_list = []
es_articles = []
try:
# 创建异步任务列表,使用PubMedApi的search_database方法
async_tasks = []
for query in query_list:
task = self._search_pubmed_with_sync_api(
query["query_item"], self.pubmed_topk, query["search_type"]
)
async_tasks.append((query, task))
# 并发执行所有搜索任务
results = await asyncio.gather(
*[task for _, task in async_tasks], return_exceptions=True
)
# 处理结果
for i, (query, _) in enumerate(async_tasks):
result = results[i]
if isinstance(result, Exception):
logger.error(f"Error in search pubmed: {result}")
else:
articles_id_list.extend(result)
except Exception as e:
logger.error(f"Error in concurrent PubMed search: {e}")
# 获取文章详细信息
pubmed_docs = await self.fetch_article_details(articles_id_list)
# 合并结果
all_results = []
all_results.extend(pubmed_docs)
all_results.extend(es_articles)
logger.info(
f"""Finished searching PubMed, query:{rag_request.query},
total articles: {len(articles_id_list)}, total time: {time.time() - start_time:.2f}s"""
)
return all_results
async def _search_pubmed_with_sync_api(
self, query: str, top_k: int, search_type: str
) -> List[str]:
"""
使用PubMedApi的search_database方法,但通过异步包装来提升并发效率
Args:
query: 搜索查询
top_k: 返回结果数量
search_type: 搜索类型
Returns:
文章ID列表
"""
try:
# 在线程池中运行同步的search_database方法
loop = asyncio.get_event_loop()
pubmed_api = PubMedApi()
# 使用run_in_executor来异步执行同步方法
id_list = await loop.run_in_executor(
None, # 使用默认线程池
pubmed_api.search_database,
query,
top_k,
search_type,
)
return id_list
except Exception as e:
logger.error(f"Error in PubMed search for query '{query}': {e}")
raise e
async def fetch_article_details(
self, articles_id_list: List[str]
) -> List[BaseBioDocument]:
"""根据文章ID从pubmed获取文章详细信息"""
if not articles_id_list:
return []
# 将articles_id_list去重
articles_id_list = list(set(articles_id_list))
# 将articles_id_list以group_size个一组切分成不同的列表
group_size = 80
articles_id_groups = [
articles_id_list[i : i + group_size]
for i in range(0, len(articles_id_list), group_size)
]
try:
# 并发获取所有组的详细信息
batch_tasks = []
for ids in articles_id_groups:
pubmed_async_api = PubMedAsyncApi()
task = pubmed_async_api.fetch_details(id_list=ids)
batch_tasks.append(task)
task_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
fetch_results = []
for result in task_results:
if isinstance(result, Exception):
logger.error(f"Error in fetch_details: {result}")
continue
fetch_results.extend(result)
except Exception as e:
logger.error(f"Error in concurrent fetch_details: {e}")
return []
# 转换为BioDocument对象
all_results = [
create_bio_document(
title=result["title"],
abstract=result["abstract"],
authors=self.process_authors(result["authors"]),
doi=result["doi"],
source=self.data_source,
source_id=result["pmid"],
pub_date=result["pub_date"],
journal=result["journal"],
text=result["abstract"],
url=f'https://pubmed.ncbi.nlm.nih.gov/{result["pmid"]}',
)
for result in fetch_results
]
return all_results
def process_authors(self, author_list: List[Dict]) -> str:
"""处理作者列表,将其转换为字符串"""
return ", ".join(
[f"{author['forename']} {author['lastname']}" for author in author_list]
)
|