jackkuo's picture
add QA
79899c0
import asyncio
import re
import time
from typing import Dict, List
from dto.bio_document import BaseBioDocument, create_bio_document
from search_service.base_search import BaseSearchService
from bio_requests.rag_request import RagRequest
from utils.bio_logger import bio_logger as logger
from service.query_rewrite import QueryRewriteService
from service.pubmed_api import PubMedApi
from service.pubmed_async_api import PubMedAsyncApi
from config.global_storage import get_model_config
class PubMedSearchService(BaseSearchService):
def __init__(self):
self.query_rewrite_service = QueryRewriteService()
self.model_config = get_model_config()
self.pubmed_topk = self.model_config["recall"]["pubmed_topk"]
self.es_topk = self.model_config["recall"]["es_topk"]
self.data_source = "pubmed"
async def get_query_list(self, rag_request: RagRequest) -> List[Dict]:
"""根据RagRequest获取查询列表"""
if rag_request.is_rewrite:
query_list = await self.query_rewrite_service.query_split(rag_request.query)
logger.info(f"length of query_list after query_split: {len(query_list)}")
if len(query_list) == 0:
logger.info("query_list is empty, use query_split_for_simple")
query_list = await self.query_rewrite_service.query_split_for_simple(
rag_request.query
)
logger.info(
f"length of query_list after query_split_for_simple: {len(query_list)}"
)
self.pubmed_topk = rag_request.pubmed_topk
self.es_topk = rag_request.pubmed_topk
else:
self.pubmed_topk = rag_request.top_k
self.es_topk = rag_request.top_k
query_list = [
{
"query_item": rag_request.query,
"search_type": rag_request.search_type,
}
]
return query_list
async def search(self, rag_request: RagRequest) -> List[BaseBioDocument]:
"""异步搜索PubMed数据库"""
if not rag_request.query:
return []
start_time = time.time()
query_list = await self.get_query_list(rag_request)
# 使用异步并发替代线程池
articles_id_list = []
es_articles = []
try:
# 创建异步任务列表,使用PubMedApi的search_database方法
async_tasks = []
for query in query_list:
task = self._search_pubmed_with_sync_api(
query["query_item"], self.pubmed_topk, query["search_type"]
)
async_tasks.append((query, task))
# 并发执行所有搜索任务
results = await asyncio.gather(
*[task for _, task in async_tasks], return_exceptions=True
)
# 处理结果
for i, (query, _) in enumerate(async_tasks):
result = results[i]
if isinstance(result, Exception):
logger.error(f"Error in search pubmed: {result}")
else:
articles_id_list.extend(result)
except Exception as e:
logger.error(f"Error in concurrent PubMed search: {e}")
# 获取文章详细信息
pubmed_docs = await self.fetch_article_details(articles_id_list)
# 合并结果
all_results = []
all_results.extend(pubmed_docs)
all_results.extend(es_articles)
logger.info(
f"""Finished searching PubMed, query:{rag_request.query},
total articles: {len(articles_id_list)}, total time: {time.time() - start_time:.2f}s"""
)
return all_results
async def _search_pubmed_with_sync_api(
self, query: str, top_k: int, search_type: str
) -> List[str]:
"""
使用PubMedApi的search_database方法,但通过异步包装来提升并发效率
Args:
query: 搜索查询
top_k: 返回结果数量
search_type: 搜索类型
Returns:
文章ID列表
"""
try:
# 在线程池中运行同步的search_database方法
loop = asyncio.get_event_loop()
pubmed_api = PubMedApi()
# 使用run_in_executor来异步执行同步方法
id_list = await loop.run_in_executor(
None, # 使用默认线程池
pubmed_api.search_database,
query,
top_k,
search_type,
)
return id_list
except Exception as e:
logger.error(f"Error in PubMed search for query '{query}': {e}")
raise e
async def fetch_article_details(
self, articles_id_list: List[str]
) -> List[BaseBioDocument]:
"""根据文章ID从pubmed获取文章详细信息"""
if not articles_id_list:
return []
# 将articles_id_list去重
articles_id_list = list(set(articles_id_list))
# 将articles_id_list以group_size个一组切分成不同的列表
group_size = 80
articles_id_groups = [
articles_id_list[i : i + group_size]
for i in range(0, len(articles_id_list), group_size)
]
try:
# 并发获取所有组的详细信息
batch_tasks = []
for ids in articles_id_groups:
pubmed_async_api = PubMedAsyncApi()
task = pubmed_async_api.fetch_details(id_list=ids)
batch_tasks.append(task)
task_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
fetch_results = []
for result in task_results:
if isinstance(result, Exception):
logger.error(f"Error in fetch_details: {result}")
continue
fetch_results.extend(result)
except Exception as e:
logger.error(f"Error in concurrent fetch_details: {e}")
return []
# 转换为BioDocument对象
all_results = [
create_bio_document(
title=result["title"],
abstract=result["abstract"],
authors=self.process_authors(result["authors"]),
doi=result["doi"],
source=self.data_source,
source_id=result["pmid"],
pub_date=result["pub_date"],
journal=result["journal"],
text=result["abstract"],
url=f'https://pubmed.ncbi.nlm.nih.gov/{result["pmid"]}',
)
for result in fetch_results
]
return all_results
def process_authors(self, author_list: List[Dict]) -> str:
"""处理作者列表,将其转换为字符串"""
return ", ".join(
[f"{author['forename']} {author['lastname']}" for author in author_list]
)