Files
deep2048/training_data.py
2025-07-23 07:04:10 +08:00

576 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
训练数据结构模块
实现2048游戏的训练数据结构包括
1. 棋盘状态的对数变换
2. 二面体群D4的8种变换棋盘压缩
3. 训练数据的内存缓存和硬盘持久化
4. 与PyTorch生态的集成
"""
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
import hashlib
import os
from typing import Tuple, List, Dict, Optional, Union
from dataclasses import dataclass
from pathlib import Path
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class TrainingExample:
"""单个训练样本的数据结构"""
board_state: np.ndarray # 棋盘状态 (H, W)
action: int # 动作 (0:上, 1:下, 2:左, 3:右)
value: float # 该状态-动作对的价值
canonical_hash: str # 规范化后的哈希值
class BoardTransform:
"""棋盘变换工具类实现二面体群D4的8种变换"""
@staticmethod
def log_transform(board: np.ndarray) -> np.ndarray:
"""
对棋盘进行对数变换
Args:
board: 原始棋盘状态包含2的幂次数字
Returns:
对数变换后的棋盘空位为0其他位置为log2(value)
"""
result = np.zeros_like(board, dtype=np.int32)
mask = board > 0
result[mask] = np.log2(board[mask]).astype(np.int32)
return result
@staticmethod
def inverse_log_transform(log_board: np.ndarray) -> np.ndarray:
"""
对数变换的逆变换
Args:
log_board: 对数变换后的棋盘
Returns:
原始棋盘状态
"""
result = np.zeros_like(log_board, dtype=np.int32)
mask = log_board > 0
result[mask] = (2 ** log_board[mask]).astype(np.int32)
return result
@staticmethod
def rotate_90(matrix: np.ndarray) -> np.ndarray:
"""顺时针旋转90度"""
return np.rot90(matrix, k=-1)
@staticmethod
def flip_horizontal(matrix: np.ndarray) -> np.ndarray:
"""水平翻转"""
return np.fliplr(matrix)
@classmethod
def get_all_transforms(cls, matrix: np.ndarray) -> List[np.ndarray]:
"""
获取二面体群D4的所有8种变换
Args:
matrix: 输入矩阵
Returns:
包含8种变换结果的列表
"""
transforms = []
# 原始图像 (R0)
r0 = matrix.copy()
transforms.append(r0)
# 旋转90° (R90)
r90 = cls.rotate_90(r0)
transforms.append(r90)
# 旋转180° (R180)
r180 = cls.rotate_90(r90)
transforms.append(r180)
# 旋转270° (R270)
r270 = cls.rotate_90(r180)
transforms.append(r270)
# 水平翻转 (F)
f = cls.flip_horizontal(r0)
transforms.append(f)
# 翻转后旋转90° (F+R90)
fr90 = cls.rotate_90(f)
transforms.append(fr90)
# 翻转后旋转180° (F+R180)
fr180 = cls.rotate_90(fr90)
transforms.append(fr180)
# 翻转后旋转270° (F+R270)
fr270 = cls.rotate_90(fr180)
transforms.append(fr270)
return transforms
@classmethod
def get_canonical_form(cls, matrix: np.ndarray) -> Tuple[np.ndarray, int]:
"""
获取矩阵的规范形式(字典序最小的变换)
Args:
matrix: 输入矩阵
Returns:
(规范形式矩阵, 变换索引)
"""
transforms = cls.get_all_transforms(matrix)
# 将每个变换拉平为1D向量并比较字典序
flattened = [t.flatten() for t in transforms]
# 找到字典序最小的索引
min_idx = 0
min_flat = flattened[0]
for i, flat in enumerate(flattened[1:], 1):
# 逐元素比较字典序
if cls._is_lexicographically_smaller(flat, min_flat):
min_idx = i
min_flat = flat
return transforms[min_idx], min_idx
@staticmethod
def _is_lexicographically_smaller(a: np.ndarray, b: np.ndarray) -> bool:
"""
检查数组a是否在字典序上小于数组b
Args:
a, b: 要比较的数组
Returns:
如果a < b则返回True
"""
for i in range(min(len(a), len(b))):
if a[i] < b[i]:
return True
elif a[i] > b[i]:
return False
# 如果前面都相等,较短的数组更小
return len(a) < len(b)
@classmethod
def compute_hash(cls, matrix: np.ndarray) -> str:
"""
计算矩阵规范形式的哈希值
Args:
matrix: 输入矩阵
Returns:
哈希字符串
"""
canonical, _ = cls.get_canonical_form(matrix)
# 使用规范形式的字节表示计算哈希
return hashlib.md5(canonical.tobytes()).hexdigest()
class ScoreCalculator:
"""分数计算工具类"""
@staticmethod
def calculate_tile_value(tile_log: int) -> int:
"""
计算单个瓦片的累积分数价值
Args:
tile_log: 瓦片的对数值 (log2(tile_value))
Returns:
累积分数价值
"""
if tile_log <= 1: # 对应原始值2或更小
return 0
# V(N) = (log2(N) - 1) * N其中N = 2^tile_log
n = 2 ** tile_log
return (tile_log - 1) * n
@classmethod
def calculate_board_score(cls, log_board: np.ndarray) -> int:
"""
计算整个棋盘的累积分数
Args:
log_board: 对数变换后的棋盘
Returns:
总累积分数
"""
total_score = 0
for tile_log in log_board.flatten():
if tile_log > 0:
total_score += cls.calculate_tile_value(tile_log)
return total_score
class TrainingDataCache:
"""训练数据的内存缓存系统"""
def __init__(self, max_size: int = 1000000):
"""
初始化缓存
Args:
max_size: 最大缓存条目数
"""
self.max_size = max_size
self.cache: Dict[str, TrainingExample] = {}
self.access_order: List[str] = [] # 用于LRU淘汰
def get(self, key: str) -> Optional[TrainingExample]:
"""获取缓存项"""
if key in self.cache:
# 更新访问顺序
self.access_order.remove(key)
self.access_order.append(key)
return self.cache[key]
return None
def put(self, key: str, example: TrainingExample) -> None:
"""添加或更新缓存项"""
if key in self.cache:
# 更新现有项
self.cache[key] = example
self.access_order.remove(key)
self.access_order.append(key)
else:
# 添加新项
if len(self.cache) >= self.max_size:
# LRU淘汰
oldest_key = self.access_order.pop(0)
del self.cache[oldest_key]
self.cache[key] = example
self.access_order.append(key)
def update_if_better(self, key: str, example: TrainingExample) -> bool:
"""如果新样本的价值更高,则更新缓存"""
existing = self.get(key)
if existing is None or example.value > existing.value:
self.put(key, example)
return True
return False
def size(self) -> int:
"""返回缓存大小"""
return len(self.cache)
def clear(self) -> None:
"""清空缓存"""
self.cache.clear()
self.access_order.clear()
def get_all_examples(self) -> List[TrainingExample]:
"""获取所有缓存的训练样本"""
return list(self.cache.values())
class TrainingDataPersistence:
"""训练数据的硬盘持久化系统"""
def __init__(self, data_dir: str = "data/training"):
"""
初始化持久化系统
Args:
data_dir: 数据存储目录
"""
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
def save_cache(self, cache: TrainingDataCache, filename: str) -> None:
"""
保存缓存到硬盘
Args:
cache: 要保存的缓存
filename: 文件名
"""
filepath = self.data_dir / f"{filename}.pkl"
examples = cache.get_all_examples()
with open(filepath, 'wb') as f:
pickle.dump(examples, f)
logger.info(f"Saved {len(examples)} examples to {filepath}")
def load_cache(self, filename: str) -> List[TrainingExample]:
"""
从硬盘加载训练数据
Args:
filename: 文件名
Returns:
训练样本列表
"""
filepath = self.data_dir / f"{filename}.pkl"
if not filepath.exists():
logger.warning(f"File {filepath} does not exist")
return []
with open(filepath, 'rb') as f:
examples = pickle.load(f)
logger.info(f"Loaded {len(examples)} examples from {filepath}")
return examples
def save_examples_batch(self, examples: List[TrainingExample],
batch_name: str) -> None:
"""
批量保存训练样本
Args:
examples: 训练样本列表
batch_name: 批次名称
"""
self.save_cache(TrainingDataCache(), batch_name)
# 直接保存examples列表
filepath = self.data_dir / f"{batch_name}.pkl"
with open(filepath, 'wb') as f:
pickle.dump(examples, f)
logger.info(f"Saved batch {batch_name} with {len(examples)} examples")
def list_saved_files(self) -> List[str]:
"""列出所有保存的数据文件"""
return [f.stem for f in self.data_dir.glob("*.pkl")]
class Game2048Dataset(Dataset):
"""PyTorch Dataset for 2048 training data"""
def __init__(self, examples: List[TrainingExample],
board_size: Tuple[int, int] = (4, 4),
max_tile_value: int = 17):
"""
初始化数据集
Args:
examples: 训练样本列表
board_size: 棋盘大小 (height, width)
max_tile_value: 最大瓦片值的对数 (log2)
"""
self.examples = examples
self.board_size = board_size
self.max_tile_value = max_tile_value
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
获取单个训练样本
Args:
idx: 样本索引
Returns:
(board_tensor, action_tensor, value_tensor)
"""
example = self.examples[idx]
# 将棋盘状态转换为one-hot编码
board_tensor = self._encode_board(example.board_state)
# 动作标签
action_tensor = torch.tensor(example.action, dtype=torch.long)
# 价值标签
value_tensor = torch.tensor(example.value, dtype=torch.float32)
return board_tensor, action_tensor, value_tensor
def _encode_board(self, board: np.ndarray) -> torch.Tensor:
"""
将棋盘状态编码为one-hot张量
Args:
board: 对数变换后的棋盘状态
Returns:
形状为 (max_tile_value + 1, height, width) 的张量
"""
# 创建one-hot编码
# 通道0: 空位 (值为0)
# 通道1: 值为1 (对应原始值2)
# 通道2: 值为2 (对应原始值4)
# ...
channels = self.max_tile_value + 1
height, width = self.board_size
encoded = torch.zeros(channels, height, width, dtype=torch.float32)
for i in range(height):
for j in range(width):
tile_value = int(board[i, j])
if 0 <= tile_value <= self.max_tile_value:
encoded[tile_value, i, j] = 1.0
return encoded
class TrainingDataManager:
"""训练数据管理器整合缓存、持久化和PyTorch集成"""
def __init__(self, data_dir: str = "data/training",
cache_size: int = 1000000,
board_size: Tuple[int, int] = (4, 4)):
"""
初始化数据管理器
Args:
data_dir: 数据存储目录
cache_size: 内存缓存大小
board_size: 棋盘大小
"""
self.cache = TrainingDataCache(cache_size)
self.persistence = TrainingDataPersistence(data_dir)
self.board_size = board_size
self.transform = BoardTransform()
self.score_calc = ScoreCalculator()
def add_training_example(self, board_state: np.ndarray,
action: int, value: float) -> str:
"""
添加训练样本
Args:
board_state: 原始棋盘状态
action: 动作
value: 价值
Returns:
样本的哈希键
"""
# 对数变换
log_board = self.transform.log_transform(board_state)
# 计算规范哈希
canonical_hash = self.transform.compute_hash(log_board)
# 创建训练样本
example = TrainingExample(
board_state=log_board,
action=action,
value=value,
canonical_hash=canonical_hash
)
# 构造缓存键 (状态哈希 + 动作)
cache_key = f"{canonical_hash}_{action}"
# 更新缓存(如果新价值更高)
self.cache.update_if_better(cache_key, example)
return cache_key
def get_pytorch_dataset(self, filter_min_value: float = 0.0) -> Game2048Dataset:
"""
获取PyTorch数据集
Args:
filter_min_value: 最小价值过滤阈值
Returns:
PyTorch数据集
"""
examples = [ex for ex in self.cache.get_all_examples()
if ex.value >= filter_min_value]
return Game2048Dataset(examples, self.board_size)
def get_dataloader(self, batch_size: int = 32, shuffle: bool = True,
filter_min_value: float = 0.0, **kwargs) -> DataLoader:
"""
获取PyTorch DataLoader
Args:
batch_size: 批次大小
shuffle: 是否打乱数据
filter_min_value: 最小价值过滤阈值
**kwargs: 其他DataLoader参数
Returns:
PyTorch DataLoader
"""
dataset = self.get_pytorch_dataset(filter_min_value)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)
def save_current_cache(self, filename: str) -> None:
"""保存当前缓存到硬盘"""
self.persistence.save_cache(self.cache, filename)
def load_from_file(self, filename: str) -> int:
"""
从文件加载训练数据到缓存
Args:
filename: 文件名
Returns:
加载的样本数量
"""
examples = self.persistence.load_cache(filename)
loaded_count = 0
for example in examples:
cache_key = f"{example.canonical_hash}_{example.action}"
if self.cache.update_if_better(cache_key, example):
loaded_count += 1
logger.info(f"Loaded {loaded_count} examples into cache")
return loaded_count
def get_cache_stats(self) -> Dict[str, int]:
"""获取缓存统计信息"""
return {
"cache_size": self.cache.size(),
"max_cache_size": self.cache.max_size,
"saved_files": len(self.persistence.list_saved_files())
}
def merge_caches(self, other_manager: 'TrainingDataManager') -> int:
"""
合并另一个数据管理器的缓存
Args:
other_manager: 另一个数据管理器
Returns:
合并的样本数量
"""
merged_count = 0
for example in other_manager.cache.get_all_examples():
cache_key = f"{example.canonical_hash}_{example.action}"
if self.cache.update_if_better(cache_key, example):
merged_count += 1
logger.info(f"Merged {merged_count} examples from other cache")
return merged_count