""" 说话人分离模型调用路由器 根据传递的provider参数调用不同的说话人分离实现,支持延迟加载 """ import logging from typing import Dict, Any, Optional, Callable from pydub import AudioSegment import spaces from ..schemas import DiarizationResult from . import diarization_pyannote_mlx from . import diarization_pyannote_transformers # 配置日志 logger = logging.getLogger("diarization") class DiarizerRouter: """说话人分离模型调用路由器,支持多种实现的统一调用""" def __init__(self): """初始化路由器""" self._loaded_modules = {} # 用于缓存已加载的模块 self._diarizers = {} # 用于缓存已实例化的分离器 # 定义支持的provider配置 self._provider_configs = { "pyannote_mlx": { "module_path": "diarization_pyannote_mlx", "function_name": "diarize_audio", "default_model": "pyannote/speaker-diarization-3.1", "supported_params": ["model_name", "token", "device", "segmentation_batch_size"], "description": "基于pyannote.audio的原生MLX实现" }, "pyannote_transformers": { "module_path": "diarization_pyannote_transformers", "function_name": "diarize_audio", "default_model": "pyannote/speaker-diarization-3.1", "supported_params": ["model_name", "token", "device", "segmentation_batch_size"], "description": "基于transformers库调用pyannote模型" } } def _lazy_load_module(self, provider: str): """ 获取指定provider的模块 参数: provider: provider名称 返回: 对应的模块 """ if provider not in self._provider_configs: raise ValueError(f"不支持的provider: {provider}") if provider not in self._loaded_modules: module_path = self._provider_configs[provider]["module_path"] logger.info(f"获取模块: {module_path}") # 根据module_path返回对应的模块 if module_path == "diarization_pyannote_mlx": module = diarization_pyannote_mlx elif module_path == "diarization_pyannote_transformers": module = diarization_pyannote_transformers else: raise ImportError(f"未找到模块: {module_path}") self._loaded_modules[provider] = module logger.info(f"模块 {module_path} 获取成功") return self._loaded_modules[provider] def _get_diarize_function(self, provider: str) -> Callable: """ 获取指定provider的说话人分离函数 参数: provider: provider名称 返回: 说话人分离函数 """ module = self._lazy_load_module(provider) function_name = self._provider_configs[provider]["function_name"] if not hasattr(module, function_name): raise AttributeError(f"模块中未找到函数: {function_name}") return getattr(module, function_name) def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]: """ 过滤参数,只保留指定provider支持的参数 参数: provider: provider名称 params: 原始参数字典 返回: 过滤后的参数字典 """ supported_params = self._provider_configs[provider]["supported_params"] filtered_params = {} for param in supported_params: if param in params: filtered_params[param] = params[param] # 如果没有指定model_name,使用默认模型 if "model_name" not in filtered_params and "model_name" in supported_params: filtered_params["model_name"] = self._provider_configs[provider]["default_model"] return filtered_params def diarize( self, audio_segment: AudioSegment, provider: str, **kwargs ) -> DiarizationResult: """ 统一的说话人分离接口 参数: audio_segment: 输入的AudioSegment对象 provider: 说话人分离提供者名称 **kwargs: 其他参数,如model_name, token, device, segmentation_batch_size等 返回: DiarizationResult对象 """ logger.info(f"使用provider '{provider}' 进行说话人分离,音频长度: {len(audio_segment)/1000:.2f}秒") if provider not in self._provider_configs: available_providers = list(self._provider_configs.keys()) raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}") try: # 获取说话人分离函数 diarize_func = self._get_diarize_function(provider) # 过滤并准备参数 filtered_kwargs = self._filter_params(provider, kwargs) logger.debug(f"调用 {provider} 说话人分离函数,参数: {filtered_kwargs}") # 执行说话人分离 result = diarize_func(audio_segment, **filtered_kwargs) logger.info(f"说话人分离完成,检测到 {result.num_speakers} 个说话人,生成 {len(result.segments)} 个分段") return result except Exception as e: logger.error(f"使用provider '{provider}' 进行说话人分离失败: {str(e)}", exc_info=True) raise RuntimeError(f"说话人分离失败: {str(e)}") def get_available_providers(self) -> Dict[str, str]: """ 获取所有可用的provider及其描述 返回: provider名称到描述的映射 """ return { provider: config["description"] for provider, config in self._provider_configs.items() } def get_provider_info(self, provider: str) -> Dict[str, Any]: """ 获取指定provider的详细信息 参数: provider: provider名称 返回: provider的配置信息 """ if provider not in self._provider_configs: raise ValueError(f"不支持的provider: {provider}") return self._provider_configs[provider].copy() # 创建全局路由器实例 _router = DiarizerRouter() @spaces.GPU(duration=180) def diarize_audio( audio_segment: AudioSegment, provider: str = "pyannote_mlx", model_name: Optional[str] = None, token: Optional[str] = None, device: str = "cpu", segmentation_batch_size: int = 32, **kwargs ) -> DiarizationResult: """ 统一的音频说话人分离接口函数 参数: audio_segment: 输入的AudioSegment对象 provider: 说话人分离提供者,可选值: - "pyannote_mlx": 基于pyannote.audio的原生MLX实现 - "pyannote_transformers": 基于transformers库调用pyannote模型 model_name: 模型名称,如果不指定则使用默认模型 token: Hugging Face令牌,用于访问模型 device: 推理设备,'cpu'、'cuda'、'mps' segmentation_batch_size: 分割批处理大小,默认为32 **kwargs: 其他参数 返回: DiarizationResult对象,包含分段结果和说话人数量 示例: # 使用默认pyannote MLX实现 result = diarize_audio(audio_segment, provider="pyannote_mlx") # 使用transformers实现 result = diarize_audio( audio_segment, provider="pyannote_transformers", ) # 使用GPU设备 result = diarize_audio( audio_segment, provider="pyannote_mlx", device="cuda" ) # 自定义批处理大小 result = diarize_audio( audio_segment, provider="pyannote_mlx", segmentation_batch_size=64 ) """ # 准备参数 params = kwargs.copy() if model_name is not None: params["model_name"] = model_name if token is not None: params["token"] = token if device != "cpu": params["device"] = device if segmentation_batch_size != 32: params["segmentation_batch_size"] = segmentation_batch_size return _router.diarize(audio_segment, provider, **params) def get_available_providers() -> Dict[str, str]: """ 获取所有可用的说话人分离提供者 返回: provider名称到描述的映射 """ return _router.get_available_providers() def get_provider_info(provider: str) -> Dict[str, Any]: """ 获取指定provider的详细信息 参数: provider: provider名称 返回: provider的配置信息 """ return _router.get_provider_info(provider)