""" L0 阶段的纯蒙特卡洛树搜索数据生成 根据论文要求,在3x3小棋盘上使用纯MCTS生成高质量的初始训练数据。 特点: 1. 使用3x3棋盘快速收敛 2. 大量MCTS模拟生成高质量数据 3. 多阶段数据生成和保存 4. 支持断点续传和增量训练 5. 自动数据质量评估和过滤 """ import os import time import json import argparse import logging from datetime import datetime from typing import Dict, List, Tuple, Optional from pathlib import Path import numpy as np import torch from game import Game2048 from torch_mcts import TorchMCTS from training_data import TrainingDataManager, TrainingExample class L0DataGenerator: """L0阶段训练数据生成器""" def __init__(self, config: Dict): """ 初始化L0数据生成器 Args: config: 配置字典 """ self.config = config self.setup_logging() self.setup_directories() # 创建训练数据管理器 self.training_manager = TrainingDataManager( data_dir=self.config['data_dir'], cache_size=self.config['cache_size'], board_size=(self.config['board_height'], self.config['board_width']) ) # 创建MCTS self.mcts = TorchMCTS( c_param=self.config['mcts_c_param'], max_simulation_depth=self.config['max_simulation_depth'], batch_size=self.config.get('batch_size', None), device=self.config.get('device', 'auto'), training_manager=self.training_manager ) # 统计信息 self.stats = { 'total_games': 0, 'total_moves': 0, 'total_simulations': 0, 'total_training_examples': 0, 'start_time': time.time(), 'stage_start_time': time.time(), 'best_scores': [], 'average_scores': [], 'stage_stats': [] } self.logger.info(f"L0数据生成器初始化完成") self.logger.info(f"配置: {json.dumps(config, indent=2)}") def setup_logging(self): """设置日志""" log_dir = Path(self.config['log_dir']) log_dir.mkdir(parents=True, exist_ok=True) log_file = log_dir / f"l0_generation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) self.logger = logging.getLogger(__name__) def setup_directories(self): """设置目录结构""" dirs = ['data_dir', 'log_dir', 'checkpoint_dir'] for dir_key in dirs: Path(self.config[dir_key]).mkdir(parents=True, exist_ok=True) def save_checkpoint(self, stage: int, batch: int): """保存检查点""" # 转换numpy类型为Python原生类型 def convert_numpy_types(obj): if hasattr(obj, 'item'): # numpy scalar return obj.item() elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, dict): return {k: convert_numpy_types(v) for k, v in obj.items()} elif isinstance(obj, list): return [convert_numpy_types(v) for v in obj] elif isinstance(obj, (np.int32, np.int64, np.float32, np.float64)): return obj.item() else: return obj checkpoint = { 'stage': int(stage), 'batch': int(batch), 'stats': convert_numpy_types(self.stats), 'config': convert_numpy_types(self.config), 'timestamp': datetime.now().isoformat() } checkpoint_file = Path(self.config['checkpoint_dir']) / f"checkpoint_stage{stage}_batch{batch}.json" with open(checkpoint_file, 'w', encoding='utf-8') as f: json.dump(checkpoint, f, indent=2, ensure_ascii=False) # 保存训练数据 data_file = f"l0_stage{stage}_batch{batch}" self.training_manager.save_current_cache(data_file) self.logger.info(f"检查点已保存: {checkpoint_file}") def load_checkpoint(self, checkpoint_file: str) -> Tuple[int, int]: """加载检查点""" with open(checkpoint_file, 'r', encoding='utf-8') as f: checkpoint = json.load(f) self.stats = checkpoint['stats'] stage = checkpoint['stage'] batch = checkpoint['batch'] # 加载训练数据 data_file = f"l0_stage{stage}_batch{batch}" self.training_manager.load_from_file(data_file) self.logger.info(f"检查点已加载: {checkpoint_file}") return stage, batch def play_single_game(self, game_id: int, simulations_per_move: int) -> Dict: """ 进行单局游戏 Args: game_id: 游戏ID simulations_per_move: 每步的模拟次数 Returns: 游戏统计信息 """ game = Game2048( height=self.config['board_height'], width=self.config['board_width'], seed=self.config['base_seed'] + game_id ) game_stats = { 'game_id': game_id, 'moves': 0, 'final_score': 0, 'max_tile': 0, 'simulations': 0, 'search_times': [], 'move_scores': [] } move_count = 0 max_moves = self.config['max_moves_per_game'] while not game.is_over and move_count < max_moves: # MCTS搜索 start_time = time.time() best_action, root = self.mcts.search(game, simulations_per_move) search_time = time.time() - start_time if best_action == -1: break # 执行动作 old_score = game.score if game.move(best_action): move_count += 1 score_gain = game.score - old_score game_stats['search_times'].append(search_time) game_stats['move_scores'].append(score_gain) game_stats['simulations'] += simulations_per_move # 记录详细信息(可选) if self.config['verbose'] and move_count % 10 == 0: self.logger.info(f"游戏{game_id} 第{move_count}步: " f"动作={best_action}, 分数={game.score}, " f"最大数字={game.get_max_tile()}") else: break # 更新游戏统计 game_stats['moves'] = move_count game_stats['final_score'] = game.score game_stats['max_tile'] = game.get_max_tile() return game_stats def generate_batch_data(self, stage: int, batch: int, games_per_batch: int, simulations_per_move: int) -> Dict: """ 生成一批训练数据 Args: stage: 阶段编号 batch: 批次编号 games_per_batch: 每批游戏数 simulations_per_move: 每步模拟次数 Returns: 批次统计信息 """ self.logger.info(f"开始生成阶段{stage}批次{batch}数据 " f"({games_per_batch}局游戏, {simulations_per_move}模拟/步)") batch_start_time = time.time() batch_stats = { 'stage': stage, 'batch': batch, 'games': [], 'total_games': games_per_batch, 'total_moves': 0, 'total_simulations': 0, 'avg_score': 0, 'max_score': 0, 'avg_moves': 0, 'generation_time': 0 } # 生成游戏数据 for game_id in range(games_per_batch): global_game_id = self.stats['total_games'] + game_id game_stats = self.play_single_game(global_game_id, simulations_per_move) batch_stats['games'].append(game_stats) batch_stats['total_moves'] += game_stats['moves'] batch_stats['total_simulations'] += game_stats['simulations'] # 进度报告 if (game_id + 1) % max(1, games_per_batch // 10) == 0: progress = (game_id + 1) / games_per_batch * 100 self.logger.info(f"批次进度: {progress:.1f}% " f"({game_id + 1}/{games_per_batch})") # 计算批次统计 scores = [g['final_score'] for g in batch_stats['games']] moves = [g['moves'] for g in batch_stats['games']] batch_stats['avg_score'] = sum(scores) / len(scores) if scores else 0 batch_stats['max_score'] = max(scores) if scores else 0 batch_stats['avg_moves'] = sum(moves) / len(moves) if moves else 0 batch_stats['generation_time'] = time.time() - batch_start_time # 更新全局统计 self.stats['total_games'] += games_per_batch self.stats['total_moves'] += batch_stats['total_moves'] self.stats['total_simulations'] += batch_stats['total_simulations'] self.stats['best_scores'].extend(scores) self.stats['average_scores'].append(batch_stats['avg_score']) self.logger.info(f"批次{batch}完成: 平均分数={batch_stats['avg_score']:.1f}, " f"最高分数={batch_stats['max_score']}, " f"平均步数={batch_stats['avg_moves']:.1f}, " f"用时={batch_stats['generation_time']:.1f}秒") return batch_stats def run_stage(self, stage: int, stage_config: Dict) -> Dict: """ 运行单个训练阶段 Args: stage: 阶段编号 stage_config: 阶段配置 Returns: 阶段统计信息 """ self.logger.info(f"开始阶段{stage}: {stage_config['description']}") stage_start_time = time.time() stage_stats = { 'stage': stage, 'description': stage_config['description'], 'batches': [], 'total_games': 0, 'total_moves': 0, 'total_simulations': 0, 'stage_time': 0, 'data_quality': {} } # 运行批次 for batch in range(stage_config['num_batches']): batch_stats = self.generate_batch_data( stage=stage, batch=batch, games_per_batch=stage_config['games_per_batch'], simulations_per_move=stage_config['simulations_per_move'] ) stage_stats['batches'].append(batch_stats) stage_stats['total_games'] += batch_stats['total_games'] stage_stats['total_moves'] += batch_stats['total_moves'] stage_stats['total_simulations'] += batch_stats['total_simulations'] # 保存检查点 if (batch + 1) % self.config['checkpoint_interval'] == 0: self.save_checkpoint(stage, batch) # 数据质量评估 if (batch + 1) % self.config['quality_check_interval'] == 0: quality_stats = self.evaluate_data_quality() self.logger.info(f"数据质量评估: {quality_stats}") stage_stats['stage_time'] = time.time() - stage_start_time # 保存阶段数据 stage_data_file = f"l0_stage{stage}_complete" self.training_manager.save_current_cache(stage_data_file) # 记录阶段统计 self.stats['stage_stats'].append(stage_stats) self.logger.info(f"阶段{stage}完成: " f"游戏数={stage_stats['total_games']}, " f"移动数={stage_stats['total_moves']}, " f"模拟数={stage_stats['total_simulations']}, " f"用时={stage_stats['stage_time']:.1f}秒") return stage_stats def evaluate_data_quality(self) -> Dict: """评估训练数据质量""" cache_stats = self.training_manager.get_cache_stats() if cache_stats['cache_size'] == 0: return {'error': 'No data to evaluate'} # 获取所有训练样本 examples = self.training_manager.cache.get_all_examples() # 计算质量指标 values = [ex.value for ex in examples] actions = [ex.action for ex in examples] quality_stats = { 'total_examples': len(examples), 'value_stats': { 'mean': float(sum(values) / len(values)) if values else 0, 'min': float(min(values)) if values else 0, 'max': float(max(values)) if values else 0, 'std': float(torch.tensor(values).std().item()) if len(values) > 1 else 0 }, 'action_distribution': { i: actions.count(i) for i in range(4) }, 'unique_states': len(set(ex.canonical_hash for ex in examples)) } return quality_stats def generate_final_report(self) -> Dict: """生成最终报告""" total_time = time.time() - self.stats['start_time'] # 转换numpy类型为Python原生类型 def convert_numpy_types(obj): if hasattr(obj, 'item'): # numpy scalar return obj.item() elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, dict): return {k: convert_numpy_types(v) for k, v in obj.items()} elif isinstance(obj, list): return [convert_numpy_types(v) for v in obj] elif isinstance(obj, (np.int32, np.int64, np.float32, np.float64)): return obj.item() else: return obj report = { 'generation_summary': { 'total_time': total_time, 'total_games': self.stats['total_games'], 'total_moves': self.stats['total_moves'], 'total_simulations': self.stats['total_simulations'], 'games_per_hour': self.stats['total_games'] / (total_time / 3600) if total_time > 0 else 0, 'simulations_per_second': self.stats['total_simulations'] / total_time if total_time > 0 else 0 }, 'data_quality': self.evaluate_data_quality(), 'stage_summary': self.stats['stage_stats'], 'config': self.config } # 转换所有numpy类型 report = convert_numpy_types(report) # 保存报告 report_file = Path(self.config['log_dir']) / f"l0_final_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" with open(report_file, 'w', encoding='utf-8') as f: json.dump(report, f, indent=2, ensure_ascii=False) self.logger.info(f"最终报告已保存: {report_file}") return report def get_default_config() -> Dict: """获取默认配置""" return { # 游戏设置 'board_height': 3, 'board_width': 3, 'base_seed': 42, 'max_moves_per_game': 100, # MCTS设置 'mcts_c_param': 1.414, 'max_simulation_depth': 80, 'num_threads': 1, # 数据管理 'cache_size': 50000, 'data_dir': 'data/l0_training', 'log_dir': 'logs/l0_generation', 'checkpoint_dir': 'checkpoints/l0', # 训练阶段配置 'stages': [ { 'description': '初始探索阶段 - 少量模拟快速生成基础数据', 'num_batches': 5, 'games_per_batch': 20, 'simulations_per_move': 50 }, { 'description': '深度搜索阶段 - 中等模拟生成质量数据', 'num_batches': 10, 'games_per_batch': 30, 'simulations_per_move': 100 }, { 'description': '精细优化阶段 - 大量模拟生成高质量数据', 'num_batches': 15, 'games_per_batch': 40, 'simulations_per_move': 200 }, { 'description': '最终收集阶段 - 超大模拟生成顶级数据', 'num_batches': 10, 'games_per_batch': 50, 'simulations_per_move': 300 } ], # 控制设置 'checkpoint_interval': 2, # 每2个批次保存检查点 'quality_check_interval': 5, # 每5个批次检查数据质量 'verbose': True } def run_l0_generation(config: Dict = None, resume_from: str = None): """ 运行L0训练数据生成 Args: config: 配置字典,None表示使用默认配置 resume_from: 检查点文件路径,用于断点续传 """ if config is None: config = get_default_config() # 创建数据生成器 generator = L0DataGenerator(config) start_stage = 0 start_batch = 0 # 断点续传 if resume_from and Path(resume_from).exists(): start_stage, start_batch = generator.load_checkpoint(resume_from) generator.logger.info(f"从检查点恢复: 阶段{start_stage}, 批次{start_batch}") try: # 运行各个阶段 for stage_idx, stage_config in enumerate(config['stages']): if stage_idx < start_stage: continue generator.run_stage(stage_idx, stage_config) # 生成最终报告 final_report = generator.generate_final_report() generator.logger.info("L0训练数据生成完成!") generator.logger.info(f"总游戏数: {final_report['generation_summary']['total_games']}") generator.logger.info(f"总训练样本: {final_report['data_quality']['total_examples']}") generator.logger.info(f"生成速度: {final_report['generation_summary']['games_per_hour']:.1f} 游戏/小时") return final_report except KeyboardInterrupt: generator.logger.info("用户中断,保存当前进度...") # 保存紧急检查点 emergency_checkpoint = Path(config['checkpoint_dir']) / f"emergency_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" generator.save_checkpoint(-1, -1) # 使用特殊标记 generator.logger.info(f"紧急检查点已保存") except Exception as e: generator.logger.error(f"生成过程中出现错误: {e}") raise def main(): """主函数""" import argparse parser = argparse.ArgumentParser(description='L0阶段MCTS训练数据生成') parser.add_argument('--config', type=str, help='配置文件路径') parser.add_argument('--resume', type=str, help='检查点文件路径') parser.add_argument('--quick', action='store_true', help='快速测试模式') parser.add_argument('--stages', type=int, default=None, help='运行的阶段数') args = parser.parse_args() # 加载配置 if args.config and Path(args.config).exists(): with open(args.config, 'r', encoding='utf-8-sig') as f: config = json.load(f) else: config = get_default_config() # 快速测试模式 if args.quick: config['stages'] = [ { 'description': '快速测试阶段', 'num_batches': 2, 'games_per_batch': 5, 'simulations_per_move': 20 } ] config['data_dir'] = 'data/l0_test' config['log_dir'] = 'logs/l0_test' config['checkpoint_dir'] = 'checkpoints/l0_test' # 限制阶段数 if args.stages is not None: config['stages'] = config['stages'][:args.stages] print("L0训练数据生成器") print("=" * 50) print(f"棋盘大小: {config['board_height']}x{config['board_width']}") print(f"训练阶段数: {len(config['stages'])}") print(f"预计总游戏数: {sum(s['num_batches'] * s['games_per_batch'] for s in config['stages'])}") print(f"数据目录: {config['data_dir']}") if not args.quick: input("按Enter键开始生成...") # 运行生成 final_report = run_l0_generation(config, args.resume) print("\n" + "=" * 50) print("生成完成!") print(f"总游戏数: {final_report['generation_summary']['total_games']}") print(f"总训练样本: {final_report['data_quality']['total_examples']}") print(f"平均分数: {final_report['data_quality']['value_stats']['mean']:.1f}") if __name__ == "__main__": main()