""" 训练数据结构模块 实现2048游戏的训练数据结构,包括: 1. 棋盘状态的对数变换 2. 二面体群D4的8种变换(棋盘压缩) 3. 训练数据的内存缓存和硬盘持久化 4. 与PyTorch生态的集成 """ import numpy as np import torch from torch.utils.data import Dataset, DataLoader import pickle import hashlib import os from typing import Tuple, List, Dict, Optional, Union from dataclasses import dataclass from pathlib import Path import logging # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class TrainingExample: """单个训练样本的数据结构""" board_state: np.ndarray # 棋盘状态 (H, W) action: int # 动作 (0:上, 1:下, 2:左, 3:右) value: float # 该状态-动作对的价值 canonical_hash: str # 规范化后的哈希值 class BoardTransform: """棋盘变换工具类,实现二面体群D4的8种变换""" @staticmethod def log_transform(board: np.ndarray) -> np.ndarray: """ 对棋盘进行对数变换 Args: board: 原始棋盘状态,包含2的幂次数字 Returns: 对数变换后的棋盘,空位为0,其他位置为log2(value) """ result = np.zeros_like(board, dtype=np.int32) mask = board > 0 result[mask] = np.log2(board[mask]).astype(np.int32) return result @staticmethod def inverse_log_transform(log_board: np.ndarray) -> np.ndarray: """ 对数变换的逆变换 Args: log_board: 对数变换后的棋盘 Returns: 原始棋盘状态 """ result = np.zeros_like(log_board, dtype=np.int32) mask = log_board > 0 result[mask] = (2 ** log_board[mask]).astype(np.int32) return result @staticmethod def rotate_90(matrix: np.ndarray) -> np.ndarray: """顺时针旋转90度""" return np.rot90(matrix, k=-1) @staticmethod def flip_horizontal(matrix: np.ndarray) -> np.ndarray: """水平翻转""" return np.fliplr(matrix) @classmethod def get_all_transforms(cls, matrix: np.ndarray) -> List[np.ndarray]: """ 获取二面体群D4的所有8种变换 Args: matrix: 输入矩阵 Returns: 包含8种变换结果的列表 """ transforms = [] # 原始图像 (R0) r0 = matrix.copy() transforms.append(r0) # 旋转90° (R90) r90 = cls.rotate_90(r0) transforms.append(r90) # 旋转180° (R180) r180 = cls.rotate_90(r90) transforms.append(r180) # 旋转270° (R270) r270 = cls.rotate_90(r180) transforms.append(r270) # 水平翻转 (F) f = cls.flip_horizontal(r0) transforms.append(f) # 翻转后旋转90° (F+R90) fr90 = cls.rotate_90(f) transforms.append(fr90) # 翻转后旋转180° (F+R180) fr180 = cls.rotate_90(fr90) transforms.append(fr180) # 翻转后旋转270° (F+R270) fr270 = cls.rotate_90(fr180) transforms.append(fr270) return transforms @classmethod def get_canonical_form(cls, matrix: np.ndarray) -> Tuple[np.ndarray, int]: """ 获取矩阵的规范形式(字典序最小的变换) Args: matrix: 输入矩阵 Returns: (规范形式矩阵, 变换索引) """ transforms = cls.get_all_transforms(matrix) # 将每个变换拉平为1D向量并比较字典序 flattened = [t.flatten() for t in transforms] # 找到字典序最小的索引 min_idx = 0 min_flat = flattened[0] for i, flat in enumerate(flattened[1:], 1): # 逐元素比较字典序 if cls._is_lexicographically_smaller(flat, min_flat): min_idx = i min_flat = flat return transforms[min_idx], min_idx @staticmethod def _is_lexicographically_smaller(a: np.ndarray, b: np.ndarray) -> bool: """ 检查数组a是否在字典序上小于数组b Args: a, b: 要比较的数组 Returns: 如果a < b则返回True """ for i in range(min(len(a), len(b))): if a[i] < b[i]: return True elif a[i] > b[i]: return False # 如果前面都相等,较短的数组更小 return len(a) < len(b) @classmethod def compute_hash(cls, matrix: np.ndarray) -> str: """ 计算矩阵规范形式的哈希值 Args: matrix: 输入矩阵 Returns: 哈希字符串 """ canonical, _ = cls.get_canonical_form(matrix) # 使用规范形式的字节表示计算哈希 return hashlib.md5(canonical.tobytes()).hexdigest() class ScoreCalculator: """分数计算工具类""" @staticmethod def calculate_tile_value(tile_log: int) -> int: """ 计算单个瓦片的累积分数价值 Args: tile_log: 瓦片的对数值 (log2(tile_value)) Returns: 累积分数价值 """ if tile_log <= 1: # 对应原始值2或更小 return 0 # V(N) = (log2(N) - 1) * N,其中N = 2^tile_log n = 2 ** tile_log return (tile_log - 1) * n @classmethod def calculate_board_score(cls, log_board: np.ndarray) -> int: """ 计算整个棋盘的累积分数 Args: log_board: 对数变换后的棋盘 Returns: 总累积分数 """ total_score = 0 for tile_log in log_board.flatten(): if tile_log > 0: total_score += cls.calculate_tile_value(tile_log) return total_score class TrainingDataCache: """训练数据的内存缓存系统""" def __init__(self, max_size: int = 1000000): """ 初始化缓存 Args: max_size: 最大缓存条目数 """ self.max_size = max_size self.cache: Dict[str, TrainingExample] = {} self.access_order: List[str] = [] # 用于LRU淘汰 def get(self, key: str) -> Optional[TrainingExample]: """获取缓存项""" if key in self.cache: # 更新访问顺序 self.access_order.remove(key) self.access_order.append(key) return self.cache[key] return None def put(self, key: str, example: TrainingExample) -> None: """添加或更新缓存项""" if key in self.cache: # 更新现有项 self.cache[key] = example self.access_order.remove(key) self.access_order.append(key) else: # 添加新项 if len(self.cache) >= self.max_size: # LRU淘汰 oldest_key = self.access_order.pop(0) del self.cache[oldest_key] self.cache[key] = example self.access_order.append(key) def update_if_better(self, key: str, example: TrainingExample) -> bool: """如果新样本的价值更高,则更新缓存""" existing = self.get(key) if existing is None or example.value > existing.value: self.put(key, example) return True return False def size(self) -> int: """返回缓存大小""" return len(self.cache) def clear(self) -> None: """清空缓存""" self.cache.clear() self.access_order.clear() def get_all_examples(self) -> List[TrainingExample]: """获取所有缓存的训练样本""" return list(self.cache.values()) class TrainingDataPersistence: """训练数据的硬盘持久化系统""" def __init__(self, data_dir: str = "data/training"): """ 初始化持久化系统 Args: data_dir: 数据存储目录 """ self.data_dir = Path(data_dir) self.data_dir.mkdir(parents=True, exist_ok=True) def save_cache(self, cache: TrainingDataCache, filename: str) -> None: """ 保存缓存到硬盘 Args: cache: 要保存的缓存 filename: 文件名 """ filepath = self.data_dir / f"{filename}.pkl" examples = cache.get_all_examples() with open(filepath, 'wb') as f: pickle.dump(examples, f) logger.info(f"Saved {len(examples)} examples to {filepath}") def load_cache(self, filename: str) -> List[TrainingExample]: """ 从硬盘加载训练数据 Args: filename: 文件名 Returns: 训练样本列表 """ filepath = self.data_dir / f"{filename}.pkl" if not filepath.exists(): logger.warning(f"File {filepath} does not exist") return [] with open(filepath, 'rb') as f: examples = pickle.load(f) logger.info(f"Loaded {len(examples)} examples from {filepath}") return examples def save_examples_batch(self, examples: List[TrainingExample], batch_name: str) -> None: """ 批量保存训练样本 Args: examples: 训练样本列表 batch_name: 批次名称 """ self.save_cache(TrainingDataCache(), batch_name) # 直接保存examples列表 filepath = self.data_dir / f"{batch_name}.pkl" with open(filepath, 'wb') as f: pickle.dump(examples, f) logger.info(f"Saved batch {batch_name} with {len(examples)} examples") def list_saved_files(self) -> List[str]: """列出所有保存的数据文件""" return [f.stem for f in self.data_dir.glob("*.pkl")] class Game2048Dataset(Dataset): """PyTorch Dataset for 2048 training data""" def __init__(self, examples: List[TrainingExample], board_size: Tuple[int, int] = (4, 4), max_tile_value: int = 17): """ 初始化数据集 Args: examples: 训练样本列表 board_size: 棋盘大小 (height, width) max_tile_value: 最大瓦片值的对数 (log2) """ self.examples = examples self.board_size = board_size self.max_tile_value = max_tile_value def __len__(self) -> int: return len(self.examples) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ 获取单个训练样本 Args: idx: 样本索引 Returns: (board_tensor, action_tensor, value_tensor) """ example = self.examples[idx] # 将棋盘状态转换为one-hot编码 board_tensor = self._encode_board(example.board_state) # 动作标签 action_tensor = torch.tensor(example.action, dtype=torch.long) # 价值标签 value_tensor = torch.tensor(example.value, dtype=torch.float32) return board_tensor, action_tensor, value_tensor def _encode_board(self, board: np.ndarray) -> torch.Tensor: """ 将棋盘状态编码为one-hot张量 Args: board: 对数变换后的棋盘状态 Returns: 形状为 (max_tile_value + 1, height, width) 的张量 """ # 创建one-hot编码 # 通道0: 空位 (值为0) # 通道1: 值为1 (对应原始值2) # 通道2: 值为2 (对应原始值4) # ... channels = self.max_tile_value + 1 height, width = self.board_size encoded = torch.zeros(channels, height, width, dtype=torch.float32) for i in range(height): for j in range(width): tile_value = int(board[i, j]) if 0 <= tile_value <= self.max_tile_value: encoded[tile_value, i, j] = 1.0 return encoded class TrainingDataManager: """训练数据管理器,整合缓存、持久化和PyTorch集成""" def __init__(self, data_dir: str = "data/training", cache_size: int = 1000000, board_size: Tuple[int, int] = (4, 4)): """ 初始化数据管理器 Args: data_dir: 数据存储目录 cache_size: 内存缓存大小 board_size: 棋盘大小 """ self.cache = TrainingDataCache(cache_size) self.persistence = TrainingDataPersistence(data_dir) self.board_size = board_size self.transform = BoardTransform() self.score_calc = ScoreCalculator() def add_training_example(self, board_state: np.ndarray, action: int, value: float) -> str: """ 添加训练样本 Args: board_state: 原始棋盘状态 action: 动作 value: 价值 Returns: 样本的哈希键 """ # 对数变换 log_board = self.transform.log_transform(board_state) # 计算规范哈希 canonical_hash = self.transform.compute_hash(log_board) # 创建训练样本 example = TrainingExample( board_state=log_board, action=action, value=value, canonical_hash=canonical_hash ) # 构造缓存键 (状态哈希 + 动作) cache_key = f"{canonical_hash}_{action}" # 更新缓存(如果新价值更高) self.cache.update_if_better(cache_key, example) return cache_key def get_pytorch_dataset(self, filter_min_value: float = 0.0) -> Game2048Dataset: """ 获取PyTorch数据集 Args: filter_min_value: 最小价值过滤阈值 Returns: PyTorch数据集 """ examples = [ex for ex in self.cache.get_all_examples() if ex.value >= filter_min_value] return Game2048Dataset(examples, self.board_size) def get_dataloader(self, batch_size: int = 32, shuffle: bool = True, filter_min_value: float = 0.0, **kwargs) -> DataLoader: """ 获取PyTorch DataLoader Args: batch_size: 批次大小 shuffle: 是否打乱数据 filter_min_value: 最小价值过滤阈值 **kwargs: 其他DataLoader参数 Returns: PyTorch DataLoader """ dataset = self.get_pytorch_dataset(filter_min_value) return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs) def save_current_cache(self, filename: str) -> None: """保存当前缓存到硬盘""" self.persistence.save_cache(self.cache, filename) def load_from_file(self, filename: str) -> int: """ 从文件加载训练数据到缓存 Args: filename: 文件名 Returns: 加载的样本数量 """ examples = self.persistence.load_cache(filename) loaded_count = 0 for example in examples: cache_key = f"{example.canonical_hash}_{example.action}" if self.cache.update_if_better(cache_key, example): loaded_count += 1 logger.info(f"Loaded {loaded_count} examples into cache") return loaded_count def get_cache_stats(self) -> Dict[str, int]: """获取缓存统计信息""" return { "cache_size": self.cache.size(), "max_cache_size": self.cache.max_size, "saved_files": len(self.persistence.list_saved_files()) } def merge_caches(self, other_manager: 'TrainingDataManager') -> int: """ 合并另一个数据管理器的缓存 Args: other_manager: 另一个数据管理器 Returns: 合并的样本数量 """ merged_count = 0 for example in other_manager.cache.get_all_examples(): cache_key = f"{example.canonical_hash}_{example.action}" if self.cache.update_if_better(cache_key, example): merged_count += 1 logger.info(f"Merged {merged_count} examples from other cache") return merged_count