388 lines
9.9 KiB
Python
388 lines
9.9 KiB
Python
"""
|
||
配置加载器模块
|
||
|
||
提供统一的配置加载机制,支持多数据源配置合并:
|
||
- 环境变量(优先级最高)
|
||
- YAML 配置文件
|
||
- 默认值(优先级最低)
|
||
|
||
设计原则:
|
||
- 单一职责:仅负责配置的加载和合并
|
||
- 依赖倒置:通过抽象接口解耦配置源
|
||
- 开闭原则:易于扩展新的配置源
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
from abc import ABC, abstractmethod
|
||
from functools import lru_cache
|
||
from pathlib import Path
|
||
from typing import Any, TypeVar
|
||
|
||
import yaml
|
||
|
||
T = TypeVar("T")
|
||
|
||
|
||
class ConfigSourceError(Exception):
|
||
"""配置源加载错误"""
|
||
|
||
def __init__(self, source: str, message: str) -> None:
|
||
self.source = source
|
||
self.message = message
|
||
super().__init__(f"[{source}] {message}")
|
||
|
||
|
||
class ConfigSource(ABC):
|
||
"""配置源抽象基类"""
|
||
|
||
@property
|
||
@abstractmethod
|
||
def name(self) -> str:
|
||
"""配置源名称"""
|
||
...
|
||
|
||
@property
|
||
@abstractmethod
|
||
def priority(self) -> int:
|
||
"""优先级,数值越大优先级越高"""
|
||
...
|
||
|
||
@abstractmethod
|
||
def load(self) -> dict[str, Any]:
|
||
"""加载配置,返回扁平化的键值对"""
|
||
...
|
||
|
||
|
||
class EnvConfigSource(ConfigSource):
|
||
"""环境变量配置源"""
|
||
|
||
def __init__(self, prefix: str = "") -> None:
|
||
"""
|
||
初始化环境变量配置源
|
||
|
||
Args:
|
||
prefix: 环境变量前缀,用于过滤相关配置
|
||
例如 prefix="SATONANO_" 时只读取以此为前缀的变量
|
||
"""
|
||
self._prefix = prefix.upper()
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return "environment"
|
||
|
||
@property
|
||
def priority(self) -> int:
|
||
return 100 # 最高优先级
|
||
|
||
def load(self) -> dict[str, Any]:
|
||
"""
|
||
加载环境变量
|
||
|
||
Returns:
|
||
配置字典,键名转换为小写并移除前缀
|
||
"""
|
||
config: dict[str, Any] = {}
|
||
prefix_len = len(self._prefix)
|
||
|
||
for key, value in os.environ.items():
|
||
if self._prefix and not key.upper().startswith(self._prefix):
|
||
continue
|
||
|
||
# 移除前缀并转为小写
|
||
config_key = key[prefix_len:].lower() if self._prefix else key.lower()
|
||
config[config_key] = self._parse_value(value)
|
||
|
||
return config
|
||
|
||
@staticmethod
|
||
def _parse_value(value: str) -> Any:
|
||
"""
|
||
解析环境变量值,自动转换类型
|
||
|
||
Args:
|
||
value: 原始字符串值
|
||
|
||
Returns:
|
||
转换后的值(bool/int/float/str)
|
||
"""
|
||
# 布尔值
|
||
if value.lower() in ("true", "yes", "1", "on"):
|
||
return True
|
||
if value.lower() in ("false", "no", "0", "off"):
|
||
return False
|
||
|
||
# 整数
|
||
try:
|
||
return int(value)
|
||
except ValueError:
|
||
pass
|
||
|
||
# 浮点数
|
||
try:
|
||
return float(value)
|
||
except ValueError:
|
||
pass
|
||
|
||
return value
|
||
|
||
|
||
class YamlConfigSource(ConfigSource):
|
||
"""YAML 配置文件源"""
|
||
|
||
def __init__(
|
||
self,
|
||
file_path: str | Path = "config.yaml",
|
||
required: bool = False,
|
||
) -> None:
|
||
"""
|
||
初始化 YAML 配置源
|
||
|
||
Args:
|
||
file_path: 配置文件路径
|
||
required: 是否必须存在,为 True 时文件不存在会抛出异常
|
||
"""
|
||
self._file_path = Path(file_path)
|
||
self._required = required
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return f"yaml:{self._file_path}"
|
||
|
||
@property
|
||
def priority(self) -> int:
|
||
return 50 # 中等优先级
|
||
|
||
def load(self) -> dict[str, Any]:
|
||
"""
|
||
加载 YAML 配置文件
|
||
|
||
Returns:
|
||
配置字典
|
||
|
||
Raises:
|
||
ConfigSourceError: 文件不存在(required=True)或解析错误
|
||
"""
|
||
if not self._file_path.exists():
|
||
if self._required:
|
||
raise ConfigSourceError(
|
||
self.name,
|
||
f"配置文件不存在: {self._file_path.absolute()}",
|
||
)
|
||
return {}
|
||
|
||
try:
|
||
with open(self._file_path, encoding="utf-8") as f:
|
||
data = yaml.safe_load(f)
|
||
except yaml.YAMLError as e:
|
||
raise ConfigSourceError(self.name, f"YAML 解析错误: {e}") from e
|
||
|
||
if data is None:
|
||
return {}
|
||
|
||
# 处理列表格式(兼容用户示例格式)
|
||
if isinstance(data, list):
|
||
return self._flatten_list(data)
|
||
|
||
# 标准字典格式
|
||
if isinstance(data, dict):
|
||
return self._normalize_keys(data)
|
||
|
||
raise ConfigSourceError(
|
||
self.name,
|
||
f"不支持的配置格式,期望 dict 或 list,得到 {type(data).__name__}",
|
||
)
|
||
|
||
@staticmethod
|
||
def _flatten_list(items: list) -> dict[str, Any]:
|
||
"""
|
||
将列表格式配置展平为字典
|
||
|
||
支持格式:
|
||
- key: value
|
||
- key: value
|
||
"""
|
||
result: dict[str, Any] = {}
|
||
for item in items:
|
||
if isinstance(item, dict):
|
||
for key, value in item.items():
|
||
result[key.lower()] = value
|
||
return result
|
||
|
||
@staticmethod
|
||
def _normalize_keys(data: dict) -> dict[str, Any]:
|
||
"""键名标准化为小写"""
|
||
return {k.lower(): v for k, v in data.items()}
|
||
|
||
|
||
class ConfigLoader:
|
||
"""
|
||
配置加载器
|
||
|
||
负责从多个配置源加载并合并配置,按优先级覆盖。
|
||
|
||
Example:
|
||
>>> loader = ConfigLoader()
|
||
>>> loader.add_source(YamlConfigSource("config.yaml"))
|
||
>>> loader.add_source(EnvConfigSource("SATONANO_"))
|
||
>>> config = loader.load()
|
||
>>> db_type = loader.get("database_type", "sqlite")
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
self._sources: list[ConfigSource] = []
|
||
self._config: dict[str, Any] = {}
|
||
self._loaded = False
|
||
|
||
def add_source(self, source: ConfigSource) -> ConfigLoader:
|
||
"""
|
||
添加配置源
|
||
|
||
Args:
|
||
source: 配置源实例
|
||
|
||
Returns:
|
||
self,支持链式调用
|
||
"""
|
||
self._sources.append(source)
|
||
self._loaded = False # 标记需要重新加载
|
||
return self
|
||
|
||
def load(self) -> dict[str, Any]:
|
||
"""
|
||
加载并合并所有配置源
|
||
|
||
Returns:
|
||
合并后的配置字典
|
||
"""
|
||
# 按优先级升序排序,后加载的覆盖先加载的
|
||
sorted_sources = sorted(self._sources, key=lambda s: s.priority)
|
||
|
||
self._config = {}
|
||
for source in sorted_sources:
|
||
source_config = source.load()
|
||
self._config.update(source_config)
|
||
|
||
self._loaded = True
|
||
return self._config.copy()
|
||
|
||
def get(self, key: str, default: T = None) -> T | Any:
|
||
"""
|
||
获取配置值
|
||
|
||
Args:
|
||
key: 配置键名(不区分大小写)
|
||
default: 默认值
|
||
|
||
Returns:
|
||
配置值或默认值
|
||
"""
|
||
if not self._loaded:
|
||
self.load()
|
||
return self._config.get(key.lower(), default)
|
||
|
||
def get_str(self, key: str, default: str = "") -> str:
|
||
"""获取字符串配置"""
|
||
value = self.get(key, default)
|
||
return str(value) if value is not None else default
|
||
|
||
def get_int(self, key: str, default: int = 0) -> int:
|
||
"""获取整数配置"""
|
||
value = self.get(key, default)
|
||
if isinstance(value, int):
|
||
return value
|
||
try:
|
||
return int(value)
|
||
except (ValueError, TypeError):
|
||
return default
|
||
|
||
def get_float(self, key: str, default: float = 0.0) -> float:
|
||
"""获取浮点数配置"""
|
||
value = self.get(key, default)
|
||
if isinstance(value, float):
|
||
return value
|
||
try:
|
||
return float(value)
|
||
except (ValueError, TypeError):
|
||
return default
|
||
|
||
def get_bool(self, key: str, default: bool = False) -> bool:
|
||
"""获取布尔配置"""
|
||
value = self.get(key, default)
|
||
if isinstance(value, bool):
|
||
return value
|
||
if isinstance(value, str):
|
||
return value.lower() in ("true", "yes", "1", "on")
|
||
return bool(value)
|
||
|
||
def require(self, key: str) -> Any:
|
||
"""
|
||
获取必需的配置值
|
||
|
||
Args:
|
||
key: 配置键名
|
||
|
||
Returns:
|
||
配置值
|
||
|
||
Raises:
|
||
KeyError: 配置不存在
|
||
"""
|
||
if not self._loaded:
|
||
self.load()
|
||
|
||
key_lower = key.lower()
|
||
if key_lower not in self._config:
|
||
raise KeyError(f"缺少必需的配置项: {key}")
|
||
return self._config[key_lower]
|
||
|
||
def all(self) -> dict[str, Any]:
|
||
"""获取所有配置的副本"""
|
||
if not self._loaded:
|
||
self.load()
|
||
return self._config.copy()
|
||
|
||
def __contains__(self, key: str) -> bool:
|
||
"""检查配置是否存在"""
|
||
if not self._loaded:
|
||
self.load()
|
||
return key.lower() in self._config
|
||
|
||
def __repr__(self) -> str:
|
||
sources = ", ".join(s.name for s in self._sources)
|
||
return f"ConfigLoader(sources=[{sources}], loaded={self._loaded})"
|
||
|
||
|
||
def create_config_loader(
|
||
yaml_path: str | Path = "config.yaml",
|
||
env_prefix: str = "SATONANO_",
|
||
yaml_required: bool = False,
|
||
) -> ConfigLoader:
|
||
"""
|
||
创建预配置的配置加载器
|
||
|
||
Args:
|
||
yaml_path: YAML 配置文件路径
|
||
env_prefix: 环境变量前缀
|
||
yaml_required: YAML 文件是否必须存在
|
||
|
||
Returns:
|
||
配置好的 ConfigLoader 实例
|
||
"""
|
||
loader = ConfigLoader()
|
||
loader.add_source(YamlConfigSource(yaml_path, required=yaml_required))
|
||
loader.add_source(EnvConfigSource(env_prefix))
|
||
return loader
|
||
|
||
|
||
@lru_cache
|
||
def get_config_loader() -> ConfigLoader:
|
||
"""获取全局配置加载器单例"""
|
||
return create_config_loader()
|
||
|
||
|
||
# 便捷访问的全局实例
|
||
config_loader = get_config_loader()
|
||
|