""" LLM模型调用路由器 根据传递的provider参数调用不同的LLM实现,支持延迟加载 """ import logging import torch from typing import Dict, Any, Optional, List, Union import os # 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题 os.environ["PYTORCH_DISABLE_DYNAMO"] = "1" os.environ["TORCH_COMPILE_DISABLE"] = "1" # 如果 torch._dynamo 可用,禁用它 try: import torch._dynamo torch._dynamo.config.disable = True torch._dynamo.config.suppress_errors = True except ImportError: pass import spaces from .llm_base import BaseChatCompletion from . import llm_gemma_mlx from . import llm_gemma_transfomers # 配置日志 logger = logging.getLogger("llm") class LLMRouter: """LLM模型调用路由器,支持多种实现的统一调用""" def __init__(self): """初始化路由器""" self._loaded_modules = {} # 用于缓存已加载的模块 self._llm_instances = {} # 用于缓存已实例化的LLM实例 # 定义支持的provider配置 self._provider_configs = { "gemma-mlx": { "module_path": "llm_gemma_mlx", "class_name": "GemmaMLXChatCompletion", "default_model": "mlx-community/gemma-3-12b-it-4bit-DWQ", "supported_params": ["model_name"], "description": "基于MLX库的Gemma聊天完成实现" }, "gemma-transformers": { "module_path": "llm_gemma_transfomers", "class_name": "GemmaTransformersChatCompletion", "default_model": "google/gemma-3-4b-it", "supported_params": [ "model_name", "device_map", "device" ], "description": "基于Transformers库的Gemma聊天完成实现" } } def _lazy_load_module(self, provider: str): """ 获取指定provider的模块 参数: provider: provider名称 返回: 对应的模块 """ if provider not in self._provider_configs: raise ValueError(f"不支持的provider: {provider}") if provider not in self._loaded_modules: module_path = self._provider_configs[provider]["module_path"] logger.info(f"获取模块: {module_path}") # 根据module_path返回对应的模块 if module_path == "llm_gemma_mlx": module = llm_gemma_mlx elif module_path == "llm_gemma_transfomers": module = llm_gemma_transfomers else: raise ImportError(f"未找到模块: {module_path}") self._loaded_modules[provider] = module logger.info(f"模块 {module_path} 获取成功") return self._loaded_modules[provider] def _get_llm_class(self, provider: str): """ 获取指定provider的LLM类 参数: provider: provider名称 返回: LLM类 """ module = self._lazy_load_module(provider) class_name = self._provider_configs[provider]["class_name"] if not hasattr(module, class_name): raise AttributeError(f"模块中未找到类: {class_name}") return getattr(module, class_name) def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]: """ 过滤参数,只保留指定provider支持的参数 参数: provider: provider名称 params: 原始参数字典 返回: 过滤后的参数字典 """ supported_params = self._provider_configs[provider]["supported_params"] filtered_params = {} for param in supported_params: if param in params: filtered_params[param] = params[param] # 如果没有指定model_name,使用默认模型 if "model_name" not in filtered_params and "model_name" in supported_params: filtered_params["model_name"] = self._provider_configs[provider]["default_model"] return filtered_params def _get_instance_key(self, provider: str, params: Dict[str, Any]) -> str: """ 生成LLM实例的缓存键 参数: provider: provider名称 params: 参数字典 返回: 实例缓存键 """ # 将参数转换为可哈希的字符串 param_str = "_".join([f"{k}={v}" for k, v in sorted(params.items())]) return f"{provider}_{param_str}" def _get_or_create_instance(self, provider: str, **kwargs) -> BaseChatCompletion: """ 获取或创建LLM实例(支持缓存复用) 参数: provider: provider名称 **kwargs: 构造函数参数 返回: LLM实例 """ # 过滤并准备参数 filtered_kwargs = self._filter_params(provider, kwargs) # 生成实例缓存键 instance_key = self._get_instance_key(provider, filtered_kwargs) # 检查是否已有缓存实例 if instance_key not in self._llm_instances: try: # 获取LLM类 llm_class = self._get_llm_class(provider) logger.debug(f"创建 {provider} LLM实例,参数: {filtered_kwargs}") # 创建实例 instance = llm_class(**filtered_kwargs) # 缓存实例 self._llm_instances[instance_key] = instance logger.info(f"LLM实例创建成功: {provider} ({instance.model_name})") except Exception as e: logger.error(f"创建 {provider} LLM实例失败: {str(e)}", exc_info=True) raise RuntimeError(f"创建LLM实例失败: {str(e)}") return self._llm_instances[instance_key] def chat_completion( self, messages: List[Dict[str, str]], provider: str, temperature: float = 0.7, max_tokens: int = 2048, top_p: float = 1.0, model: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """ 统一的聊天完成接口 参数: messages: 消息列表,每个消息包含role和content provider: LLM提供者名称 temperature: 温度参数,控制生成的随机性 max_tokens: 最大生成token数 top_p: nucleus采样参数 model: 可选的模型名称,如果提供则覆盖默认model_name **kwargs: 其他参数,如device等 返回: 聊天完成响应字典 """ logger.info(f"使用provider '{provider}' 进行聊天完成,消息数量: {len(messages)}") if provider not in self._provider_configs: available_providers = list(self._provider_configs.keys()) raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}") try: # 如果提供了model参数,添加到kwargs中 if model is not None: kwargs["model_name"] = model # 获取或创建LLM实例 llm_instance = self._get_or_create_instance(provider, **kwargs) # 调用聊天完成 result = llm_instance.create( messages=messages, temperature=temperature, max_tokens=max_tokens, top_p=top_p, model=model, **kwargs ) logger.info(f"聊天完成成功,使用tokens: {result.get('usage', {}).get('total_tokens', 'unknown')}") return result except Exception as e: logger.error(f"使用provider '{provider}' 进行聊天完成失败: {str(e)}", exc_info=True) raise RuntimeError(f"聊天完成失败: {str(e)}") def get_model_info(self, provider: str, **kwargs) -> Dict[str, Any]: """ 获取模型信息 参数: provider: provider名称 **kwargs: 构造函数参数 返回: 模型信息字典 """ try: llm_instance = self._get_or_create_instance(provider, **kwargs) return llm_instance.get_model_info() except Exception as e: logger.error(f"获取模型信息失败: {str(e)}") raise RuntimeError(f"获取模型信息失败: {str(e)}") def get_available_providers(self) -> Dict[str, str]: """ 获取所有可用的provider及其描述 返回: provider名称到描述的映射 """ return { provider: config["description"] for provider, config in self._provider_configs.items() } def get_provider_info(self, provider: str) -> Dict[str, Any]: """ 获取指定provider的详细信息 参数: provider: provider名称 返回: provider的配置信息 """ if provider not in self._provider_configs: raise ValueError(f"不支持的provider: {provider}") return self._provider_configs[provider].copy() def clear_cache(self): """清理缓存的实例""" # 清理每个实例的GPU缓存 for instance in self._llm_instances.values(): if hasattr(instance, 'clear_cache'): instance.clear_cache() # 清理实例缓存 self._llm_instances.clear() logger.info("LLM实例缓存已清理") # 创建全局路由器实例 _router = LLMRouter() @spaces.GPU(duration=180) def chat_completion( messages: List[Dict[str, str]], provider: str = "gemma-transformers", temperature: float = 0.7, max_tokens: int = 2048, top_p: float = 1.0, model: Optional[str] = None, device: Optional[str] = None, device_map: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """ 统一的聊天完成接口函数 参数: messages: 消息列表,每个消息包含role和content字段 provider: LLM提供者,可选值: - "gemma-mlx": 基于MLX库的Gemma聊天完成实现 - "gemma-transformers": 基于Transformers库的Gemma聊天完成实现 temperature: 温度参数,控制生成的随机性 (0.0-2.0) max_tokens: 最大生成token数 top_p: nucleus采样参数 (0.0-1.0) model: 模型名称,如果不指定则使用默认模型 device: 推理设备,'cpu'、'cuda'、'mps'(仅transformers provider支持) device_map: 设备映射配置(仅transformers provider支持) **kwargs: 其他参数 返回: 聊天完成响应字典,包含生成的消息和使用统计 示例: # 使用默认MLX实现 response = chat_completion( messages=[{"role": "user", "content": "你好"}], provider="gemma-mlx" ) # 使用Gemma transformers实现 response = chat_completion( messages=[{"role": "user", "content": "你好"}], provider="gemma-transformers", model="google/gemma-3-4b-it", device="cuda", ) # 自定义参数 response = chat_completion( messages=[ {"role": "system", "content": "你是一个有用的助手"}, {"role": "user", "content": "请介绍自己"} ], provider="gemma-mlx", temperature=0.8, max_tokens=1024 ) """ # 准备参数 params = kwargs.copy() if model is not None: params["model_name"] = model if device is not None: params["device"] = device if device_map: params["device_map"] = device_map return _router.chat_completion( messages=messages, provider=provider, temperature=temperature, max_tokens=max_tokens, top_p=top_p, model=model, **params ) def get_model_info(provider: str = "gemma-mlx", **kwargs) -> Dict[str, Any]: """ 获取模型信息 参数: provider: provider名称 **kwargs: 构造函数参数 返回: 模型信息字典 """ return _router.get_model_info(provider, **kwargs) def get_available_providers() -> Dict[str, str]: """ 获取所有可用的LLM提供者 返回: provider名称到描述的映射 """ return _router.get_available_providers() def get_provider_info(provider: str) -> Dict[str, Any]: """ 获取指定provider的详细信息 参数: provider: provider名称 返回: provider的配置信息 """ return _router.get_provider_info(provider) def clear_cache(): """清理缓存的LLM实例""" _router.clear_cache()