File size: 17,633 Bytes
79899c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
"""生物医学聊天服务模块,提供RAG问答和流式响应功能。"""

import datetime
import json
import time
from typing import Any, AsyncGenerator, List

from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam

from bio_requests.chat_request import ChatRequest
from bio_requests.rag_request import RagRequest
from config.global_storage import get_model_config
from search_service.pubmed_search import PubMedSearchService

from search_service.web_search import WebSearchService
from service.query_rewrite import QueryRewriteService
from service.rerank import RerankService
from utils.bio_logger import bio_logger as logger
from utils.i18n_util import get_error_message, get_label_message
from utils.token_util import num_tokens_from_messages, num_tokens_from_text
from utils.snowflake_id import snowflake_id_str


class ChatService:
    """生物医学聊天服务,提供RAG问答和流式响应功能。"""

    def __init__(self):
        self.pubmed_search_service = PubMedSearchService()
        self.web_search_service = WebSearchService()
        self.query_rewrite_service = QueryRewriteService()

        self.rag_request = RagRequest()
        self.rerank_service = RerankService()
        self.model_config = get_model_config()

    def _initialize_rag_request(self, chat_request: ChatRequest) -> None:
        """初始化RAG请求参数"""
        self.rag_request.query = chat_request.query

    async def generate_stream(self, chat_request: ChatRequest):
        """
        Generate a stream of messages for the chat request.

        Args:
            chat_request: 聊天请求
        """

        start_time = time.time()

        try:
            # 初始化RAG请求
            self._initialize_rag_request(chat_request)

            # PubMed搜索
            logger.info("QA-RAG: Start search pubmed...")

            pubmed_results = await self._search_pubmed(chat_request)

            pubmed_task_text = self._generate_pubmed_search_task_text(pubmed_results)
            yield pubmed_task_text
            logger.info(
                f"QA-RAG: Finished search pubmed, length: {len(pubmed_results)}"
            )

            # Web搜索
            web_results = []

            logger.info("QA-RAG: Start search web...")

            web_urls, task_text = await self._search_web()
            logger.info("QA-RAG: Finished search web...")

            web_results = (
                await self.web_search_service.enrich_url_results_with_contents(web_urls)
            )

            yield task_text

            # 创建消息
            messages, citation_list = self._create_messages(
                pubmed_results, web_results, chat_request
            )
            citation_text = self._generate_citation_text(citation_list)
            yield citation_text
            # 流式聊天完成
            async for content in self._stream_chat_completion(messages):
                yield content

            logger.info(
                f"Finished search and chat, query: [{chat_request.query}], total time: {time.time() - start_time:.2f}s"
            )

        except Exception as e:
            logger.error(f"Error occurred: {e}")
            # 使用上下文中的语言返回错误消息
            error_msg = get_error_message("llm_service_error")
            yield f"data: {error_msg}\n\n"
            return

    def _generate_citation_text(self, citation_list: List[Any]) -> str:
        """生成引用文本"""

        return f"""
```bdd-resource-lookup
{json.dumps(citation_list)}
```
    """

    async def _search_pubmed(self, chat_request: ChatRequest) -> List[Any]:
        """执行PubMed搜索"""
        try:
            logger.info(f"query: {chat_request.query}, Using pubmed search...")
            self.rag_request.top_k = self.model_config["qa-topk"]["pubmed"]
            self.rag_request.pubmed_topk = self.model_config["qa-topk"]["pubmed"]

            start_search_time = time.time()
            pubmed_results = await self.pubmed_search_service.search(self.rag_request)
            end_search_time = time.time()

            logger.info(
                f"length of pubmed_results: {len(pubmed_results)},time used:{end_search_time - start_search_time:.2f}s"
            )
            pubmed_results = pubmed_results[0 : self.rag_request.top_k]
            logger.info(f"length of pubmed_results after rerank: {len(pubmed_results)}")

            end_rerank_time = time.time()
            logger.info(
                f"Reranked {len(pubmed_results)} results,time used:{end_rerank_time - end_search_time:.2f}s"
            )

            return pubmed_results
        except Exception as e:
            logger.error(f"error in search pubmed: {e}")
            return []

    async def _search_web(self) -> tuple[List[Any], str]:
        """执行Web搜索"""
        web_topk = self.model_config["qa-topk"]["web"]
        try:
            # 尝试获取重写后的查询
            query_list = await self.query_rewrite_service.query_split_for_web(
                self.rag_request.query
            )
            # 安全获取重写查询,如果query_list为空或获取失败则使用原始查询
            serper_query = (
                query_list[0].get("query_item", "").strip() if query_list else None
            )
            # 如果重写查询为空,则回退到原始查询
            if not serper_query:
                serper_query = self.rag_request.query
            # 使用最终确定的查询执行搜索
            url_results = await self.web_search_service.search_serper(
                query=serper_query, max_results=web_topk
            )
        except Exception as e:
            logger.error(f"error in query rewrite web or serper retrieval: {e}")
            # 出错时使用原始查询进行搜索
            url_results = await self.web_search_service.search_serper(
                query=self.rag_request.query, max_results=web_topk
            )

        # 生成任务文本
        task_text = self._generate_web_search_task_text(url_results)
        return url_results, task_text


    def _generate_pubmed_search_task_text(self, pubmed_results: List[Any]) -> str:
        """生成PubMed搜索任务文本"""
        docs = [
            {
                "docId": result.bio_id,
                "url": result.url,
                "title": result.title,
                "description": result.text,
                "author": result.authors,
                "JournalInfo": result.journal.get("title", "")
                + "."
                + result.journal.get("year", "")
                + "."
                + (
                    result.journal.get("start_page", "")
                    + "-"
                    + result.journal.get("end_page", "")
                    + "."
                    if result.journal.get("start_page")
                    and result.journal.get("end_page")
                    else ""
                )
                + "doi:"
                + result.doi,
                "PMID": result.source_id,
            }
            for result in pubmed_results
        ]
        label = get_label_message("pubmed_search")
        return self._generate_task_text(label, "pubmed", docs)

    def _generate_web_search_task_text(self, url_results: List[Any]) -> str:
        """生成Web搜索任务文本"""
        web_docs = [
            {
                "docId": snowflake_id_str(),
                "url": url_result.url,
                "title": url_result.title,
                "description": url_result.description,
            }
            for url_result in url_results
        ]

        logger.info(f"URL Results: {web_docs}")

        label = get_label_message("web_search")

        return self._generate_task_text(label, "webSearch", web_docs)

    def _generate_task_text(self, label, source, bio_docs: List[Any]):
        """生成任务文本"""
        task = {
            "type": "search",
            "label": label,
            "hoverable": True,
            "handler": "QASearch",
            "status": "running",
            "handlerParam": {"source": source, "bioDocs": bio_docs},
        }
        return f"""
```bdd-chat-agent-task
{json.dumps(task)}
``` 
"""

    def _build_document_texts(
        self, pubmed_results: List[Any], web_results: List[Any]
    ) -> tuple[str, str, List[Any]]:
        """构建文档文本"""
        # 个人向量搜索结果
        citation_list = []
        temp_doc_list = []

        # pubmed结果
        pubmed_offset = 0
        for idx, doc in enumerate(pubmed_results):
            _idx = idx + 1 + pubmed_offset
            temp_doc_list.append(
                "[document {idx} begin] title: {title}. content: {abstract} [document {idx} end]".format(
                    idx=_idx, title=doc.title, abstract=doc.abstract
                )
            )
            citation_list.append(
                {"source": "pubmed", "docId": doc.bio_id, "citation": _idx}
            )
        pubmed_texts = "\n".join(temp_doc_list)

        temp_doc_list = []
        # 联网搜索结果
        web_offset = pubmed_offset + len(pubmed_results)
        for idx, doc in enumerate(web_results):
            _idx = idx + 1 + web_offset
            temp_doc_list.append(
                "[document {idx} begin] title: {title}. content: {content} [document {idx} end]".format(
                    idx=_idx, title=doc.title, content=doc.text
                )
            )
            citation_list.append(
                {"source": "webSearch", "docId": doc.bio_id, "citation": _idx}
            )
        web_texts = "\n".join(temp_doc_list)

        return pubmed_texts, web_texts, citation_list

    def _truncate_documents_to_token_limit(
        self,
        pubmed_texts: str,
        web_texts: str,
        chat_request: ChatRequest,
    ) -> tuple[List[ChatCompletionMessageParam], int]:
        """截断文档以符合token限制"""
        pubmed_list = pubmed_texts.split("\n")
        web_list = web_texts.split("\n")

        today = datetime.date.today()
        openai_client_rag_prompt = self.model_config["chat"]["rag_prompt"]
        max_tokens = self.model_config["qa-prompt-max-token"]["max_tokens"]
        pubmed_token_limit = max_tokens
        web_token_limit = 60000
        personal_vector_token_limit = 80000
        if chat_request.is_pubmed and chat_request.is_web:
            personal_vector_token_limit = 40000
            pubmed_token_limit = 20000
            web_token_limit = 60000
        elif chat_request.is_pubmed and not chat_request.is_web:
            personal_vector_token_limit = 80000
            pubmed_token_limit = 40000
            web_token_limit = 0
        elif chat_request.is_pubmed and chat_request.is_web:
            personal_vector_token_limit = 0
            pubmed_token_limit = 60000
            web_token_limit = 60000
        elif chat_request.is_pubmed and not chat_request.is_web:
            personal_vector_token_limit = 0
            pubmed_token_limit = 120000
            web_token_limit = 0

        def calculate_num_tokens(
            pubmed_list: List[str], web_list: List[str]
        ) -> tuple[int, List[ChatCompletionMessageParam]]:
            # 合并结果
            docs_text = "\n".join(pubmed_list + web_list)

            pt = (
                openai_client_rag_prompt.replace("{search_results}", docs_text)
                .replace("{cur_date}", str(today))
                .replace("{question}", chat_request.query)
            )
            messages: List[ChatCompletionMessageParam] = [
                {"role": "user", "content": pt}
            ]
            # 计算token数
            num_tokens = num_tokens_from_messages(messages)
            return num_tokens, messages

        while True:
            num_tokens, messages = calculate_num_tokens(pubmed_list, web_list)
            if num_tokens <= max_tokens:
                break
            # 如果超过token限制,则按照比例进行截断
            logger.info(
                f"start truncate documents to token limit: max_tokens: {max_tokens}"
            )
            logger.info(
                f"pubmed_token_limit: {pubmed_token_limit}, web_token_limit: {web_token_limit}, personal_vector_token_limit: {personal_vector_token_limit}"
            )

            while True:
                if num_tokens_from_text("\n".join(pubmed_list)) > pubmed_token_limit:
                    pubmed_list.pop()
                else:
                    break

            # 截断pubmed之后,重新计算token数,如果token数小于max_tokens,则停止截断
            num_tokens, messages = calculate_num_tokens(pubmed_list, web_list)
            if num_tokens <= max_tokens:
                break

            while True:
                if num_tokens_from_text("\n".join(web_list)) > web_token_limit:
                    web_list.pop()
                else:
                    break

            # 截断web之后,重新计算token数,如果token数小于max_tokens,则停止截断
            num_tokens, messages = calculate_num_tokens(pubmed_list, web_list)
            if num_tokens <= max_tokens:
                break

        logger.info(f"Final token count: {num_tokens}")
        return messages, num_tokens

    def _create_messages(
        self,
        pubmed_results: List[Any],
        web_results: List[Any],
        chat_request: ChatRequest,
    ) -> tuple[List[ChatCompletionMessageParam], List[Any]]:
        """创建聊天消息"""
        if len(pubmed_results) == 0 and len(web_results) == 0:
            logger.info(f"No results found for query: {chat_request.query}")
            pt = chat_request.query
            messages: List[ChatCompletionMessageParam] = [
                {"role": "user", "content": pt}
            ]
            num_tokens = num_tokens_from_messages(messages)
            logger.info(f"Total tokens: {num_tokens}")
            return messages, []

        # 构建文档文本
        pubmed_texts, web_texts, citation_list = self._build_document_texts(
            pubmed_results, web_results
        )

        # 截断文档以符合token限制
        messages, num_tokens = self._truncate_documents_to_token_limit(
            pubmed_texts, web_texts, chat_request
        )

        return messages, citation_list

    async def _stream_chat_completion(
        self, messages: List[ChatCompletionMessageParam]
    ) -> AsyncGenerator[bytes, None]:
        """流式聊天完成,支持qa-llm的main/backup配置"""

        async def create_stream_with_config(
            qa_config: dict, config_name: str
        ) -> AsyncGenerator[bytes, None]:
            """使用指定配置创建流式响应"""
            try:
                logger.info(f"Using qa-llm {config_name} configuration")

                client = AsyncOpenAI(
                    api_key=qa_config["api_key"],
                    base_url=qa_config["base_url"],
                )

                chat_start_time = time.time()

                # 创建聊天完成流
                stream = await client.chat.completions.create(
                    model=qa_config["model"],
                    messages=messages,
                    stream=True,
                    temperature=qa_config["temperature"],
                    max_tokens=qa_config["max_tokens"],
                )

                logger.info(
                    f"Finished chat completion with {config_name} config, total time: {time.time() - chat_start_time:.2f}s"
                )

                is_start_answer = False
                # 处理流式响应
                async for chunk in stream:
                    if chunk.choices and (content := chunk.choices[0].delta.content):
                        if not is_start_answer:
                            is_start_answer = True

                        yield content.encode("utf-8")

            except Exception as e:
                logger.info(f"qa-llm {config_name} configuration failed: {e}")
                raise e

        async def with_fallback(main_func, backup_func):
            """高阶函数:尝试主函数,失败时使用备选函数"""
            try:
                async for content in main_func():
                    yield content
            except Exception as main_error:
                logger.info("Main config failed, falling back to backup configuration")
                try:
                    async for content in backup_func():
                        yield content
                except Exception as backup_error:
                    logger.error(
                        f"Both main and backup qa-llm configurations failed. "
                        f"Main error: {main_error}, Backup error: {backup_error}"
                    )
                    raise backup_error

        # 创建主用和备选配置的生成器函数
        async def main_stream():
            logger.info("Using main qa-llm configuration")
            async for content in create_stream_with_config(
                self.model_config["qa-llm"]["main"], "main"
            ):
                yield content

        async def backup_stream():
            logger.info("Using backup qa-llm configuration")
            async for content in create_stream_with_config(
                self.model_config["qa-llm"]["backup"], "backup"
            ):
                yield content

        # 使用fallback逻辑
        async for content in with_fallback(main_stream, backup_stream):
            yield content