File size: 4,423 Bytes
79899c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""全局配置存储模块,提供配置文件的加载和缓存功能,API密钥和base_url从环境变量加载。"""

import os
from typing import Any, Dict, Optional

import yaml


class ConfigManager:
    """配置管理器,使用单例模式缓存配置,API密钥和base_url从环境变量加载。"""

    _instance = None
    _config: Optional[Dict[str, Any]] = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_config(self) -> Dict[str, Any]:
        """获取配置,如果未加载则自动加载。

        Returns:
            包含所有配置信息的字典
        """
        if self._config is None:
            self._config = self._load_config()
        return self._config

    def _get_environment(self) -> str:
        """获取当前环境类型。

        Returns:
            环境类型:'prod' 或 'dev'
        """
        return os.getenv("ENVIRONMENT", "dev").lower()

    def _get_config_path(self) -> str:
        """根据环境获取配置文件路径。

        Returns:
            配置文件路径
        """
        env = self._get_environment()
        if env == "prod":
            return "config/app_config_prod.yaml"

        return "config/app_config_dev.yaml"

    def _load_config(self) -> Dict[str, Any]:
        """加载配置文件,并覆盖API密钥和base_url为环境变量值。

        Returns:
            从YAML文件加载的配置字典,API密钥和base_url从环境变量覆盖
        """
        config_path = self._get_config_path()
        try:
            with open(config_path, "r", encoding="utf-8") as file:
                config = yaml.safe_load(file)
                # 添加环境信息到配置中
                config["environment"] = self._get_environment()
                
                # 从环境变量覆盖API密钥和base_url
                self._override_api_configs(config)
                
                return config
        except FileNotFoundError as exc:
            raise FileNotFoundError(f"配置文件未找到: {config_path}") from exc
        except yaml.YAMLError as exc:
            raise ValueError(f"配置文件格式错误: {exc}") from exc

    def _override_api_configs(self, config: Dict[str, Any]) -> None:
        """从环境变量覆盖API密钥和base_url配置。

        Args:
            config: 配置字典
        """
        # QA LLM 主模型
        if "qa-llm" in config and "main" in config["qa-llm"]:
            main_config = config["qa-llm"]["main"]
            if os.getenv("QA_LLM_MAIN_API_KEY"):
                main_config["api_key"] = os.getenv("QA_LLM_MAIN_API_KEY")
            if os.getenv("QA_LLM_MAIN_BASE_URL"):
                main_config["base_url"] = os.getenv("QA_LLM_MAIN_BASE_URL")
        
        # QA LLM 备用模型
        if "qa-llm" in config and "backup" in config["qa-llm"]:
            backup_config = config["qa-llm"]["backup"]
            if os.getenv("QA_LLM_BACKUP_API_KEY"):
                backup_config["api_key"] = os.getenv("QA_LLM_BACKUP_API_KEY")
            if os.getenv("QA_LLM_BACKUP_BASE_URL"):
                backup_config["base_url"] = os.getenv("QA_LLM_BACKUP_BASE_URL")
        
        # Rewrite LLM 备用模型 (GPT-4o)
        if "rewrite-llm" in config and "backup" in config["rewrite-llm"]:
            backup_config = config["rewrite-llm"]["backup"]
            if os.getenv("REWRITE_LLM_BACKUP_API_KEY"):
                backup_config["api_key"] = os.getenv("REWRITE_LLM_BACKUP_API_KEY")
            if os.getenv("REWRITE_LLM_BACKUP_BASE_URL"):
                backup_config["base_url"] = os.getenv("REWRITE_LLM_BACKUP_BASE_URL")
        
        # Rewrite LLM 主模型
        if "rewrite-llm" in config and "main" in config["rewrite-llm"]:
            main_config = config["rewrite-llm"]["main"]
            if os.getenv("REWRITE_LLM_MAIN_API_KEY"):
                main_config["api_key"] = os.getenv("REWRITE_LLM_MAIN_API_KEY")
            if os.getenv("REWRITE_LLM_MAIN_BASE_URL"):
                main_config["base_url"] = os.getenv("REWRITE_LLM_MAIN_BASE_URL")


# 全局配置管理器实例
_config_manager = ConfigManager()


def get_model_config() -> Dict[str, Any]:
    """获取模型配置。

    Returns:
        包含所有配置信息的字典
    """
    return _config_manager.get_config()