Spaces:
Sleeping
Sleeping
import asyncio | |
import re | |
import time | |
from typing import Dict, List | |
from dto.bio_document import BaseBioDocument, create_bio_document | |
from search_service.base_search import BaseSearchService | |
from bio_requests.rag_request import RagRequest | |
from utils.bio_logger import bio_logger as logger | |
from service.query_rewrite import QueryRewriteService | |
from service.pubmed_api import PubMedApi | |
from service.pubmed_async_api import PubMedAsyncApi | |
from config.global_storage import get_model_config | |
class PubMedSearchService(BaseSearchService): | |
def __init__(self): | |
self.query_rewrite_service = QueryRewriteService() | |
self.model_config = get_model_config() | |
self.pubmed_topk = self.model_config["recall"]["pubmed_topk"] | |
self.es_topk = self.model_config["recall"]["es_topk"] | |
self.data_source = "pubmed" | |
async def get_query_list(self, rag_request: RagRequest) -> List[Dict]: | |
"""根据RagRequest获取查询列表""" | |
if rag_request.is_rewrite: | |
query_list = await self.query_rewrite_service.query_split(rag_request.query) | |
logger.info(f"length of query_list after query_split: {len(query_list)}") | |
if len(query_list) == 0: | |
logger.info("query_list is empty, use query_split_for_simple") | |
query_list = await self.query_rewrite_service.query_split_for_simple( | |
rag_request.query | |
) | |
logger.info( | |
f"length of query_list after query_split_for_simple: {len(query_list)}" | |
) | |
self.pubmed_topk = rag_request.pubmed_topk | |
self.es_topk = rag_request.pubmed_topk | |
else: | |
self.pubmed_topk = rag_request.top_k | |
self.es_topk = rag_request.top_k | |
query_list = [ | |
{ | |
"query_item": rag_request.query, | |
"search_type": rag_request.search_type, | |
} | |
] | |
return query_list | |
async def search(self, rag_request: RagRequest) -> List[BaseBioDocument]: | |
"""异步搜索PubMed数据库""" | |
if not rag_request.query: | |
return [] | |
start_time = time.time() | |
query_list = await self.get_query_list(rag_request) | |
# 使用异步并发替代线程池 | |
articles_id_list = [] | |
es_articles = [] | |
try: | |
# 创建异步任务列表,使用PubMedApi的search_database方法 | |
async_tasks = [] | |
for query in query_list: | |
task = self._search_pubmed_with_sync_api( | |
query["query_item"], self.pubmed_topk, query["search_type"] | |
) | |
async_tasks.append((query, task)) | |
# 并发执行所有搜索任务 | |
results = await asyncio.gather( | |
*[task for _, task in async_tasks], return_exceptions=True | |
) | |
# 处理结果 | |
for i, (query, _) in enumerate(async_tasks): | |
result = results[i] | |
if isinstance(result, Exception): | |
logger.error(f"Error in search pubmed: {result}") | |
else: | |
articles_id_list.extend(result) | |
except Exception as e: | |
logger.error(f"Error in concurrent PubMed search: {e}") | |
# 获取文章详细信息 | |
pubmed_docs = await self.fetch_article_details(articles_id_list) | |
# 合并结果 | |
all_results = [] | |
all_results.extend(pubmed_docs) | |
all_results.extend(es_articles) | |
logger.info( | |
f"""Finished searching PubMed, query:{rag_request.query}, | |
total articles: {len(articles_id_list)}, total time: {time.time() - start_time:.2f}s""" | |
) | |
return all_results | |
async def _search_pubmed_with_sync_api( | |
self, query: str, top_k: int, search_type: str | |
) -> List[str]: | |
""" | |
使用PubMedApi的search_database方法,但通过异步包装来提升并发效率 | |
Args: | |
query: 搜索查询 | |
top_k: 返回结果数量 | |
search_type: 搜索类型 | |
Returns: | |
文章ID列表 | |
""" | |
try: | |
# 在线程池中运行同步的search_database方法 | |
loop = asyncio.get_event_loop() | |
pubmed_api = PubMedApi() | |
# 使用run_in_executor来异步执行同步方法 | |
id_list = await loop.run_in_executor( | |
None, # 使用默认线程池 | |
pubmed_api.search_database, | |
query, | |
top_k, | |
search_type, | |
) | |
return id_list | |
except Exception as e: | |
logger.error(f"Error in PubMed search for query '{query}': {e}") | |
raise e | |
async def fetch_article_details( | |
self, articles_id_list: List[str] | |
) -> List[BaseBioDocument]: | |
"""根据文章ID从pubmed获取文章详细信息""" | |
if not articles_id_list: | |
return [] | |
# 将articles_id_list去重 | |
articles_id_list = list(set(articles_id_list)) | |
# 将articles_id_list以group_size个一组切分成不同的列表 | |
group_size = 80 | |
articles_id_groups = [ | |
articles_id_list[i : i + group_size] | |
for i in range(0, len(articles_id_list), group_size) | |
] | |
try: | |
# 并发获取所有组的详细信息 | |
batch_tasks = [] | |
for ids in articles_id_groups: | |
pubmed_async_api = PubMedAsyncApi() | |
task = pubmed_async_api.fetch_details(id_list=ids) | |
batch_tasks.append(task) | |
task_results = await asyncio.gather(*batch_tasks, return_exceptions=True) | |
fetch_results = [] | |
for result in task_results: | |
if isinstance(result, Exception): | |
logger.error(f"Error in fetch_details: {result}") | |
continue | |
fetch_results.extend(result) | |
except Exception as e: | |
logger.error(f"Error in concurrent fetch_details: {e}") | |
return [] | |
# 转换为BioDocument对象 | |
all_results = [ | |
create_bio_document( | |
title=result["title"], | |
abstract=result["abstract"], | |
authors=self.process_authors(result["authors"]), | |
doi=result["doi"], | |
source=self.data_source, | |
source_id=result["pmid"], | |
pub_date=result["pub_date"], | |
journal=result["journal"], | |
text=result["abstract"], | |
url=f'https://pubmed.ncbi.nlm.nih.gov/{result["pmid"]}', | |
) | |
for result in fetch_results | |
] | |
return all_results | |
def process_authors(self, author_list: List[Dict]) -> str: | |
"""处理作者列表,将其转换为字符串""" | |
return ", ".join( | |
[f"{author['forename']} {author['lastname']}" for author in author_list] | |
) | |