589 lines
20 KiB
Python
589 lines
20 KiB
Python
"""
|
||
L0 阶段的纯蒙特卡洛树搜索数据生成
|
||
|
||
根据论文要求,在3x3小棋盘上使用纯MCTS生成高质量的初始训练数据。
|
||
特点:
|
||
1. 使用3x3棋盘快速收敛
|
||
2. 大量MCTS模拟生成高质量数据
|
||
3. 多阶段数据生成和保存
|
||
4. 支持断点续传和增量训练
|
||
5. 自动数据质量评估和过滤
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import json
|
||
import argparse
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import Dict, List, Tuple, Optional
|
||
from pathlib import Path
|
||
import numpy as np
|
||
import torch
|
||
|
||
from game import Game2048
|
||
from torch_mcts import TorchMCTS
|
||
from training_data import TrainingDataManager, TrainingExample
|
||
|
||
|
||
class L0DataGenerator:
|
||
"""L0阶段训练数据生成器"""
|
||
|
||
def __init__(self, config: Dict):
|
||
"""
|
||
初始化L0数据生成器
|
||
|
||
Args:
|
||
config: 配置字典
|
||
"""
|
||
self.config = config
|
||
self.setup_logging()
|
||
self.setup_directories()
|
||
|
||
# 创建训练数据管理器
|
||
self.training_manager = TrainingDataManager(
|
||
data_dir=self.config['data_dir'],
|
||
cache_size=self.config['cache_size'],
|
||
board_size=(self.config['board_height'], self.config['board_width'])
|
||
)
|
||
|
||
# 创建MCTS
|
||
self.mcts = TorchMCTS(
|
||
c_param=self.config['mcts_c_param'],
|
||
max_simulation_depth=self.config['max_simulation_depth'],
|
||
batch_size=self.config.get('batch_size', None),
|
||
device=self.config.get('device', 'auto'),
|
||
training_manager=self.training_manager
|
||
)
|
||
|
||
# 统计信息
|
||
self.stats = {
|
||
'total_games': 0,
|
||
'total_moves': 0,
|
||
'total_simulations': 0,
|
||
'total_training_examples': 0,
|
||
'start_time': time.time(),
|
||
'stage_start_time': time.time(),
|
||
'best_scores': [],
|
||
'average_scores': [],
|
||
'stage_stats': []
|
||
}
|
||
|
||
self.logger.info(f"L0数据生成器初始化完成")
|
||
self.logger.info(f"配置: {json.dumps(config, indent=2)}")
|
||
|
||
def setup_logging(self):
|
||
"""设置日志"""
|
||
log_dir = Path(self.config['log_dir'])
|
||
log_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
log_file = log_dir / f"l0_generation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler(log_file),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
|
||
self.logger = logging.getLogger(__name__)
|
||
|
||
def setup_directories(self):
|
||
"""设置目录结构"""
|
||
dirs = ['data_dir', 'log_dir', 'checkpoint_dir']
|
||
for dir_key in dirs:
|
||
Path(self.config[dir_key]).mkdir(parents=True, exist_ok=True)
|
||
|
||
def save_checkpoint(self, stage: int, batch: int):
|
||
"""保存检查点"""
|
||
# 转换numpy类型为Python原生类型
|
||
def convert_numpy_types(obj):
|
||
if hasattr(obj, 'item'): # numpy scalar
|
||
return obj.item()
|
||
elif isinstance(obj, np.ndarray):
|
||
return obj.tolist()
|
||
elif isinstance(obj, dict):
|
||
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
||
elif isinstance(obj, list):
|
||
return [convert_numpy_types(v) for v in obj]
|
||
elif isinstance(obj, (np.int32, np.int64, np.float32, np.float64)):
|
||
return obj.item()
|
||
else:
|
||
return obj
|
||
|
||
checkpoint = {
|
||
'stage': int(stage),
|
||
'batch': int(batch),
|
||
'stats': convert_numpy_types(self.stats),
|
||
'config': convert_numpy_types(self.config),
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
|
||
checkpoint_file = Path(self.config['checkpoint_dir']) / f"checkpoint_stage{stage}_batch{batch}.json"
|
||
|
||
with open(checkpoint_file, 'w', encoding='utf-8') as f:
|
||
json.dump(checkpoint, f, indent=2, ensure_ascii=False)
|
||
|
||
# 保存训练数据
|
||
data_file = f"l0_stage{stage}_batch{batch}"
|
||
self.training_manager.save_current_cache(data_file)
|
||
|
||
self.logger.info(f"检查点已保存: {checkpoint_file}")
|
||
|
||
def load_checkpoint(self, checkpoint_file: str) -> Tuple[int, int]:
|
||
"""加载检查点"""
|
||
with open(checkpoint_file, 'r', encoding='utf-8') as f:
|
||
checkpoint = json.load(f)
|
||
|
||
self.stats = checkpoint['stats']
|
||
stage = checkpoint['stage']
|
||
batch = checkpoint['batch']
|
||
|
||
# 加载训练数据
|
||
data_file = f"l0_stage{stage}_batch{batch}"
|
||
self.training_manager.load_from_file(data_file)
|
||
|
||
self.logger.info(f"检查点已加载: {checkpoint_file}")
|
||
return stage, batch
|
||
|
||
def play_single_game(self, game_id: int, simulations_per_move: int) -> Dict:
|
||
"""
|
||
进行单局游戏
|
||
|
||
Args:
|
||
game_id: 游戏ID
|
||
simulations_per_move: 每步的模拟次数
|
||
|
||
Returns:
|
||
游戏统计信息
|
||
"""
|
||
game = Game2048(
|
||
height=self.config['board_height'],
|
||
width=self.config['board_width'],
|
||
seed=self.config['base_seed'] + game_id
|
||
)
|
||
|
||
game_stats = {
|
||
'game_id': game_id,
|
||
'moves': 0,
|
||
'final_score': 0,
|
||
'max_tile': 0,
|
||
'simulations': 0,
|
||
'search_times': [],
|
||
'move_scores': []
|
||
}
|
||
|
||
move_count = 0
|
||
max_moves = self.config['max_moves_per_game']
|
||
|
||
while not game.is_over and move_count < max_moves:
|
||
# MCTS搜索
|
||
start_time = time.time()
|
||
best_action, root = self.mcts.search(game, simulations_per_move)
|
||
search_time = time.time() - start_time
|
||
|
||
if best_action == -1:
|
||
break
|
||
|
||
# 执行动作
|
||
old_score = game.score
|
||
if game.move(best_action):
|
||
move_count += 1
|
||
score_gain = game.score - old_score
|
||
|
||
game_stats['search_times'].append(search_time)
|
||
game_stats['move_scores'].append(score_gain)
|
||
game_stats['simulations'] += simulations_per_move
|
||
|
||
# 记录详细信息(可选)
|
||
if self.config['verbose'] and move_count % 10 == 0:
|
||
self.logger.info(f"游戏{game_id} 第{move_count}步: "
|
||
f"动作={best_action}, 分数={game.score}, "
|
||
f"最大数字={game.get_max_tile()}")
|
||
else:
|
||
break
|
||
|
||
# 更新游戏统计
|
||
game_stats['moves'] = move_count
|
||
game_stats['final_score'] = game.score
|
||
game_stats['max_tile'] = game.get_max_tile()
|
||
|
||
return game_stats
|
||
|
||
def generate_batch_data(self, stage: int, batch: int,
|
||
games_per_batch: int, simulations_per_move: int) -> Dict:
|
||
"""
|
||
生成一批训练数据
|
||
|
||
Args:
|
||
stage: 阶段编号
|
||
batch: 批次编号
|
||
games_per_batch: 每批游戏数
|
||
simulations_per_move: 每步模拟次数
|
||
|
||
Returns:
|
||
批次统计信息
|
||
"""
|
||
self.logger.info(f"开始生成阶段{stage}批次{batch}数据 "
|
||
f"({games_per_batch}局游戏, {simulations_per_move}模拟/步)")
|
||
|
||
batch_start_time = time.time()
|
||
batch_stats = {
|
||
'stage': stage,
|
||
'batch': batch,
|
||
'games': [],
|
||
'total_games': games_per_batch,
|
||
'total_moves': 0,
|
||
'total_simulations': 0,
|
||
'avg_score': 0,
|
||
'max_score': 0,
|
||
'avg_moves': 0,
|
||
'generation_time': 0
|
||
}
|
||
|
||
# 生成游戏数据
|
||
for game_id in range(games_per_batch):
|
||
global_game_id = self.stats['total_games'] + game_id
|
||
game_stats = self.play_single_game(global_game_id, simulations_per_move)
|
||
|
||
batch_stats['games'].append(game_stats)
|
||
batch_stats['total_moves'] += game_stats['moves']
|
||
batch_stats['total_simulations'] += game_stats['simulations']
|
||
|
||
# 进度报告
|
||
if (game_id + 1) % max(1, games_per_batch // 10) == 0:
|
||
progress = (game_id + 1) / games_per_batch * 100
|
||
self.logger.info(f"批次进度: {progress:.1f}% "
|
||
f"({game_id + 1}/{games_per_batch})")
|
||
|
||
# 计算批次统计
|
||
scores = [g['final_score'] for g in batch_stats['games']]
|
||
moves = [g['moves'] for g in batch_stats['games']]
|
||
|
||
batch_stats['avg_score'] = sum(scores) / len(scores) if scores else 0
|
||
batch_stats['max_score'] = max(scores) if scores else 0
|
||
batch_stats['avg_moves'] = sum(moves) / len(moves) if moves else 0
|
||
batch_stats['generation_time'] = time.time() - batch_start_time
|
||
|
||
# 更新全局统计
|
||
self.stats['total_games'] += games_per_batch
|
||
self.stats['total_moves'] += batch_stats['total_moves']
|
||
self.stats['total_simulations'] += batch_stats['total_simulations']
|
||
self.stats['best_scores'].extend(scores)
|
||
self.stats['average_scores'].append(batch_stats['avg_score'])
|
||
|
||
self.logger.info(f"批次{batch}完成: 平均分数={batch_stats['avg_score']:.1f}, "
|
||
f"最高分数={batch_stats['max_score']}, "
|
||
f"平均步数={batch_stats['avg_moves']:.1f}, "
|
||
f"用时={batch_stats['generation_time']:.1f}秒")
|
||
|
||
return batch_stats
|
||
|
||
def run_stage(self, stage: int, stage_config: Dict) -> Dict:
|
||
"""
|
||
运行单个训练阶段
|
||
|
||
Args:
|
||
stage: 阶段编号
|
||
stage_config: 阶段配置
|
||
|
||
Returns:
|
||
阶段统计信息
|
||
"""
|
||
self.logger.info(f"开始阶段{stage}: {stage_config['description']}")
|
||
|
||
stage_start_time = time.time()
|
||
stage_stats = {
|
||
'stage': stage,
|
||
'description': stage_config['description'],
|
||
'batches': [],
|
||
'total_games': 0,
|
||
'total_moves': 0,
|
||
'total_simulations': 0,
|
||
'stage_time': 0,
|
||
'data_quality': {}
|
||
}
|
||
|
||
# 运行批次
|
||
for batch in range(stage_config['num_batches']):
|
||
batch_stats = self.generate_batch_data(
|
||
stage=stage,
|
||
batch=batch,
|
||
games_per_batch=stage_config['games_per_batch'],
|
||
simulations_per_move=stage_config['simulations_per_move']
|
||
)
|
||
|
||
stage_stats['batches'].append(batch_stats)
|
||
stage_stats['total_games'] += batch_stats['total_games']
|
||
stage_stats['total_moves'] += batch_stats['total_moves']
|
||
stage_stats['total_simulations'] += batch_stats['total_simulations']
|
||
|
||
# 保存检查点
|
||
if (batch + 1) % self.config['checkpoint_interval'] == 0:
|
||
self.save_checkpoint(stage, batch)
|
||
|
||
# 数据质量评估
|
||
if (batch + 1) % self.config['quality_check_interval'] == 0:
|
||
quality_stats = self.evaluate_data_quality()
|
||
self.logger.info(f"数据质量评估: {quality_stats}")
|
||
|
||
stage_stats['stage_time'] = time.time() - stage_start_time
|
||
|
||
# 保存阶段数据
|
||
stage_data_file = f"l0_stage{stage}_complete"
|
||
self.training_manager.save_current_cache(stage_data_file)
|
||
|
||
# 记录阶段统计
|
||
self.stats['stage_stats'].append(stage_stats)
|
||
|
||
self.logger.info(f"阶段{stage}完成: "
|
||
f"游戏数={stage_stats['total_games']}, "
|
||
f"移动数={stage_stats['total_moves']}, "
|
||
f"模拟数={stage_stats['total_simulations']}, "
|
||
f"用时={stage_stats['stage_time']:.1f}秒")
|
||
|
||
return stage_stats
|
||
|
||
def evaluate_data_quality(self) -> Dict:
|
||
"""评估训练数据质量"""
|
||
cache_stats = self.training_manager.get_cache_stats()
|
||
|
||
if cache_stats['cache_size'] == 0:
|
||
return {'error': 'No data to evaluate'}
|
||
|
||
# 获取所有训练样本
|
||
examples = self.training_manager.cache.get_all_examples()
|
||
|
||
# 计算质量指标
|
||
values = [ex.value for ex in examples]
|
||
actions = [ex.action for ex in examples]
|
||
|
||
quality_stats = {
|
||
'total_examples': len(examples),
|
||
'value_stats': {
|
||
'mean': float(sum(values) / len(values)) if values else 0,
|
||
'min': float(min(values)) if values else 0,
|
||
'max': float(max(values)) if values else 0,
|
||
'std': float(torch.tensor(values).std().item()) if len(values) > 1 else 0
|
||
},
|
||
'action_distribution': {
|
||
i: actions.count(i) for i in range(4)
|
||
},
|
||
'unique_states': len(set(ex.canonical_hash for ex in examples))
|
||
}
|
||
|
||
return quality_stats
|
||
|
||
def generate_final_report(self) -> Dict:
|
||
"""生成最终报告"""
|
||
total_time = time.time() - self.stats['start_time']
|
||
|
||
# 转换numpy类型为Python原生类型
|
||
def convert_numpy_types(obj):
|
||
if hasattr(obj, 'item'): # numpy scalar
|
||
return obj.item()
|
||
elif isinstance(obj, np.ndarray):
|
||
return obj.tolist()
|
||
elif isinstance(obj, dict):
|
||
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
||
elif isinstance(obj, list):
|
||
return [convert_numpy_types(v) for v in obj]
|
||
elif isinstance(obj, (np.int32, np.int64, np.float32, np.float64)):
|
||
return obj.item()
|
||
else:
|
||
return obj
|
||
|
||
report = {
|
||
'generation_summary': {
|
||
'total_time': total_time,
|
||
'total_games': self.stats['total_games'],
|
||
'total_moves': self.stats['total_moves'],
|
||
'total_simulations': self.stats['total_simulations'],
|
||
'games_per_hour': self.stats['total_games'] / (total_time / 3600) if total_time > 0 else 0,
|
||
'simulations_per_second': self.stats['total_simulations'] / total_time if total_time > 0 else 0
|
||
},
|
||
'data_quality': self.evaluate_data_quality(),
|
||
'stage_summary': self.stats['stage_stats'],
|
||
'config': self.config
|
||
}
|
||
|
||
# 转换所有numpy类型
|
||
report = convert_numpy_types(report)
|
||
|
||
# 保存报告
|
||
report_file = Path(self.config['log_dir']) / f"l0_final_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||
with open(report_file, 'w', encoding='utf-8') as f:
|
||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||
|
||
self.logger.info(f"最终报告已保存: {report_file}")
|
||
return report
|
||
|
||
|
||
def get_default_config() -> Dict:
|
||
"""获取默认配置"""
|
||
return {
|
||
# 游戏设置
|
||
'board_height': 3,
|
||
'board_width': 3,
|
||
'base_seed': 42,
|
||
'max_moves_per_game': 100,
|
||
|
||
# MCTS设置
|
||
'mcts_c_param': 1.414,
|
||
'max_simulation_depth': 80,
|
||
'num_threads': 1,
|
||
|
||
# 数据管理
|
||
'cache_size': 50000,
|
||
'data_dir': 'data/l0_training',
|
||
'log_dir': 'logs/l0_generation',
|
||
'checkpoint_dir': 'checkpoints/l0',
|
||
|
||
# 训练阶段配置
|
||
'stages': [
|
||
{
|
||
'description': '初始探索阶段 - 少量模拟快速生成基础数据',
|
||
'num_batches': 5,
|
||
'games_per_batch': 20,
|
||
'simulations_per_move': 50
|
||
},
|
||
{
|
||
'description': '深度搜索阶段 - 中等模拟生成质量数据',
|
||
'num_batches': 10,
|
||
'games_per_batch': 30,
|
||
'simulations_per_move': 100
|
||
},
|
||
{
|
||
'description': '精细优化阶段 - 大量模拟生成高质量数据',
|
||
'num_batches': 15,
|
||
'games_per_batch': 40,
|
||
'simulations_per_move': 200
|
||
},
|
||
{
|
||
'description': '最终收集阶段 - 超大模拟生成顶级数据',
|
||
'num_batches': 10,
|
||
'games_per_batch': 50,
|
||
'simulations_per_move': 300
|
||
}
|
||
],
|
||
|
||
# 控制设置
|
||
'checkpoint_interval': 2, # 每2个批次保存检查点
|
||
'quality_check_interval': 5, # 每5个批次检查数据质量
|
||
'verbose': True
|
||
}
|
||
|
||
|
||
def run_l0_generation(config: Dict = None, resume_from: str = None):
|
||
"""
|
||
运行L0训练数据生成
|
||
|
||
Args:
|
||
config: 配置字典,None表示使用默认配置
|
||
resume_from: 检查点文件路径,用于断点续传
|
||
"""
|
||
if config is None:
|
||
config = get_default_config()
|
||
|
||
# 创建数据生成器
|
||
generator = L0DataGenerator(config)
|
||
|
||
start_stage = 0
|
||
start_batch = 0
|
||
|
||
# 断点续传
|
||
if resume_from and Path(resume_from).exists():
|
||
start_stage, start_batch = generator.load_checkpoint(resume_from)
|
||
generator.logger.info(f"从检查点恢复: 阶段{start_stage}, 批次{start_batch}")
|
||
|
||
try:
|
||
# 运行各个阶段
|
||
for stage_idx, stage_config in enumerate(config['stages']):
|
||
if stage_idx < start_stage:
|
||
continue
|
||
|
||
generator.run_stage(stage_idx, stage_config)
|
||
|
||
# 生成最终报告
|
||
final_report = generator.generate_final_report()
|
||
|
||
generator.logger.info("L0训练数据生成完成!")
|
||
generator.logger.info(f"总游戏数: {final_report['generation_summary']['total_games']}")
|
||
generator.logger.info(f"总训练样本: {final_report['data_quality']['total_examples']}")
|
||
generator.logger.info(f"生成速度: {final_report['generation_summary']['games_per_hour']:.1f} 游戏/小时")
|
||
|
||
return final_report
|
||
|
||
except KeyboardInterrupt:
|
||
generator.logger.info("用户中断,保存当前进度...")
|
||
# 保存紧急检查点
|
||
emergency_checkpoint = Path(config['checkpoint_dir']) / f"emergency_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||
generator.save_checkpoint(-1, -1) # 使用特殊标记
|
||
generator.logger.info(f"紧急检查点已保存")
|
||
|
||
except Exception as e:
|
||
generator.logger.error(f"生成过程中出现错误: {e}")
|
||
raise
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='L0阶段MCTS训练数据生成')
|
||
parser.add_argument('--config', type=str, help='配置文件路径')
|
||
parser.add_argument('--resume', type=str, help='检查点文件路径')
|
||
parser.add_argument('--quick', action='store_true', help='快速测试模式')
|
||
parser.add_argument('--stages', type=int, default=None, help='运行的阶段数')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 加载配置
|
||
if args.config and Path(args.config).exists():
|
||
with open(args.config, 'r', encoding='utf-8-sig') as f:
|
||
config = json.load(f)
|
||
else:
|
||
config = get_default_config()
|
||
|
||
# 快速测试模式
|
||
if args.quick:
|
||
config['stages'] = [
|
||
{
|
||
'description': '快速测试阶段',
|
||
'num_batches': 2,
|
||
'games_per_batch': 5,
|
||
'simulations_per_move': 20
|
||
}
|
||
]
|
||
config['data_dir'] = 'data/l0_test'
|
||
config['log_dir'] = 'logs/l0_test'
|
||
config['checkpoint_dir'] = 'checkpoints/l0_test'
|
||
|
||
# 限制阶段数
|
||
if args.stages is not None:
|
||
config['stages'] = config['stages'][:args.stages]
|
||
|
||
print("L0训练数据生成器")
|
||
print("=" * 50)
|
||
print(f"棋盘大小: {config['board_height']}x{config['board_width']}")
|
||
print(f"训练阶段数: {len(config['stages'])}")
|
||
print(f"预计总游戏数: {sum(s['num_batches'] * s['games_per_batch'] for s in config['stages'])}")
|
||
print(f"数据目录: {config['data_dir']}")
|
||
|
||
if not args.quick:
|
||
input("按Enter键开始生成...")
|
||
|
||
# 运行生成
|
||
final_report = run_l0_generation(config, args.resume)
|
||
|
||
print("\n" + "=" * 50)
|
||
print("生成完成!")
|
||
print(f"总游戏数: {final_report['generation_summary']['total_games']}")
|
||
print(f"总训练样本: {final_report['data_quality']['total_examples']}")
|
||
print(f"平均分数: {final_report['data_quality']['value_stats']['mean']:.1f}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |