增加L0训练阶段的MCTS部分
This commit is contained in:
575
training_data.py
Normal file
575
training_data.py
Normal file
@@ -0,0 +1,575 @@
|
||||
"""
|
||||
训练数据结构模块
|
||||
|
||||
实现2048游戏的训练数据结构,包括:
|
||||
1. 棋盘状态的对数变换
|
||||
2. 二面体群D4的8种变换(棋盘压缩)
|
||||
3. 训练数据的内存缓存和硬盘持久化
|
||||
4. 与PyTorch生态的集成
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import pickle
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Tuple, List, Dict, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingExample:
|
||||
"""单个训练样本的数据结构"""
|
||||
board_state: np.ndarray # 棋盘状态 (H, W)
|
||||
action: int # 动作 (0:上, 1:下, 2:左, 3:右)
|
||||
value: float # 该状态-动作对的价值
|
||||
canonical_hash: str # 规范化后的哈希值
|
||||
|
||||
|
||||
class BoardTransform:
|
||||
"""棋盘变换工具类,实现二面体群D4的8种变换"""
|
||||
|
||||
@staticmethod
|
||||
def log_transform(board: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
对棋盘进行对数变换
|
||||
|
||||
Args:
|
||||
board: 原始棋盘状态,包含2的幂次数字
|
||||
|
||||
Returns:
|
||||
对数变换后的棋盘,空位为0,其他位置为log2(value)
|
||||
"""
|
||||
result = np.zeros_like(board, dtype=np.int32)
|
||||
mask = board > 0
|
||||
result[mask] = np.log2(board[mask]).astype(np.int32)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def inverse_log_transform(log_board: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
对数变换的逆变换
|
||||
|
||||
Args:
|
||||
log_board: 对数变换后的棋盘
|
||||
|
||||
Returns:
|
||||
原始棋盘状态
|
||||
"""
|
||||
result = np.zeros_like(log_board, dtype=np.int32)
|
||||
mask = log_board > 0
|
||||
result[mask] = (2 ** log_board[mask]).astype(np.int32)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def rotate_90(matrix: np.ndarray) -> np.ndarray:
|
||||
"""顺时针旋转90度"""
|
||||
return np.rot90(matrix, k=-1)
|
||||
|
||||
@staticmethod
|
||||
def flip_horizontal(matrix: np.ndarray) -> np.ndarray:
|
||||
"""水平翻转"""
|
||||
return np.fliplr(matrix)
|
||||
|
||||
@classmethod
|
||||
def get_all_transforms(cls, matrix: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
获取二面体群D4的所有8种变换
|
||||
|
||||
Args:
|
||||
matrix: 输入矩阵
|
||||
|
||||
Returns:
|
||||
包含8种变换结果的列表
|
||||
"""
|
||||
transforms = []
|
||||
|
||||
# 原始图像 (R0)
|
||||
r0 = matrix.copy()
|
||||
transforms.append(r0)
|
||||
|
||||
# 旋转90° (R90)
|
||||
r90 = cls.rotate_90(r0)
|
||||
transforms.append(r90)
|
||||
|
||||
# 旋转180° (R180)
|
||||
r180 = cls.rotate_90(r90)
|
||||
transforms.append(r180)
|
||||
|
||||
# 旋转270° (R270)
|
||||
r270 = cls.rotate_90(r180)
|
||||
transforms.append(r270)
|
||||
|
||||
# 水平翻转 (F)
|
||||
f = cls.flip_horizontal(r0)
|
||||
transforms.append(f)
|
||||
|
||||
# 翻转后旋转90° (F+R90)
|
||||
fr90 = cls.rotate_90(f)
|
||||
transforms.append(fr90)
|
||||
|
||||
# 翻转后旋转180° (F+R180)
|
||||
fr180 = cls.rotate_90(fr90)
|
||||
transforms.append(fr180)
|
||||
|
||||
# 翻转后旋转270° (F+R270)
|
||||
fr270 = cls.rotate_90(fr180)
|
||||
transforms.append(fr270)
|
||||
|
||||
return transforms
|
||||
|
||||
@classmethod
|
||||
def get_canonical_form(cls, matrix: np.ndarray) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
获取矩阵的规范形式(字典序最小的变换)
|
||||
|
||||
Args:
|
||||
matrix: 输入矩阵
|
||||
|
||||
Returns:
|
||||
(规范形式矩阵, 变换索引)
|
||||
"""
|
||||
transforms = cls.get_all_transforms(matrix)
|
||||
|
||||
# 将每个变换拉平为1D向量并比较字典序
|
||||
flattened = [t.flatten() for t in transforms]
|
||||
|
||||
# 找到字典序最小的索引
|
||||
min_idx = 0
|
||||
min_flat = flattened[0]
|
||||
|
||||
for i, flat in enumerate(flattened[1:], 1):
|
||||
# 逐元素比较字典序
|
||||
if cls._is_lexicographically_smaller(flat, min_flat):
|
||||
min_idx = i
|
||||
min_flat = flat
|
||||
|
||||
return transforms[min_idx], min_idx
|
||||
|
||||
@staticmethod
|
||||
def _is_lexicographically_smaller(a: np.ndarray, b: np.ndarray) -> bool:
|
||||
"""
|
||||
检查数组a是否在字典序上小于数组b
|
||||
|
||||
Args:
|
||||
a, b: 要比较的数组
|
||||
|
||||
Returns:
|
||||
如果a < b则返回True
|
||||
"""
|
||||
for i in range(min(len(a), len(b))):
|
||||
if a[i] < b[i]:
|
||||
return True
|
||||
elif a[i] > b[i]:
|
||||
return False
|
||||
# 如果前面都相等,较短的数组更小
|
||||
return len(a) < len(b)
|
||||
|
||||
@classmethod
|
||||
def compute_hash(cls, matrix: np.ndarray) -> str:
|
||||
"""
|
||||
计算矩阵规范形式的哈希值
|
||||
|
||||
Args:
|
||||
matrix: 输入矩阵
|
||||
|
||||
Returns:
|
||||
哈希字符串
|
||||
"""
|
||||
canonical, _ = cls.get_canonical_form(matrix)
|
||||
# 使用规范形式的字节表示计算哈希
|
||||
return hashlib.md5(canonical.tobytes()).hexdigest()
|
||||
|
||||
|
||||
class ScoreCalculator:
|
||||
"""分数计算工具类"""
|
||||
|
||||
@staticmethod
|
||||
def calculate_tile_value(tile_log: int) -> int:
|
||||
"""
|
||||
计算单个瓦片的累积分数价值
|
||||
|
||||
Args:
|
||||
tile_log: 瓦片的对数值 (log2(tile_value))
|
||||
|
||||
Returns:
|
||||
累积分数价值
|
||||
"""
|
||||
if tile_log <= 1: # 对应原始值2或更小
|
||||
return 0
|
||||
|
||||
# V(N) = (log2(N) - 1) * N,其中N = 2^tile_log
|
||||
n = 2 ** tile_log
|
||||
return (tile_log - 1) * n
|
||||
|
||||
@classmethod
|
||||
def calculate_board_score(cls, log_board: np.ndarray) -> int:
|
||||
"""
|
||||
计算整个棋盘的累积分数
|
||||
|
||||
Args:
|
||||
log_board: 对数变换后的棋盘
|
||||
|
||||
Returns:
|
||||
总累积分数
|
||||
"""
|
||||
total_score = 0
|
||||
for tile_log in log_board.flatten():
|
||||
if tile_log > 0:
|
||||
total_score += cls.calculate_tile_value(tile_log)
|
||||
return total_score
|
||||
|
||||
|
||||
class TrainingDataCache:
|
||||
"""训练数据的内存缓存系统"""
|
||||
|
||||
def __init__(self, max_size: int = 1000000):
|
||||
"""
|
||||
初始化缓存
|
||||
|
||||
Args:
|
||||
max_size: 最大缓存条目数
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self.cache: Dict[str, TrainingExample] = {}
|
||||
self.access_order: List[str] = [] # 用于LRU淘汰
|
||||
|
||||
def get(self, key: str) -> Optional[TrainingExample]:
|
||||
"""获取缓存项"""
|
||||
if key in self.cache:
|
||||
# 更新访问顺序
|
||||
self.access_order.remove(key)
|
||||
self.access_order.append(key)
|
||||
return self.cache[key]
|
||||
return None
|
||||
|
||||
def put(self, key: str, example: TrainingExample) -> None:
|
||||
"""添加或更新缓存项"""
|
||||
if key in self.cache:
|
||||
# 更新现有项
|
||||
self.cache[key] = example
|
||||
self.access_order.remove(key)
|
||||
self.access_order.append(key)
|
||||
else:
|
||||
# 添加新项
|
||||
if len(self.cache) >= self.max_size:
|
||||
# LRU淘汰
|
||||
oldest_key = self.access_order.pop(0)
|
||||
del self.cache[oldest_key]
|
||||
|
||||
self.cache[key] = example
|
||||
self.access_order.append(key)
|
||||
|
||||
def update_if_better(self, key: str, example: TrainingExample) -> bool:
|
||||
"""如果新样本的价值更高,则更新缓存"""
|
||||
existing = self.get(key)
|
||||
if existing is None or example.value > existing.value:
|
||||
self.put(key, example)
|
||||
return True
|
||||
return False
|
||||
|
||||
def size(self) -> int:
|
||||
"""返回缓存大小"""
|
||||
return len(self.cache)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空缓存"""
|
||||
self.cache.clear()
|
||||
self.access_order.clear()
|
||||
|
||||
def get_all_examples(self) -> List[TrainingExample]:
|
||||
"""获取所有缓存的训练样本"""
|
||||
return list(self.cache.values())
|
||||
|
||||
|
||||
class TrainingDataPersistence:
|
||||
"""训练数据的硬盘持久化系统"""
|
||||
|
||||
def __init__(self, data_dir: str = "data/training"):
|
||||
"""
|
||||
初始化持久化系统
|
||||
|
||||
Args:
|
||||
data_dir: 数据存储目录
|
||||
"""
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def save_cache(self, cache: TrainingDataCache, filename: str) -> None:
|
||||
"""
|
||||
保存缓存到硬盘
|
||||
|
||||
Args:
|
||||
cache: 要保存的缓存
|
||||
filename: 文件名
|
||||
"""
|
||||
filepath = self.data_dir / f"{filename}.pkl"
|
||||
examples = cache.get_all_examples()
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
pickle.dump(examples, f)
|
||||
|
||||
logger.info(f"Saved {len(examples)} examples to {filepath}")
|
||||
|
||||
def load_cache(self, filename: str) -> List[TrainingExample]:
|
||||
"""
|
||||
从硬盘加载训练数据
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
训练样本列表
|
||||
"""
|
||||
filepath = self.data_dir / f"{filename}.pkl"
|
||||
|
||||
if not filepath.exists():
|
||||
logger.warning(f"File {filepath} does not exist")
|
||||
return []
|
||||
|
||||
with open(filepath, 'rb') as f:
|
||||
examples = pickle.load(f)
|
||||
|
||||
logger.info(f"Loaded {len(examples)} examples from {filepath}")
|
||||
return examples
|
||||
|
||||
def save_examples_batch(self, examples: List[TrainingExample],
|
||||
batch_name: str) -> None:
|
||||
"""
|
||||
批量保存训练样本
|
||||
|
||||
Args:
|
||||
examples: 训练样本列表
|
||||
batch_name: 批次名称
|
||||
"""
|
||||
self.save_cache(TrainingDataCache(), batch_name)
|
||||
# 直接保存examples列表
|
||||
filepath = self.data_dir / f"{batch_name}.pkl"
|
||||
with open(filepath, 'wb') as f:
|
||||
pickle.dump(examples, f)
|
||||
|
||||
logger.info(f"Saved batch {batch_name} with {len(examples)} examples")
|
||||
|
||||
def list_saved_files(self) -> List[str]:
|
||||
"""列出所有保存的数据文件"""
|
||||
return [f.stem for f in self.data_dir.glob("*.pkl")]
|
||||
|
||||
|
||||
class Game2048Dataset(Dataset):
|
||||
"""PyTorch Dataset for 2048 training data"""
|
||||
|
||||
def __init__(self, examples: List[TrainingExample],
|
||||
board_size: Tuple[int, int] = (4, 4),
|
||||
max_tile_value: int = 17):
|
||||
"""
|
||||
初始化数据集
|
||||
|
||||
Args:
|
||||
examples: 训练样本列表
|
||||
board_size: 棋盘大小 (height, width)
|
||||
max_tile_value: 最大瓦片值的对数 (log2)
|
||||
"""
|
||||
self.examples = examples
|
||||
self.board_size = board_size
|
||||
self.max_tile_value = max_tile_value
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
获取单个训练样本
|
||||
|
||||
Args:
|
||||
idx: 样本索引
|
||||
|
||||
Returns:
|
||||
(board_tensor, action_tensor, value_tensor)
|
||||
"""
|
||||
example = self.examples[idx]
|
||||
|
||||
# 将棋盘状态转换为one-hot编码
|
||||
board_tensor = self._encode_board(example.board_state)
|
||||
|
||||
# 动作标签
|
||||
action_tensor = torch.tensor(example.action, dtype=torch.long)
|
||||
|
||||
# 价值标签
|
||||
value_tensor = torch.tensor(example.value, dtype=torch.float32)
|
||||
|
||||
return board_tensor, action_tensor, value_tensor
|
||||
|
||||
def _encode_board(self, board: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
将棋盘状态编码为one-hot张量
|
||||
|
||||
Args:
|
||||
board: 对数变换后的棋盘状态
|
||||
|
||||
Returns:
|
||||
形状为 (max_tile_value + 1, height, width) 的张量
|
||||
"""
|
||||
# 创建one-hot编码
|
||||
# 通道0: 空位 (值为0)
|
||||
# 通道1: 值为1 (对应原始值2)
|
||||
# 通道2: 值为2 (对应原始值4)
|
||||
# ...
|
||||
channels = self.max_tile_value + 1
|
||||
height, width = self.board_size
|
||||
|
||||
encoded = torch.zeros(channels, height, width, dtype=torch.float32)
|
||||
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
tile_value = int(board[i, j])
|
||||
if 0 <= tile_value <= self.max_tile_value:
|
||||
encoded[tile_value, i, j] = 1.0
|
||||
|
||||
return encoded
|
||||
|
||||
|
||||
class TrainingDataManager:
|
||||
"""训练数据管理器,整合缓存、持久化和PyTorch集成"""
|
||||
|
||||
def __init__(self, data_dir: str = "data/training",
|
||||
cache_size: int = 1000000,
|
||||
board_size: Tuple[int, int] = (4, 4)):
|
||||
"""
|
||||
初始化数据管理器
|
||||
|
||||
Args:
|
||||
data_dir: 数据存储目录
|
||||
cache_size: 内存缓存大小
|
||||
board_size: 棋盘大小
|
||||
"""
|
||||
self.cache = TrainingDataCache(cache_size)
|
||||
self.persistence = TrainingDataPersistence(data_dir)
|
||||
self.board_size = board_size
|
||||
self.transform = BoardTransform()
|
||||
self.score_calc = ScoreCalculator()
|
||||
|
||||
def add_training_example(self, board_state: np.ndarray,
|
||||
action: int, value: float) -> str:
|
||||
"""
|
||||
添加训练样本
|
||||
|
||||
Args:
|
||||
board_state: 原始棋盘状态
|
||||
action: 动作
|
||||
value: 价值
|
||||
|
||||
Returns:
|
||||
样本的哈希键
|
||||
"""
|
||||
# 对数变换
|
||||
log_board = self.transform.log_transform(board_state)
|
||||
|
||||
# 计算规范哈希
|
||||
canonical_hash = self.transform.compute_hash(log_board)
|
||||
|
||||
# 创建训练样本
|
||||
example = TrainingExample(
|
||||
board_state=log_board,
|
||||
action=action,
|
||||
value=value,
|
||||
canonical_hash=canonical_hash
|
||||
)
|
||||
|
||||
# 构造缓存键 (状态哈希 + 动作)
|
||||
cache_key = f"{canonical_hash}_{action}"
|
||||
|
||||
# 更新缓存(如果新价值更高)
|
||||
self.cache.update_if_better(cache_key, example)
|
||||
|
||||
return cache_key
|
||||
|
||||
def get_pytorch_dataset(self, filter_min_value: float = 0.0) -> Game2048Dataset:
|
||||
"""
|
||||
获取PyTorch数据集
|
||||
|
||||
Args:
|
||||
filter_min_value: 最小价值过滤阈值
|
||||
|
||||
Returns:
|
||||
PyTorch数据集
|
||||
"""
|
||||
examples = [ex for ex in self.cache.get_all_examples()
|
||||
if ex.value >= filter_min_value]
|
||||
|
||||
return Game2048Dataset(examples, self.board_size)
|
||||
|
||||
def get_dataloader(self, batch_size: int = 32, shuffle: bool = True,
|
||||
filter_min_value: float = 0.0, **kwargs) -> DataLoader:
|
||||
"""
|
||||
获取PyTorch DataLoader
|
||||
|
||||
Args:
|
||||
batch_size: 批次大小
|
||||
shuffle: 是否打乱数据
|
||||
filter_min_value: 最小价值过滤阈值
|
||||
**kwargs: 其他DataLoader参数
|
||||
|
||||
Returns:
|
||||
PyTorch DataLoader
|
||||
"""
|
||||
dataset = self.get_pytorch_dataset(filter_min_value)
|
||||
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)
|
||||
|
||||
def save_current_cache(self, filename: str) -> None:
|
||||
"""保存当前缓存到硬盘"""
|
||||
self.persistence.save_cache(self.cache, filename)
|
||||
|
||||
def load_from_file(self, filename: str) -> int:
|
||||
"""
|
||||
从文件加载训练数据到缓存
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
加载的样本数量
|
||||
"""
|
||||
examples = self.persistence.load_cache(filename)
|
||||
|
||||
loaded_count = 0
|
||||
for example in examples:
|
||||
cache_key = f"{example.canonical_hash}_{example.action}"
|
||||
if self.cache.update_if_better(cache_key, example):
|
||||
loaded_count += 1
|
||||
|
||||
logger.info(f"Loaded {loaded_count} examples into cache")
|
||||
return loaded_count
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, int]:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"cache_size": self.cache.size(),
|
||||
"max_cache_size": self.cache.max_size,
|
||||
"saved_files": len(self.persistence.list_saved_files())
|
||||
}
|
||||
|
||||
def merge_caches(self, other_manager: 'TrainingDataManager') -> int:
|
||||
"""
|
||||
合并另一个数据管理器的缓存
|
||||
|
||||
Args:
|
||||
other_manager: 另一个数据管理器
|
||||
|
||||
Returns:
|
||||
合并的样本数量
|
||||
"""
|
||||
merged_count = 0
|
||||
for example in other_manager.cache.get_all_examples():
|
||||
cache_key = f"{example.canonical_hash}_{example.action}"
|
||||
if self.cache.update_if_better(cache_key, example):
|
||||
merged_count += 1
|
||||
|
||||
logger.info(f"Merged {merged_count} examples from other cache")
|
||||
return merged_count
|
||||
Reference in New Issue
Block a user