Files
deep2048/l0_play.py
2025-07-23 07:04:10 +08:00

589 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
L0 阶段的纯蒙特卡洛树搜索数据生成
根据论文要求在3x3小棋盘上使用纯MCTS生成高质量的初始训练数据。
特点:
1. 使用3x3棋盘快速收敛
2. 大量MCTS模拟生成高质量数据
3. 多阶段数据生成和保存
4. 支持断点续传和增量训练
5. 自动数据质量评估和过滤
"""
import os
import time
import json
import argparse
import logging
from datetime import datetime
from typing import Dict, List, Tuple, Optional
from pathlib import Path
import numpy as np
import torch
from game import Game2048
from torch_mcts import TorchMCTS
from training_data import TrainingDataManager, TrainingExample
class L0DataGenerator:
"""L0阶段训练数据生成器"""
def __init__(self, config: Dict):
"""
初始化L0数据生成器
Args:
config: 配置字典
"""
self.config = config
self.setup_logging()
self.setup_directories()
# 创建训练数据管理器
self.training_manager = TrainingDataManager(
data_dir=self.config['data_dir'],
cache_size=self.config['cache_size'],
board_size=(self.config['board_height'], self.config['board_width'])
)
# 创建MCTS
self.mcts = TorchMCTS(
c_param=self.config['mcts_c_param'],
max_simulation_depth=self.config['max_simulation_depth'],
batch_size=self.config.get('batch_size', None),
device=self.config.get('device', 'auto'),
training_manager=self.training_manager
)
# 统计信息
self.stats = {
'total_games': 0,
'total_moves': 0,
'total_simulations': 0,
'total_training_examples': 0,
'start_time': time.time(),
'stage_start_time': time.time(),
'best_scores': [],
'average_scores': [],
'stage_stats': []
}
self.logger.info(f"L0数据生成器初始化完成")
self.logger.info(f"配置: {json.dumps(config, indent=2)}")
def setup_logging(self):
"""设置日志"""
log_dir = Path(self.config['log_dir'])
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"l0_generation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
self.logger = logging.getLogger(__name__)
def setup_directories(self):
"""设置目录结构"""
dirs = ['data_dir', 'log_dir', 'checkpoint_dir']
for dir_key in dirs:
Path(self.config[dir_key]).mkdir(parents=True, exist_ok=True)
def save_checkpoint(self, stage: int, batch: int):
"""保存检查点"""
# 转换numpy类型为Python原生类型
def convert_numpy_types(obj):
if hasattr(obj, 'item'): # numpy scalar
return obj.item()
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, dict):
return {k: convert_numpy_types(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_numpy_types(v) for v in obj]
elif isinstance(obj, (np.int32, np.int64, np.float32, np.float64)):
return obj.item()
else:
return obj
checkpoint = {
'stage': int(stage),
'batch': int(batch),
'stats': convert_numpy_types(self.stats),
'config': convert_numpy_types(self.config),
'timestamp': datetime.now().isoformat()
}
checkpoint_file = Path(self.config['checkpoint_dir']) / f"checkpoint_stage{stage}_batch{batch}.json"
with open(checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint, f, indent=2, ensure_ascii=False)
# 保存训练数据
data_file = f"l0_stage{stage}_batch{batch}"
self.training_manager.save_current_cache(data_file)
self.logger.info(f"检查点已保存: {checkpoint_file}")
def load_checkpoint(self, checkpoint_file: str) -> Tuple[int, int]:
"""加载检查点"""
with open(checkpoint_file, 'r', encoding='utf-8') as f:
checkpoint = json.load(f)
self.stats = checkpoint['stats']
stage = checkpoint['stage']
batch = checkpoint['batch']
# 加载训练数据
data_file = f"l0_stage{stage}_batch{batch}"
self.training_manager.load_from_file(data_file)
self.logger.info(f"检查点已加载: {checkpoint_file}")
return stage, batch
def play_single_game(self, game_id: int, simulations_per_move: int) -> Dict:
"""
进行单局游戏
Args:
game_id: 游戏ID
simulations_per_move: 每步的模拟次数
Returns:
游戏统计信息
"""
game = Game2048(
height=self.config['board_height'],
width=self.config['board_width'],
seed=self.config['base_seed'] + game_id
)
game_stats = {
'game_id': game_id,
'moves': 0,
'final_score': 0,
'max_tile': 0,
'simulations': 0,
'search_times': [],
'move_scores': []
}
move_count = 0
max_moves = self.config['max_moves_per_game']
while not game.is_over and move_count < max_moves:
# MCTS搜索
start_time = time.time()
best_action, root = self.mcts.search(game, simulations_per_move)
search_time = time.time() - start_time
if best_action == -1:
break
# 执行动作
old_score = game.score
if game.move(best_action):
move_count += 1
score_gain = game.score - old_score
game_stats['search_times'].append(search_time)
game_stats['move_scores'].append(score_gain)
game_stats['simulations'] += simulations_per_move
# 记录详细信息(可选)
if self.config['verbose'] and move_count % 10 == 0:
self.logger.info(f"游戏{game_id}{move_count}步: "
f"动作={best_action}, 分数={game.score}, "
f"最大数字={game.get_max_tile()}")
else:
break
# 更新游戏统计
game_stats['moves'] = move_count
game_stats['final_score'] = game.score
game_stats['max_tile'] = game.get_max_tile()
return game_stats
def generate_batch_data(self, stage: int, batch: int,
games_per_batch: int, simulations_per_move: int) -> Dict:
"""
生成一批训练数据
Args:
stage: 阶段编号
batch: 批次编号
games_per_batch: 每批游戏数
simulations_per_move: 每步模拟次数
Returns:
批次统计信息
"""
self.logger.info(f"开始生成阶段{stage}批次{batch}数据 "
f"({games_per_batch}局游戏, {simulations_per_move}模拟/步)")
batch_start_time = time.time()
batch_stats = {
'stage': stage,
'batch': batch,
'games': [],
'total_games': games_per_batch,
'total_moves': 0,
'total_simulations': 0,
'avg_score': 0,
'max_score': 0,
'avg_moves': 0,
'generation_time': 0
}
# 生成游戏数据
for game_id in range(games_per_batch):
global_game_id = self.stats['total_games'] + game_id
game_stats = self.play_single_game(global_game_id, simulations_per_move)
batch_stats['games'].append(game_stats)
batch_stats['total_moves'] += game_stats['moves']
batch_stats['total_simulations'] += game_stats['simulations']
# 进度报告
if (game_id + 1) % max(1, games_per_batch // 10) == 0:
progress = (game_id + 1) / games_per_batch * 100
self.logger.info(f"批次进度: {progress:.1f}% "
f"({game_id + 1}/{games_per_batch})")
# 计算批次统计
scores = [g['final_score'] for g in batch_stats['games']]
moves = [g['moves'] for g in batch_stats['games']]
batch_stats['avg_score'] = sum(scores) / len(scores) if scores else 0
batch_stats['max_score'] = max(scores) if scores else 0
batch_stats['avg_moves'] = sum(moves) / len(moves) if moves else 0
batch_stats['generation_time'] = time.time() - batch_start_time
# 更新全局统计
self.stats['total_games'] += games_per_batch
self.stats['total_moves'] += batch_stats['total_moves']
self.stats['total_simulations'] += batch_stats['total_simulations']
self.stats['best_scores'].extend(scores)
self.stats['average_scores'].append(batch_stats['avg_score'])
self.logger.info(f"批次{batch}完成: 平均分数={batch_stats['avg_score']:.1f}, "
f"最高分数={batch_stats['max_score']}, "
f"平均步数={batch_stats['avg_moves']:.1f}, "
f"用时={batch_stats['generation_time']:.1f}")
return batch_stats
def run_stage(self, stage: int, stage_config: Dict) -> Dict:
"""
运行单个训练阶段
Args:
stage: 阶段编号
stage_config: 阶段配置
Returns:
阶段统计信息
"""
self.logger.info(f"开始阶段{stage}: {stage_config['description']}")
stage_start_time = time.time()
stage_stats = {
'stage': stage,
'description': stage_config['description'],
'batches': [],
'total_games': 0,
'total_moves': 0,
'total_simulations': 0,
'stage_time': 0,
'data_quality': {}
}
# 运行批次
for batch in range(stage_config['num_batches']):
batch_stats = self.generate_batch_data(
stage=stage,
batch=batch,
games_per_batch=stage_config['games_per_batch'],
simulations_per_move=stage_config['simulations_per_move']
)
stage_stats['batches'].append(batch_stats)
stage_stats['total_games'] += batch_stats['total_games']
stage_stats['total_moves'] += batch_stats['total_moves']
stage_stats['total_simulations'] += batch_stats['total_simulations']
# 保存检查点
if (batch + 1) % self.config['checkpoint_interval'] == 0:
self.save_checkpoint(stage, batch)
# 数据质量评估
if (batch + 1) % self.config['quality_check_interval'] == 0:
quality_stats = self.evaluate_data_quality()
self.logger.info(f"数据质量评估: {quality_stats}")
stage_stats['stage_time'] = time.time() - stage_start_time
# 保存阶段数据
stage_data_file = f"l0_stage{stage}_complete"
self.training_manager.save_current_cache(stage_data_file)
# 记录阶段统计
self.stats['stage_stats'].append(stage_stats)
self.logger.info(f"阶段{stage}完成: "
f"游戏数={stage_stats['total_games']}, "
f"移动数={stage_stats['total_moves']}, "
f"模拟数={stage_stats['total_simulations']}, "
f"用时={stage_stats['stage_time']:.1f}")
return stage_stats
def evaluate_data_quality(self) -> Dict:
"""评估训练数据质量"""
cache_stats = self.training_manager.get_cache_stats()
if cache_stats['cache_size'] == 0:
return {'error': 'No data to evaluate'}
# 获取所有训练样本
examples = self.training_manager.cache.get_all_examples()
# 计算质量指标
values = [ex.value for ex in examples]
actions = [ex.action for ex in examples]
quality_stats = {
'total_examples': len(examples),
'value_stats': {
'mean': float(sum(values) / len(values)) if values else 0,
'min': float(min(values)) if values else 0,
'max': float(max(values)) if values else 0,
'std': float(torch.tensor(values).std().item()) if len(values) > 1 else 0
},
'action_distribution': {
i: actions.count(i) for i in range(4)
},
'unique_states': len(set(ex.canonical_hash for ex in examples))
}
return quality_stats
def generate_final_report(self) -> Dict:
"""生成最终报告"""
total_time = time.time() - self.stats['start_time']
# 转换numpy类型为Python原生类型
def convert_numpy_types(obj):
if hasattr(obj, 'item'): # numpy scalar
return obj.item()
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, dict):
return {k: convert_numpy_types(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_numpy_types(v) for v in obj]
elif isinstance(obj, (np.int32, np.int64, np.float32, np.float64)):
return obj.item()
else:
return obj
report = {
'generation_summary': {
'total_time': total_time,
'total_games': self.stats['total_games'],
'total_moves': self.stats['total_moves'],
'total_simulations': self.stats['total_simulations'],
'games_per_hour': self.stats['total_games'] / (total_time / 3600) if total_time > 0 else 0,
'simulations_per_second': self.stats['total_simulations'] / total_time if total_time > 0 else 0
},
'data_quality': self.evaluate_data_quality(),
'stage_summary': self.stats['stage_stats'],
'config': self.config
}
# 转换所有numpy类型
report = convert_numpy_types(report)
# 保存报告
report_file = Path(self.config['log_dir']) / f"l0_final_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(report_file, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
self.logger.info(f"最终报告已保存: {report_file}")
return report
def get_default_config() -> Dict:
"""获取默认配置"""
return {
# 游戏设置
'board_height': 3,
'board_width': 3,
'base_seed': 42,
'max_moves_per_game': 100,
# MCTS设置
'mcts_c_param': 1.414,
'max_simulation_depth': 80,
'num_threads': 1,
# 数据管理
'cache_size': 50000,
'data_dir': 'data/l0_training',
'log_dir': 'logs/l0_generation',
'checkpoint_dir': 'checkpoints/l0',
# 训练阶段配置
'stages': [
{
'description': '初始探索阶段 - 少量模拟快速生成基础数据',
'num_batches': 5,
'games_per_batch': 20,
'simulations_per_move': 50
},
{
'description': '深度搜索阶段 - 中等模拟生成质量数据',
'num_batches': 10,
'games_per_batch': 30,
'simulations_per_move': 100
},
{
'description': '精细优化阶段 - 大量模拟生成高质量数据',
'num_batches': 15,
'games_per_batch': 40,
'simulations_per_move': 200
},
{
'description': '最终收集阶段 - 超大模拟生成顶级数据',
'num_batches': 10,
'games_per_batch': 50,
'simulations_per_move': 300
}
],
# 控制设置
'checkpoint_interval': 2, # 每2个批次保存检查点
'quality_check_interval': 5, # 每5个批次检查数据质量
'verbose': True
}
def run_l0_generation(config: Dict = None, resume_from: str = None):
"""
运行L0训练数据生成
Args:
config: 配置字典None表示使用默认配置
resume_from: 检查点文件路径,用于断点续传
"""
if config is None:
config = get_default_config()
# 创建数据生成器
generator = L0DataGenerator(config)
start_stage = 0
start_batch = 0
# 断点续传
if resume_from and Path(resume_from).exists():
start_stage, start_batch = generator.load_checkpoint(resume_from)
generator.logger.info(f"从检查点恢复: 阶段{start_stage}, 批次{start_batch}")
try:
# 运行各个阶段
for stage_idx, stage_config in enumerate(config['stages']):
if stage_idx < start_stage:
continue
generator.run_stage(stage_idx, stage_config)
# 生成最终报告
final_report = generator.generate_final_report()
generator.logger.info("L0训练数据生成完成!")
generator.logger.info(f"总游戏数: {final_report['generation_summary']['total_games']}")
generator.logger.info(f"总训练样本: {final_report['data_quality']['total_examples']}")
generator.logger.info(f"生成速度: {final_report['generation_summary']['games_per_hour']:.1f} 游戏/小时")
return final_report
except KeyboardInterrupt:
generator.logger.info("用户中断,保存当前进度...")
# 保存紧急检查点
emergency_checkpoint = Path(config['checkpoint_dir']) / f"emergency_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
generator.save_checkpoint(-1, -1) # 使用特殊标记
generator.logger.info(f"紧急检查点已保存")
except Exception as e:
generator.logger.error(f"生成过程中出现错误: {e}")
raise
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='L0阶段MCTS训练数据生成')
parser.add_argument('--config', type=str, help='配置文件路径')
parser.add_argument('--resume', type=str, help='检查点文件路径')
parser.add_argument('--quick', action='store_true', help='快速测试模式')
parser.add_argument('--stages', type=int, default=None, help='运行的阶段数')
args = parser.parse_args()
# 加载配置
if args.config and Path(args.config).exists():
with open(args.config, 'r', encoding='utf-8-sig') as f:
config = json.load(f)
else:
config = get_default_config()
# 快速测试模式
if args.quick:
config['stages'] = [
{
'description': '快速测试阶段',
'num_batches': 2,
'games_per_batch': 5,
'simulations_per_move': 20
}
]
config['data_dir'] = 'data/l0_test'
config['log_dir'] = 'logs/l0_test'
config['checkpoint_dir'] = 'checkpoints/l0_test'
# 限制阶段数
if args.stages is not None:
config['stages'] = config['stages'][:args.stages]
print("L0训练数据生成器")
print("=" * 50)
print(f"棋盘大小: {config['board_height']}x{config['board_width']}")
print(f"训练阶段数: {len(config['stages'])}")
print(f"预计总游戏数: {sum(s['num_batches'] * s['games_per_batch'] for s in config['stages'])}")
print(f"数据目录: {config['data_dir']}")
if not args.quick:
input("按Enter键开始生成...")
# 运行生成
final_report = run_l0_generation(config, args.resume)
print("\n" + "=" * 50)
print("生成完成!")
print(f"总游戏数: {final_report['generation_summary']['total_games']}")
print(f"总训练样本: {final_report['data_quality']['total_examples']}")
print(f"平均分数: {final_report['data_quality']['value_stats']['mean']:.1f}")
if __name__ == "__main__":
main()