Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import re | |
import time | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Tuple | |
import gradio as gr | |
import modelscope_studio.components.antd as antd | |
import modelscope_studio.components.antdx as antdx | |
import modelscope_studio.components.base as ms | |
import modelscope_studio.components.pro as pro | |
from mem0 import Memory | |
from modelscope_studio.components.pro.chatbot import (ChatbotBotConfig, | |
ChatbotPromptsConfig, | |
ChatbotUserConfig, | |
ChatbotWelcomeConfig) | |
from openai import OpenAI | |
# 配置日志 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler('xinyuan_chat.log'), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
class AppConfig: | |
"""应用配置类""" | |
users_file: str = "users.txt" | |
memory_path: str = "./faiss_memories" | |
model_name: str = "xinyuan-32b-v0609" | |
max_tokens: int = 32768 | |
temperature: float = 0.6 | |
top_p: float = 0.95 | |
chatbot_height: int = 1000 | |
min_username_length: int = 3 | |
max_memory_results: int = 5 | |
class MemoryManager: | |
"""记忆管理器""" | |
def __init__(self, config_path: str): | |
self.config = { | |
"vector_store": { | |
"provider": "faiss", | |
"config": { | |
"collection_name": "xinyuan_memories", | |
"path": config_path, | |
"distance_strategy": "euclidean" | |
} | |
} | |
} | |
try: | |
self.memory = Memory.from_config(self.config) | |
logger.info(f"Memory manager initialized with path: {config_path}") | |
except Exception as e: | |
logger.error(f"Failed to initialize memory manager: {e}") | |
raise | |
def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[Dict[str, Any]]: | |
"""搜索相关记忆""" | |
try: | |
if not query or not user_id: | |
return [] | |
results = self.memory.search(query=query, user_id=user_id, limit=limit) | |
if results and 'results' in results: | |
return sorted(results['results'], key=lambda x: x.get('score', 0), reverse=True) | |
return [] | |
except Exception as e: | |
logger.error(f"Error searching memories for user {user_id}: {e}") | |
return [] | |
def add_memory(self, messages: List[Dict[str, str]], user_id: str) -> bool: | |
"""添加记忆""" | |
try: | |
if not messages or not user_id: | |
return False | |
self.memory.add(messages, user_id=user_id) | |
logger.info(f"Memory added for user {user_id}") | |
return True | |
except Exception as e: | |
logger.error(f"Error adding memory for user {user_id}: {e}") | |
return False | |
class UserManager: | |
"""用户管理器""" | |
def __init__(self, users_file: str, min_username_length: int = 3): | |
self.users_file = Path(users_file) | |
self.min_username_length = min_username_length | |
self._ensure_users_file_exists() | |
def _ensure_users_file_exists(self): | |
"""确保用户文件存在""" | |
if not self.users_file.exists(): | |
self.users_file.touch() | |
logger.info(f"Created users file: {self.users_file}") | |
def load_users(self) -> set: | |
"""加载已注册用户列表""" | |
try: | |
with open(self.users_file, 'r', encoding='utf-8') as f: | |
users = {line.strip() for line in f if line.strip()} | |
logger.debug(f"Loaded {len(users)} users") | |
return users | |
except Exception as e: | |
logger.error(f"Error loading users: {e}") | |
return set() | |
def save_user(self, username: str) -> bool: | |
"""保存新用户到文件""" | |
try: | |
with open(self.users_file, 'a', encoding='utf-8') as f: | |
f.write(f"{username}\n") | |
logger.info(f"User {username} saved to file") | |
return True | |
except Exception as e: | |
logger.error(f"Error saving user {username}: {e}") | |
return False | |
def is_valid_username(self, username: str) -> bool: | |
"""验证用户名是否有效""" | |
if not username or not isinstance(username, str): | |
return False | |
# 检查长度 | |
if len(username) < self.min_username_length: | |
return False | |
# 检查格式:以字母开头,只包含字母、数字和下划线 | |
return bool(re.match(r'^[a-zA-Z][a-zA-Z0-9_]*$', username)) | |
def login_user(self, username: str) -> Tuple[bool, str]: | |
"""用户登录验证""" | |
if not self.is_valid_username(username): | |
return False, "用户名无效!用户名必须以英文字母开头,只能包含英文字母、数字和下划线,且长度至少3位。" | |
users = self.load_users() | |
if username in users: | |
logger.info(f"User {username} logged in successfully") | |
return True, f"欢迎回来,{username}!" | |
else: | |
logger.warning(f"Login attempt for unregistered user: {username}") | |
return False, f"用户 {username} 未注册,请先注册。" | |
def register_user(self, username: str) -> Tuple[bool, str]: | |
"""用户注册""" | |
if not self.is_valid_username(username): | |
return False, "用户名无效!用户名必须以英文字母开头,只能包含英文字母、数字和下划线,且长度至少3位。" | |
users = self.load_users() | |
if username in users: | |
logger.warning(f"Registration attempt for existing user: {username}") | |
return False, f"用户名 {username} 已存在,请直接登录。" | |
if self.save_user(username): | |
logger.info(f"User {username} registered successfully") | |
return True, f"注册成功!欢迎,{username}!" | |
else: | |
return False, "注册失败,请稍后重试。" | |
class ChatManager: | |
"""聊天管理器""" | |
def __init__(self, config: AppConfig, memory_manager: MemoryManager): | |
self.config = config | |
self.memory_manager = memory_manager | |
self.client = self._initialize_openai_client() | |
def _initialize_openai_client(self) -> OpenAI: | |
"""初始化OpenAI客户端""" | |
try: | |
# 可以根据需要配置API密钥和基础URL | |
gw_api_key = os.getenv("GW_API_KEY") | |
client = OpenAI( | |
base_url='https://api.geniuworks.com/v2', | |
api_key=gw_api_key, | |
) | |
logger.info("OpenAI client initialized successfully") | |
return client | |
except Exception as e: | |
logger.error(f"Failed to initialize OpenAI client: {e}") | |
raise | |
def format_history(self, sender_value: str, history: List[Dict], username: Optional[str] = None) -> List[Dict[str, str]]: | |
"""格式化聊天历史""" | |
messages = [] | |
# 添加系统提示 | |
if username: | |
system_prompt = f"""You are Xinyuan, a large language model trained by Cylingo Group. You are a helpful assistant. 目前和你聊天的用户是{username}.""" | |
# 搜索相关记忆 | |
if sender_value: | |
related_memories = self.memory_manager.search_memories( | |
query=sender_value, | |
user_id=username, | |
limit=self.config.max_memory_results | |
) | |
if related_memories: | |
memory_content = "\n相关记忆:\n" | |
for idx, memory in enumerate(related_memories): | |
memory_content += f"记忆{idx + 1}:{memory.get('memory', '')} (相关度: {memory.get('score', 0):.3f})\n" | |
system_prompt += memory_content | |
messages.append({"role": "system", "content": system_prompt}) | |
# 添加历史对话 | |
for item in history: | |
if item.get("role") == "user": | |
messages.append({"role": "user", "content": item.get("content", "")}) | |
elif item.get("role") == "assistant" and item.get("content"): | |
# 提取助手回复的文本内容 | |
content_list = item.get("content", []) | |
if content_list and len(content_list) > 1: | |
assistant_content = content_list[-1].get("content", "") | |
if assistant_content: | |
messages.append({"role": "assistant", "content": assistant_content}) | |
return messages | |
def create_chat_completion(self, messages: List[Dict[str, str]]) -> Any: | |
"""创建聊天完成请求""" | |
try: | |
return self.client.chat.completions.create( | |
model=self.config.model_name, | |
messages=messages, | |
stream=True, | |
max_tokens=self.config.max_tokens, | |
temperature=self.config.temperature, | |
top_p=self.config.top_p, | |
) | |
except Exception as e: | |
logger.error(f"Error creating chat completion: {e}") | |
raise | |
# 全局配置和管理器实例 | |
config = AppConfig() | |
memory_manager = MemoryManager(config.memory_path) | |
user_manager = UserManager(config.users_file, config.min_username_length) | |
chat_manager = ChatManager(config, memory_manager) | |
# Gradio界面函数 | |
def handle_auth(username: str, is_register: bool) -> Tuple: | |
"""处理认证逻辑""" | |
try: | |
if is_register: | |
success, message = user_manager.register_user(username) | |
else: | |
success, message = user_manager.login_user(username) | |
if success: | |
return ( | |
gr.update(visible=False), # 隐藏登录界面 | |
gr.update(visible=True), # 显示聊天界面 | |
gr.update(message=message, type="success", visible=True), | |
username | |
) | |
else: | |
return ( | |
gr.update(visible=True), # 保持登录界面可见 | |
gr.update(visible=False), # 隐藏聊天界面 | |
gr.update(message=message, type="error", visible=True), | |
"" | |
) | |
except Exception as e: | |
logger.error(f"Error in handle_auth: {e}") | |
return ( | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(message="系统错误,请稍后重试。", type="error", visible=True), | |
"" | |
) | |
def prompt_select(e: gr.EventData) -> gr.update: | |
"""处理提示选择""" | |
try: | |
return gr.update(value=e._data["payload"][0]["value"]["description"]) | |
except Exception as e: | |
logger.error(f"Error in prompt_select: {e}") | |
return gr.update(value="") | |
def clear() -> gr.update: | |
"""清空聊天记录""" | |
return gr.update(value=None) | |
def retry(chatbot_value: List, e: gr.EventData, username: Optional[str] = None): | |
"""重试功能""" | |
try: | |
index = e._data["payload"][0]["index"] | |
chatbot_value = chatbot_value[:index] | |
yield gr.update(value=None, loading=True), gr.update(value=chatbot_value), gr.update(disabled=True) | |
for chunk in submit(None, chatbot_value, username): | |
yield chunk | |
except Exception as e: | |
logger.error(f"Error in retry: {e}") | |
yield gr.update(value=None, loading=False), gr.update(value=chatbot_value), gr.update(disabled=False) | |
def cancel(chatbot_value: List) -> Tuple: | |
"""取消当前对话""" | |
try: | |
if chatbot_value: | |
chatbot_value[-1]["loading"] = False | |
chatbot_value[-1]["status"] = "done" | |
chatbot_value[-1]["footer"] = "Chat completion paused" | |
return ( | |
gr.update(value=chatbot_value), | |
gr.update(loading=False), | |
gr.update(disabled=False) | |
) | |
except Exception as e: | |
logger.error(f"Error in cancel: {e}") | |
return ( | |
gr.update(value=chatbot_value), | |
gr.update(loading=False), | |
gr.update(disabled=False) | |
) | |
def submit(sender_value: Optional[str], chatbot_value: List, username: Optional[str] = None): | |
"""提交聊天消息""" | |
start_time = time.time() | |
try: | |
# 添加用户消息 | |
if sender_value is not None: | |
chatbot_value.append({ | |
"role": "user", | |
"content": sender_value, | |
}) | |
# 格式化历史消息 | |
history_messages = chat_manager.format_history(sender_value, chatbot_value, username) | |
# 添加助手消息占位符 | |
chatbot_value.append({ | |
"role": "assistant", | |
"content": [], | |
"loading": True, | |
"status": "pending" | |
}) | |
# 更新UI状态 | |
yield ( | |
gr.update(value=None, loading=True), # sender | |
gr.update(value=chatbot_value), # chatbot | |
gr.update(disabled=True) # clear_btn | |
) | |
# 创建聊天完成请求 | |
response = chat_manager.create_chat_completion(history_messages) | |
# 处理流式响应 | |
thought_done = False | |
message_content = chatbot_value[-1]["content"] | |
# 初始化消息内容结构 | |
message_content.append({ | |
"copyable": False, | |
"editable": False, | |
"type": "tool", | |
"content": "", | |
"options": {"title": "Thinking..."} | |
}) | |
message_content.append({ | |
"type": "text", | |
"content": "", | |
}) | |
full_assistant_content = "" | |
# 处理流式响应 | |
for chunk in response: | |
try: | |
reasoning_content = getattr(chunk.choices[0].delta, 'reasoning_content', None) or "" | |
content = getattr(chunk.choices[0].delta, 'content', None) or "" | |
chatbot_value[-1]["loading"] = False | |
message_content[-2]["content"] += reasoning_content | |
message_content[-1]["content"] += content | |
if content: | |
full_assistant_content += content | |
if content and not thought_done: | |
thought_done = True | |
thought_cost_time = f"{time.time() - start_time:.2f}" | |
message_content[-2]["options"]["title"] = f"End of Thought ({thought_cost_time}s)" | |
message_content[-2]["options"]["status"] = "done" | |
yield ( | |
gr.update(), # sender | |
gr.update(value=chatbot_value), # chatbot | |
gr.update() # clear_btn | |
) | |
except Exception as chunk_error: | |
logger.error(f"Error processing chunk: {chunk_error}") | |
continue | |
# 保存到记忆 | |
if username and sender_value and full_assistant_content: | |
memory_messages = [ | |
{'role': 'user', 'content': sender_value}, | |
{'role': 'assistant', 'content': full_assistant_content} | |
] | |
memory_manager.add_memory(memory_messages, username) | |
# 完成响应 | |
total_time = f"{time.time() - start_time:.2f}s" | |
chatbot_value[-1]["footer"] = total_time | |
chatbot_value[-1]["status"] = "done" | |
yield ( | |
gr.update(loading=False), # sender | |
gr.update(value=chatbot_value), # chatbot | |
gr.update(disabled=False) # clear_btn | |
) | |
except Exception as e: | |
logger.error(f"Error in submit: {e}") | |
# 错误处理 | |
if chatbot_value: | |
chatbot_value[-1]["loading"] = False | |
chatbot_value[-1]["status"] = "done" | |
chatbot_value[-1]["content"] = "抱歉,处理您的请求时出现错误,请稍后重试。" | |
yield ( | |
gr.update(loading=False), # sender | |
gr.update(value=chatbot_value), # chatbot | |
gr.update(disabled=False) # clear_btn | |
) | |
# 创建Gradio界面 | |
def create_interface(): | |
"""创建Gradio界面""" | |
with gr.Blocks(title="Xinyuan 聊天助手") as demo, ms.Application(), antdx.XProvider(): | |
# 状态变量 | |
current_user = gr.State("") | |
# 登录界面 | |
with antd.Flex(vertical=True, gap="large", elem_id="login_container") as login_container: | |
with antd.Card(title="欢迎使用 Xinyuan 聊天助手"): | |
with antd.Flex(vertical=True, gap="middle"): | |
antd.Typography.Title("用户登录/注册", level=3) | |
antd.Typography.Text("请输入您的英文用户名(3位以上,仅支持英文字母、数字和下划线)") | |
username_input = antd.Input( | |
placeholder="请输入用户名(如:john_doe)", | |
size="large" | |
) | |
with antd.Flex(gap="small"): | |
login_btn = antd.Button("登录", type="primary", size="large") | |
register_btn = antd.Button("注册", size="large") | |
auth_message = antd.Alert( | |
message="请输入用户名", | |
type="info", | |
visible=False | |
) | |
# 聊天界面 | |
with antd.Flex(vertical=True, gap="middle", visible=False) as chat_container: | |
# 用户信息栏 | |
with antd.Flex(justify="space-between", align="center"): | |
user_info = gr.Markdown("") | |
logout_btn = antd.Button("退出登录", size="small") | |
# 聊天机器人组件 | |
chatbot = pro.Chatbot( | |
height=config.chatbot_height, | |
welcome_config=ChatbotWelcomeConfig( | |
variant="borderless", | |
icon="./xinyuan.png", | |
title="Hello, I'm Xinyuan👋", | |
description="You can input text to get started.", | |
prompts=ChatbotPromptsConfig( | |
title="How can I help you today?", | |
styles={ | |
"list": {"width": '100%'}, | |
"item": {"flex": 1}, | |
}, | |
items=[ | |
{ | |
"label": "💝 心理学与实际应用", | |
"children": [ | |
{"description": "课题分离是什么意思?"}, | |
{"description": "回避型依恋和焦虑型依恋有什么区别?还有其他依恋类型吗?"}, | |
{"description": "为什么我背单词的时候总是只记得开头和结尾,中间全忘了?"} | |
] | |
}, | |
{ | |
"label": "👪 儿童教育与发展", | |
"children": [ | |
{"description": "什么是正念养育?"}, | |
{"description": "2岁孩子分离焦虑严重,送托育中心天天哭闹怎么办?"}, | |
{"description": "4岁娃说话不清还爱打人,是心理问题还是欠管教?"} | |
] | |
} | |
] | |
) | |
), | |
user_config=ChatbotUserConfig( | |
avatar="https://api.dicebear.com/7.x/miniavs/svg?seed=3", | |
variant="shadow" | |
), | |
bot_config=ChatbotBotConfig( | |
header='Xinyuan', | |
avatar="./xinyuan.png", | |
actions=["copy", "retry"], | |
variant="shadow" | |
), | |
) | |
# 发送器组件 | |
with antdx.Sender() as sender: | |
with ms.Slot("prefix"): | |
with antd.Button(value=None, color="default", variant="text") as clear_btn: | |
with ms.Slot("icon"): | |
antd.Icon("ClearOutlined") | |
# 事件处理函数 | |
def handle_login(username: str): | |
return handle_auth(username, False) | |
def handle_register(username: str): | |
return handle_auth(username, True) | |
def handle_logout(): | |
return ( | |
gr.update(visible=True), # 显示登录界面 | |
gr.update(visible=False), # 隐藏聊天界面 | |
gr.update(message="已退出登录", type="info", visible=True), | |
gr.update(value=""), # 清空用户名输入 | |
"", # 清空用户信息显示 | |
"" # 清空当前用户状态 | |
) | |
def update_user_info(username: str) -> str: | |
return f"**当前用户: {username}**" if username else "" | |
# 绑定事件 | |
login_btn.click( | |
fn=handle_login, | |
inputs=[username_input], | |
outputs=[login_container, chat_container, auth_message, current_user] | |
).then( | |
fn=update_user_info, | |
inputs=[current_user], | |
outputs=[user_info] | |
) | |
register_btn.click( | |
fn=handle_register, | |
inputs=[username_input], | |
outputs=[login_container, chat_container, auth_message, current_user] | |
).then( | |
fn=update_user_info, | |
inputs=[current_user], | |
outputs=[user_info] | |
) | |
logout_btn.click( | |
fn=handle_logout, | |
outputs=[login_container, chat_container, auth_message, username_input, user_info, current_user] | |
) | |
# 聊天功能事件绑定 | |
clear_btn.click(fn=clear, outputs=[chatbot]) | |
submit_event = sender.submit( | |
fn=submit, | |
inputs=[sender, chatbot, current_user], | |
outputs=[sender, chatbot, clear_btn] | |
) | |
sender.cancel( | |
fn=cancel, | |
inputs=[chatbot], | |
outputs=[chatbot, sender, clear_btn], | |
cancels=[submit_event], | |
queue=False | |
) | |
chatbot.retry( | |
fn=retry, | |
inputs=[chatbot, current_user], | |
outputs=[sender, chatbot, clear_btn] | |
) | |
chatbot.welcome_prompt_select(fn=prompt_select, outputs=[sender]) | |
return demo | |
def main(): | |
"""主函数""" | |
try: | |
logger.info("Starting Xinyuan Chat Application") | |
demo = create_interface() | |
demo.queue().launch() | |
except Exception as e: | |
logger.error(f"Failed to start application: {e}") | |
raise | |
if __name__ == "__main__": | |
main() | |