Spaces:
Sleeping
Sleeping
"""生物医学聊天服务模块,提供RAG问答和流式响应功能。""" | |
import datetime | |
import json | |
import time | |
from typing import Any, AsyncGenerator, List | |
from openai import AsyncOpenAI | |
from openai.types.chat import ChatCompletionMessageParam | |
from bio_requests.chat_request import ChatRequest | |
from bio_requests.rag_request import RagRequest | |
from config.global_storage import get_model_config | |
from search_service.pubmed_search import PubMedSearchService | |
from search_service.web_search import WebSearchService | |
from service.query_rewrite import QueryRewriteService | |
from service.rerank import RerankService | |
from utils.bio_logger import bio_logger as logger | |
from utils.i18n_util import get_error_message, get_label_message | |
from utils.token_util import num_tokens_from_messages, num_tokens_from_text | |
from utils.snowflake_id import snowflake_id_str | |
class ChatService: | |
"""生物医学聊天服务,提供RAG问答和流式响应功能。""" | |
def __init__(self): | |
self.pubmed_search_service = PubMedSearchService() | |
self.web_search_service = WebSearchService() | |
self.query_rewrite_service = QueryRewriteService() | |
self.rag_request = RagRequest() | |
self.rerank_service = RerankService() | |
self.model_config = get_model_config() | |
def _initialize_rag_request(self, chat_request: ChatRequest) -> None: | |
"""初始化RAG请求参数""" | |
self.rag_request.query = chat_request.query | |
async def generate_stream(self, chat_request: ChatRequest): | |
""" | |
Generate a stream of messages for the chat request. | |
Args: | |
chat_request: 聊天请求 | |
""" | |
start_time = time.time() | |
try: | |
# 初始化RAG请求 | |
self._initialize_rag_request(chat_request) | |
# PubMed搜索 | |
logger.info("QA-RAG: Start search pubmed...") | |
pubmed_results = await self._search_pubmed(chat_request) | |
pubmed_task_text = self._generate_pubmed_search_task_text(pubmed_results) | |
yield pubmed_task_text | |
logger.info( | |
f"QA-RAG: Finished search pubmed, length: {len(pubmed_results)}" | |
) | |
# Web搜索 | |
web_results = [] | |
logger.info("QA-RAG: Start search web...") | |
web_urls, task_text = await self._search_web() | |
logger.info("QA-RAG: Finished search web...") | |
web_results = ( | |
await self.web_search_service.enrich_url_results_with_contents(web_urls) | |
) | |
yield task_text | |
# 创建消息 | |
messages, citation_list = self._create_messages( | |
pubmed_results, web_results, chat_request | |
) | |
citation_text = self._generate_citation_text(citation_list) | |
yield citation_text | |
# 流式聊天完成 | |
async for content in self._stream_chat_completion(messages): | |
yield content | |
logger.info( | |
f"Finished search and chat, query: [{chat_request.query}], total time: {time.time() - start_time:.2f}s" | |
) | |
except Exception as e: | |
logger.error(f"Error occurred: {e}") | |
# 使用上下文中的语言返回错误消息 | |
error_msg = get_error_message("llm_service_error") | |
yield f"data: {error_msg}\n\n" | |
return | |
def _generate_citation_text(self, citation_list: List[Any]) -> str: | |
"""生成引用文本""" | |
return f""" | |
```bdd-resource-lookup | |
{json.dumps(citation_list)} | |
``` | |
""" | |
async def _search_pubmed(self, chat_request: ChatRequest) -> List[Any]: | |
"""执行PubMed搜索""" | |
try: | |
logger.info(f"query: {chat_request.query}, Using pubmed search...") | |
self.rag_request.top_k = self.model_config["qa-topk"]["pubmed"] | |
self.rag_request.pubmed_topk = self.model_config["qa-topk"]["pubmed"] | |
start_search_time = time.time() | |
pubmed_results = await self.pubmed_search_service.search(self.rag_request) | |
end_search_time = time.time() | |
logger.info( | |
f"length of pubmed_results: {len(pubmed_results)},time used:{end_search_time - start_search_time:.2f}s" | |
) | |
pubmed_results = pubmed_results[0 : self.rag_request.top_k] | |
logger.info(f"length of pubmed_results after rerank: {len(pubmed_results)}") | |
end_rerank_time = time.time() | |
logger.info( | |
f"Reranked {len(pubmed_results)} results,time used:{end_rerank_time - end_search_time:.2f}s" | |
) | |
return pubmed_results | |
except Exception as e: | |
logger.error(f"error in search pubmed: {e}") | |
return [] | |
async def _search_web(self) -> tuple[List[Any], str]: | |
"""执行Web搜索""" | |
web_topk = self.model_config["qa-topk"]["web"] | |
try: | |
# 尝试获取重写后的查询 | |
query_list = await self.query_rewrite_service.query_split_for_web( | |
self.rag_request.query | |
) | |
# 安全获取重写查询,如果query_list为空或获取失败则使用原始查询 | |
serper_query = ( | |
query_list[0].get("query_item", "").strip() if query_list else None | |
) | |
# 如果重写查询为空,则回退到原始查询 | |
if not serper_query: | |
serper_query = self.rag_request.query | |
# 使用最终确定的查询执行搜索 | |
url_results = await self.web_search_service.search_serper( | |
query=serper_query, max_results=web_topk | |
) | |
except Exception as e: | |
logger.error(f"error in query rewrite web or serper retrieval: {e}") | |
# 出错时使用原始查询进行搜索 | |
url_results = await self.web_search_service.search_serper( | |
query=self.rag_request.query, max_results=web_topk | |
) | |
# 生成任务文本 | |
task_text = self._generate_web_search_task_text(url_results) | |
return url_results, task_text | |
def _generate_pubmed_search_task_text(self, pubmed_results: List[Any]) -> str: | |
"""生成PubMed搜索任务文本""" | |
docs = [ | |
{ | |
"docId": result.bio_id, | |
"url": result.url, | |
"title": result.title, | |
"description": result.text, | |
"author": result.authors, | |
"JournalInfo": result.journal.get("title", "") | |
+ "." | |
+ result.journal.get("year", "") | |
+ "." | |
+ ( | |
result.journal.get("start_page", "") | |
+ "-" | |
+ result.journal.get("end_page", "") | |
+ "." | |
if result.journal.get("start_page") | |
and result.journal.get("end_page") | |
else "" | |
) | |
+ "doi:" | |
+ result.doi, | |
"PMID": result.source_id, | |
} | |
for result in pubmed_results | |
] | |
label = get_label_message("pubmed_search") | |
return self._generate_task_text(label, "pubmed", docs) | |
def _generate_web_search_task_text(self, url_results: List[Any]) -> str: | |
"""生成Web搜索任务文本""" | |
web_docs = [ | |
{ | |
"docId": snowflake_id_str(), | |
"url": url_result.url, | |
"title": url_result.title, | |
"description": url_result.description, | |
} | |
for url_result in url_results | |
] | |
logger.info(f"URL Results: {web_docs}") | |
label = get_label_message("web_search") | |
return self._generate_task_text(label, "webSearch", web_docs) | |
def _generate_task_text(self, label, source, bio_docs: List[Any]): | |
"""生成任务文本""" | |
task = { | |
"type": "search", | |
"label": label, | |
"hoverable": True, | |
"handler": "QASearch", | |
"status": "running", | |
"handlerParam": {"source": source, "bioDocs": bio_docs}, | |
} | |
return f""" | |
```bdd-chat-agent-task | |
{json.dumps(task)} | |
``` | |
""" | |
def _build_document_texts( | |
self, pubmed_results: List[Any], web_results: List[Any] | |
) -> tuple[str, str, List[Any]]: | |
"""构建文档文本""" | |
# 个人向量搜索结果 | |
citation_list = [] | |
temp_doc_list = [] | |
# pubmed结果 | |
pubmed_offset = 0 | |
for idx, doc in enumerate(pubmed_results): | |
_idx = idx + 1 + pubmed_offset | |
temp_doc_list.append( | |
"[document {idx} begin] title: {title}. content: {abstract} [document {idx} end]".format( | |
idx=_idx, title=doc.title, abstract=doc.abstract | |
) | |
) | |
citation_list.append( | |
{"source": "pubmed", "docId": doc.bio_id, "citation": _idx} | |
) | |
pubmed_texts = "\n".join(temp_doc_list) | |
temp_doc_list = [] | |
# 联网搜索结果 | |
web_offset = pubmed_offset + len(pubmed_results) | |
for idx, doc in enumerate(web_results): | |
_idx = idx + 1 + web_offset | |
temp_doc_list.append( | |
"[document {idx} begin] title: {title}. content: {content} [document {idx} end]".format( | |
idx=_idx, title=doc.title, content=doc.text | |
) | |
) | |
citation_list.append( | |
{"source": "webSearch", "docId": doc.bio_id, "citation": _idx} | |
) | |
web_texts = "\n".join(temp_doc_list) | |
return pubmed_texts, web_texts, citation_list | |
def _truncate_documents_to_token_limit( | |
self, | |
pubmed_texts: str, | |
web_texts: str, | |
chat_request: ChatRequest, | |
) -> tuple[List[ChatCompletionMessageParam], int]: | |
"""截断文档以符合token限制""" | |
pubmed_list = pubmed_texts.split("\n") | |
web_list = web_texts.split("\n") | |
today = datetime.date.today() | |
openai_client_rag_prompt = self.model_config["chat"]["rag_prompt"] | |
max_tokens = self.model_config["qa-prompt-max-token"]["max_tokens"] | |
pubmed_token_limit = max_tokens | |
web_token_limit = 60000 | |
personal_vector_token_limit = 80000 | |
if chat_request.is_pubmed and chat_request.is_web: | |
personal_vector_token_limit = 40000 | |
pubmed_token_limit = 20000 | |
web_token_limit = 60000 | |
elif chat_request.is_pubmed and not chat_request.is_web: | |
personal_vector_token_limit = 80000 | |
pubmed_token_limit = 40000 | |
web_token_limit = 0 | |
elif chat_request.is_pubmed and chat_request.is_web: | |
personal_vector_token_limit = 0 | |
pubmed_token_limit = 60000 | |
web_token_limit = 60000 | |
elif chat_request.is_pubmed and not chat_request.is_web: | |
personal_vector_token_limit = 0 | |
pubmed_token_limit = 120000 | |
web_token_limit = 0 | |
def calculate_num_tokens( | |
pubmed_list: List[str], web_list: List[str] | |
) -> tuple[int, List[ChatCompletionMessageParam]]: | |
# 合并结果 | |
docs_text = "\n".join(pubmed_list + web_list) | |
pt = ( | |
openai_client_rag_prompt.replace("{search_results}", docs_text) | |
.replace("{cur_date}", str(today)) | |
.replace("{question}", chat_request.query) | |
) | |
messages: List[ChatCompletionMessageParam] = [ | |
{"role": "user", "content": pt} | |
] | |
# 计算token数 | |
num_tokens = num_tokens_from_messages(messages) | |
return num_tokens, messages | |
while True: | |
num_tokens, messages = calculate_num_tokens(pubmed_list, web_list) | |
if num_tokens <= max_tokens: | |
break | |
# 如果超过token限制,则按照比例进行截断 | |
logger.info( | |
f"start truncate documents to token limit: max_tokens: {max_tokens}" | |
) | |
logger.info( | |
f"pubmed_token_limit: {pubmed_token_limit}, web_token_limit: {web_token_limit}, personal_vector_token_limit: {personal_vector_token_limit}" | |
) | |
while True: | |
if num_tokens_from_text("\n".join(pubmed_list)) > pubmed_token_limit: | |
pubmed_list.pop() | |
else: | |
break | |
# 截断pubmed之后,重新计算token数,如果token数小于max_tokens,则停止截断 | |
num_tokens, messages = calculate_num_tokens(pubmed_list, web_list) | |
if num_tokens <= max_tokens: | |
break | |
while True: | |
if num_tokens_from_text("\n".join(web_list)) > web_token_limit: | |
web_list.pop() | |
else: | |
break | |
# 截断web之后,重新计算token数,如果token数小于max_tokens,则停止截断 | |
num_tokens, messages = calculate_num_tokens(pubmed_list, web_list) | |
if num_tokens <= max_tokens: | |
break | |
logger.info(f"Final token count: {num_tokens}") | |
return messages, num_tokens | |
def _create_messages( | |
self, | |
pubmed_results: List[Any], | |
web_results: List[Any], | |
chat_request: ChatRequest, | |
) -> tuple[List[ChatCompletionMessageParam], List[Any]]: | |
"""创建聊天消息""" | |
if len(pubmed_results) == 0 and len(web_results) == 0: | |
logger.info(f"No results found for query: {chat_request.query}") | |
pt = chat_request.query | |
messages: List[ChatCompletionMessageParam] = [ | |
{"role": "user", "content": pt} | |
] | |
num_tokens = num_tokens_from_messages(messages) | |
logger.info(f"Total tokens: {num_tokens}") | |
return messages, [] | |
# 构建文档文本 | |
pubmed_texts, web_texts, citation_list = self._build_document_texts( | |
pubmed_results, web_results | |
) | |
# 截断文档以符合token限制 | |
messages, num_tokens = self._truncate_documents_to_token_limit( | |
pubmed_texts, web_texts, chat_request | |
) | |
return messages, citation_list | |
async def _stream_chat_completion( | |
self, messages: List[ChatCompletionMessageParam] | |
) -> AsyncGenerator[bytes, None]: | |
"""流式聊天完成,支持qa-llm的main/backup配置""" | |
async def create_stream_with_config( | |
qa_config: dict, config_name: str | |
) -> AsyncGenerator[bytes, None]: | |
"""使用指定配置创建流式响应""" | |
try: | |
logger.info(f"Using qa-llm {config_name} configuration") | |
client = AsyncOpenAI( | |
api_key=qa_config["api_key"], | |
base_url=qa_config["base_url"], | |
) | |
chat_start_time = time.time() | |
# 创建聊天完成流 | |
stream = await client.chat.completions.create( | |
model=qa_config["model"], | |
messages=messages, | |
stream=True, | |
temperature=qa_config["temperature"], | |
max_tokens=qa_config["max_tokens"], | |
) | |
logger.info( | |
f"Finished chat completion with {config_name} config, total time: {time.time() - chat_start_time:.2f}s" | |
) | |
is_start_answer = False | |
# 处理流式响应 | |
async for chunk in stream: | |
if chunk.choices and (content := chunk.choices[0].delta.content): | |
if not is_start_answer: | |
is_start_answer = True | |
yield content.encode("utf-8") | |
except Exception as e: | |
logger.info(f"qa-llm {config_name} configuration failed: {e}") | |
raise e | |
async def with_fallback(main_func, backup_func): | |
"""高阶函数:尝试主函数,失败时使用备选函数""" | |
try: | |
async for content in main_func(): | |
yield content | |
except Exception as main_error: | |
logger.info("Main config failed, falling back to backup configuration") | |
try: | |
async for content in backup_func(): | |
yield content | |
except Exception as backup_error: | |
logger.error( | |
f"Both main and backup qa-llm configurations failed. " | |
f"Main error: {main_error}, Backup error: {backup_error}" | |
) | |
raise backup_error | |
# 创建主用和备选配置的生成器函数 | |
async def main_stream(): | |
logger.info("Using main qa-llm configuration") | |
async for content in create_stream_with_config( | |
self.model_config["qa-llm"]["main"], "main" | |
): | |
yield content | |
async def backup_stream(): | |
logger.info("Using backup qa-llm configuration") | |
async for content in create_stream_with_config( | |
self.model_config["qa-llm"]["backup"], "backup" | |
): | |
yield content | |
# 使用fallback逻辑 | |
async for content in with_fallback(main_stream, backup_stream): | |
yield content | |