"""生物医学聊天服务模块,提供RAG问答和流式响应功能。""" import datetime import json import time from typing import Any, AsyncGenerator, List from openai import AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam from bio_requests.chat_request import ChatRequest from bio_requests.rag_request import RagRequest from config.global_storage import get_model_config from search_service.pubmed_search import PubMedSearchService from search_service.web_search import WebSearchService from service.query_rewrite import QueryRewriteService from service.rerank import RerankService from utils.bio_logger import bio_logger as logger from utils.i18n_util import get_error_message, get_label_message from utils.token_util import num_tokens_from_messages, num_tokens_from_text from utils.snowflake_id import snowflake_id_str class ChatService: """生物医学聊天服务,提供RAG问答和流式响应功能。""" def __init__(self): self.pubmed_search_service = PubMedSearchService() self.web_search_service = WebSearchService() self.query_rewrite_service = QueryRewriteService() self.rag_request = RagRequest() self.rerank_service = RerankService() self.model_config = get_model_config() def _initialize_rag_request(self, chat_request: ChatRequest) -> None: """初始化RAG请求参数""" self.rag_request.query = chat_request.query async def generate_stream(self, chat_request: ChatRequest): """ Generate a stream of messages for the chat request. Args: chat_request: 聊天请求 """ start_time = time.time() try: # 初始化RAG请求 self._initialize_rag_request(chat_request) # PubMed搜索 logger.info("QA-RAG: Start search pubmed...") pubmed_results = await self._search_pubmed(chat_request) pubmed_task_text = self._generate_pubmed_search_task_text(pubmed_results) yield pubmed_task_text logger.info( f"QA-RAG: Finished search pubmed, length: {len(pubmed_results)}" ) # Web搜索 web_results = [] logger.info("QA-RAG: Start search web...") web_urls, task_text = await self._search_web() logger.info("QA-RAG: Finished search web...") web_results = ( await self.web_search_service.enrich_url_results_with_contents(web_urls) ) yield task_text # 创建消息 messages, citation_list = self._create_messages( pubmed_results, web_results, chat_request ) citation_text = self._generate_citation_text(citation_list) yield citation_text # 流式聊天完成 async for content in self._stream_chat_completion(messages): yield content logger.info( f"Finished search and chat, query: [{chat_request.query}], total time: {time.time() - start_time:.2f}s" ) except Exception as e: logger.error(f"Error occurred: {e}") # 使用上下文中的语言返回错误消息 error_msg = get_error_message("llm_service_error") yield f"data: {error_msg}\n\n" return def _generate_citation_text(self, citation_list: List[Any]) -> str: """生成引用文本""" return f""" ```bdd-resource-lookup {json.dumps(citation_list)} ``` """ async def _search_pubmed(self, chat_request: ChatRequest) -> List[Any]: """执行PubMed搜索""" try: logger.info(f"query: {chat_request.query}, Using pubmed search...") self.rag_request.top_k = self.model_config["qa-topk"]["pubmed"] self.rag_request.pubmed_topk = self.model_config["qa-topk"]["pubmed"] start_search_time = time.time() pubmed_results = await self.pubmed_search_service.search(self.rag_request) end_search_time = time.time() logger.info( f"length of pubmed_results: {len(pubmed_results)},time used:{end_search_time - start_search_time:.2f}s" ) pubmed_results = pubmed_results[0 : self.rag_request.top_k] logger.info(f"length of pubmed_results after rerank: {len(pubmed_results)}") end_rerank_time = time.time() logger.info( f"Reranked {len(pubmed_results)} results,time used:{end_rerank_time - end_search_time:.2f}s" ) return pubmed_results except Exception as e: logger.error(f"error in search pubmed: {e}") return [] async def _search_web(self) -> tuple[List[Any], str]: """执行Web搜索""" web_topk = self.model_config["qa-topk"]["web"] try: # 尝试获取重写后的查询 query_list = await self.query_rewrite_service.query_split_for_web( self.rag_request.query ) # 安全获取重写查询,如果query_list为空或获取失败则使用原始查询 serper_query = ( query_list[0].get("query_item", "").strip() if query_list else None ) # 如果重写查询为空,则回退到原始查询 if not serper_query: serper_query = self.rag_request.query # 使用最终确定的查询执行搜索 url_results = await self.web_search_service.search_serper( query=serper_query, max_results=web_topk ) except Exception as e: logger.error(f"error in query rewrite web or serper retrieval: {e}") # 出错时使用原始查询进行搜索 url_results = await self.web_search_service.search_serper( query=self.rag_request.query, max_results=web_topk ) # 生成任务文本 task_text = self._generate_web_search_task_text(url_results) return url_results, task_text def _generate_pubmed_search_task_text(self, pubmed_results: List[Any]) -> str: """生成PubMed搜索任务文本""" docs = [ { "docId": result.bio_id, "url": result.url, "title": result.title, "description": result.text, "author": result.authors, "JournalInfo": result.journal.get("title", "") + "." + result.journal.get("year", "") + "." + ( result.journal.get("start_page", "") + "-" + result.journal.get("end_page", "") + "." if result.journal.get("start_page") and result.journal.get("end_page") else "" ) + "doi:" + result.doi, "PMID": result.source_id, } for result in pubmed_results ] label = get_label_message("pubmed_search") return self._generate_task_text(label, "pubmed", docs) def _generate_web_search_task_text(self, url_results: List[Any]) -> str: """生成Web搜索任务文本""" web_docs = [ { "docId": snowflake_id_str(), "url": url_result.url, "title": url_result.title, "description": url_result.description, } for url_result in url_results ] logger.info(f"URL Results: {web_docs}") label = get_label_message("web_search") return self._generate_task_text(label, "webSearch", web_docs) def _generate_task_text(self, label, source, bio_docs: List[Any]): """生成任务文本""" task = { "type": "search", "label": label, "hoverable": True, "handler": "QASearch", "status": "running", "handlerParam": {"source": source, "bioDocs": bio_docs}, } return f""" ```bdd-chat-agent-task {json.dumps(task)} ``` """ def _build_document_texts( self, pubmed_results: List[Any], web_results: List[Any] ) -> tuple[str, str, List[Any]]: """构建文档文本""" # 个人向量搜索结果 citation_list = [] temp_doc_list = [] # pubmed结果 pubmed_offset = 0 for idx, doc in enumerate(pubmed_results): _idx = idx + 1 + pubmed_offset temp_doc_list.append( "[document {idx} begin] title: {title}. content: {abstract} [document {idx} end]".format( idx=_idx, title=doc.title, abstract=doc.abstract ) ) citation_list.append( {"source": "pubmed", "docId": doc.bio_id, "citation": _idx} ) pubmed_texts = "\n".join(temp_doc_list) temp_doc_list = [] # 联网搜索结果 web_offset = pubmed_offset + len(pubmed_results) for idx, doc in enumerate(web_results): _idx = idx + 1 + web_offset temp_doc_list.append( "[document {idx} begin] title: {title}. content: {content} [document {idx} end]".format( idx=_idx, title=doc.title, content=doc.text ) ) citation_list.append( {"source": "webSearch", "docId": doc.bio_id, "citation": _idx} ) web_texts = "\n".join(temp_doc_list) return pubmed_texts, web_texts, citation_list def _truncate_documents_to_token_limit( self, pubmed_texts: str, web_texts: str, chat_request: ChatRequest, ) -> tuple[List[ChatCompletionMessageParam], int]: """截断文档以符合token限制""" pubmed_list = pubmed_texts.split("\n") web_list = web_texts.split("\n") today = datetime.date.today() openai_client_rag_prompt = self.model_config["chat"]["rag_prompt"] max_tokens = self.model_config["qa-prompt-max-token"]["max_tokens"] pubmed_token_limit = max_tokens web_token_limit = 60000 personal_vector_token_limit = 80000 if chat_request.is_pubmed and chat_request.is_web: personal_vector_token_limit = 40000 pubmed_token_limit = 20000 web_token_limit = 60000 elif chat_request.is_pubmed and not chat_request.is_web: personal_vector_token_limit = 80000 pubmed_token_limit = 40000 web_token_limit = 0 elif chat_request.is_pubmed and chat_request.is_web: personal_vector_token_limit = 0 pubmed_token_limit = 60000 web_token_limit = 60000 elif chat_request.is_pubmed and not chat_request.is_web: personal_vector_token_limit = 0 pubmed_token_limit = 120000 web_token_limit = 0 def calculate_num_tokens( pubmed_list: List[str], web_list: List[str] ) -> tuple[int, List[ChatCompletionMessageParam]]: # 合并结果 docs_text = "\n".join(pubmed_list + web_list) pt = ( openai_client_rag_prompt.replace("{search_results}", docs_text) .replace("{cur_date}", str(today)) .replace("{question}", chat_request.query) ) messages: List[ChatCompletionMessageParam] = [ {"role": "user", "content": pt} ] # 计算token数 num_tokens = num_tokens_from_messages(messages) return num_tokens, messages while True: num_tokens, messages = calculate_num_tokens(pubmed_list, web_list) if num_tokens <= max_tokens: break # 如果超过token限制,则按照比例进行截断 logger.info( f"start truncate documents to token limit: max_tokens: {max_tokens}" ) logger.info( f"pubmed_token_limit: {pubmed_token_limit}, web_token_limit: {web_token_limit}, personal_vector_token_limit: {personal_vector_token_limit}" ) while True: if num_tokens_from_text("\n".join(pubmed_list)) > pubmed_token_limit: pubmed_list.pop() else: break # 截断pubmed之后,重新计算token数,如果token数小于max_tokens,则停止截断 num_tokens, messages = calculate_num_tokens(pubmed_list, web_list) if num_tokens <= max_tokens: break while True: if num_tokens_from_text("\n".join(web_list)) > web_token_limit: web_list.pop() else: break # 截断web之后,重新计算token数,如果token数小于max_tokens,则停止截断 num_tokens, messages = calculate_num_tokens(pubmed_list, web_list) if num_tokens <= max_tokens: break logger.info(f"Final token count: {num_tokens}") return messages, num_tokens def _create_messages( self, pubmed_results: List[Any], web_results: List[Any], chat_request: ChatRequest, ) -> tuple[List[ChatCompletionMessageParam], List[Any]]: """创建聊天消息""" if len(pubmed_results) == 0 and len(web_results) == 0: logger.info(f"No results found for query: {chat_request.query}") pt = chat_request.query messages: List[ChatCompletionMessageParam] = [ {"role": "user", "content": pt} ] num_tokens = num_tokens_from_messages(messages) logger.info(f"Total tokens: {num_tokens}") return messages, [] # 构建文档文本 pubmed_texts, web_texts, citation_list = self._build_document_texts( pubmed_results, web_results ) # 截断文档以符合token限制 messages, num_tokens = self._truncate_documents_to_token_limit( pubmed_texts, web_texts, chat_request ) return messages, citation_list async def _stream_chat_completion( self, messages: List[ChatCompletionMessageParam] ) -> AsyncGenerator[bytes, None]: """流式聊天完成,支持qa-llm的main/backup配置""" async def create_stream_with_config( qa_config: dict, config_name: str ) -> AsyncGenerator[bytes, None]: """使用指定配置创建流式响应""" try: logger.info(f"Using qa-llm {config_name} configuration") client = AsyncOpenAI( api_key=qa_config["api_key"], base_url=qa_config["base_url"], ) chat_start_time = time.time() # 创建聊天完成流 stream = await client.chat.completions.create( model=qa_config["model"], messages=messages, stream=True, temperature=qa_config["temperature"], max_tokens=qa_config["max_tokens"], ) logger.info( f"Finished chat completion with {config_name} config, total time: {time.time() - chat_start_time:.2f}s" ) is_start_answer = False # 处理流式响应 async for chunk in stream: if chunk.choices and (content := chunk.choices[0].delta.content): if not is_start_answer: is_start_answer = True yield content.encode("utf-8") except Exception as e: logger.info(f"qa-llm {config_name} configuration failed: {e}") raise e async def with_fallback(main_func, backup_func): """高阶函数:尝试主函数,失败时使用备选函数""" try: async for content in main_func(): yield content except Exception as main_error: logger.info("Main config failed, falling back to backup configuration") try: async for content in backup_func(): yield content except Exception as backup_error: logger.error( f"Both main and backup qa-llm configurations failed. " f"Main error: {main_error}, Backup error: {backup_error}" ) raise backup_error # 创建主用和备选配置的生成器函数 async def main_stream(): logger.info("Using main qa-llm configuration") async for content in create_stream_with_config( self.model_config["qa-llm"]["main"], "main" ): yield content async def backup_stream(): logger.info("Using backup qa-llm configuration") async for content in create_stream_with_config( self.model_config["qa-llm"]["backup"], "backup" ): yield content # 使用fallback逻辑 async for content in with_fallback(main_stream, backup_stream): yield content