""" PyTorch MCTS测试 测试统一的PyTorch MCTS实现 """ import pytest import torch import time import numpy as np from game import Game2048 from torch_mcts import TorchMCTS from training_data import TrainingDataManager class TestTorchMCTS: """PyTorch MCTS测试类""" @pytest.fixture def game(self): """测试游戏状态""" return Game2048(height=3, width=3, seed=42) @pytest.fixture def cpu_mcts(self): """CPU MCTS实例""" return TorchMCTS( c_param=1.414, max_simulation_depth=30, batch_size=1024, device="cpu" ) @pytest.fixture def gpu_mcts(self): """GPU MCTS实例""" if not torch.cuda.is_available(): pytest.skip("CUDA不可用") return TorchMCTS( c_param=1.414, max_simulation_depth=30, batch_size=4096, device="cuda" ) def test_cpu_mcts_basic_functionality(self, game, cpu_mcts): """测试CPU MCTS基本功能""" # 执行搜索 action, stats = cpu_mcts.search(game, 1000) # 验证结果 assert action in game.get_valid_moves(), f"选择了无效动作: {action}" assert 'action_visits' in stats assert 'action_avg_values' in stats assert 'sims_per_second' in stats assert stats['device'] == 'cpu' # 验证访问次数 total_visits = sum(stats['action_visits'].values()) assert total_visits == 1000, f"访问次数不匹配: {total_visits}" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用") def test_gpu_mcts_basic_functionality(self, game, gpu_mcts): """测试GPU MCTS基本功能""" # 执行搜索 action, stats = gpu_mcts.search(game, 2000) # 验证结果 assert action in game.get_valid_moves(), f"选择了无效动作: {action}" assert 'action_visits' in stats assert 'action_avg_values' in stats assert 'sims_per_second' in stats assert stats['device'] == 'cuda' # 验证访问次数 total_visits = sum(stats['action_visits'].values()) assert total_visits == 2000, f"访问次数不匹配: {total_visits}" def test_action_distribution_quality(self, game, cpu_mcts): """测试动作分布质量""" action, stats = cpu_mcts.search(game, 5000) action_visits = stats['action_visits'] visit_values = list(action_visits.values()) # 检查分布不应该完全均匀(MCTS应该有偏向性) assert len(set(visit_values)) > 1, "动作分布完全均匀,不符合MCTS预期" # 检查最佳动作应该有最多访问次数 best_action_visits = action_visits[action] assert best_action_visits == max(visit_values), "最佳动作访问次数不是最多" # 检查价值的合理性 action_values = stats['action_avg_values'] for act, value in action_values.items(): assert value > 0, f"动作{act}的价值应该为正: {value}" assert value < 100000, f"动作{act}的价值过大: {value}" def test_device_auto_selection(self, game): """测试设备自动选择""" mcts = TorchMCTS(device="auto", batch_size=1024) # 验证设备选择 if torch.cuda.is_available(): assert mcts.device.type == "cuda" else: assert mcts.device.type == "cpu" # 执行搜索验证功能 action, stats = mcts.search(game, 1000) assert action in game.get_valid_moves() if mcts.device.type == "cuda": del mcts torch.cuda.empty_cache() def test_batch_size_auto_selection(self, game): """测试批次大小自动选择""" # CPU自动选择 cpu_mcts = TorchMCTS(device="cpu", batch_size=None) assert cpu_mcts.batch_size == 4096 # CPU默认批次大小 # GPU自动选择(如果可用) if torch.cuda.is_available(): gpu_mcts = TorchMCTS(device="cuda", batch_size=None) assert gpu_mcts.batch_size == 32768 # GPU默认批次大小 del gpu_mcts torch.cuda.empty_cache() def test_performance_cpu(self, game, cpu_mcts): """测试CPU性能""" simulations = 2000 start_time = time.time() action, stats = cpu_mcts.search(game, simulations) elapsed_time = time.time() - start_time speed = simulations / elapsed_time # CPU应该达到基本性能要求 assert speed > 100, f"CPU性能过低: {speed:.1f} 模拟/秒" # 验证统计信息准确性 assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用") def test_performance_gpu(self, game, gpu_mcts): """测试GPU性能""" simulations = 5000 torch.cuda.synchronize() start_time = time.time() action, stats = gpu_mcts.search(game, simulations) torch.cuda.synchronize() elapsed_time = time.time() - start_time speed = simulations / elapsed_time # GPU应该有合理的性能 assert speed > 50, f"GPU性能过低: {speed:.1f} 模拟/秒" # 验证统计信息准确性 assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确" def test_training_data_collection(self, game): """测试训练数据收集""" # 创建训练数据管理器 training_manager = TrainingDataManager( data_dir="data/test_torch_training", cache_size=5000, board_size=(3, 3) ) mcts = TorchMCTS( max_simulation_depth=30, batch_size=1024, device="cpu", training_manager=training_manager ) # 执行搜索 action, stats = mcts.search(game, 2000) # 验证训练数据收集 cache_stats = training_manager.get_cache_stats() assert cache_stats['cache_size'] > 0, "未收集到训练数据" # 验证数据质量 assert cache_stats['cache_size'] <= 2000, "收集的样本数超出预期" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用") def test_memory_management(self, game): """测试GPU内存管理""" torch.cuda.empty_cache() initial_memory = torch.cuda.memory_allocated() gpu_mcts = TorchMCTS( max_simulation_depth=30, batch_size=8192, device="cuda" ) # 执行搜索 action, stats = gpu_mcts.search(game, 3000) # 检查内存使用 peak_memory = torch.cuda.max_memory_allocated() memory_used = (peak_memory - initial_memory) / 1e6 # MB assert memory_used < 200, f"GPU内存使用过多: {memory_used:.1f} MB" # 清理并验证内存释放 del gpu_mcts torch.cuda.empty_cache() final_memory = torch.cuda.memory_allocated() assert final_memory <= initial_memory * 1.1, "GPU内存未正确释放" def test_device_switching(self, game): """测试设备动态切换""" mcts = TorchMCTS(device="cpu", batch_size=1024) # 初始为CPU assert mcts.device.type == "cpu" action1, stats1 = mcts.search(game.copy(), 1000) assert stats1['device'] == 'cpu' # 切换到GPU(如果可用) if torch.cuda.is_available(): mcts.set_device("cuda") assert mcts.device.type == "cuda" action2, stats2 = mcts.search(game.copy(), 1000) assert stats2['device'] == 'cuda' # 切换回CPU mcts.set_device("cpu") assert mcts.device.type == "cpu" torch.cuda.empty_cache() def test_consistency_across_devices(self, game): """测试不同设备间的一致性""" if not torch.cuda.is_available(): pytest.skip("CUDA不可用") # 使用相同的随机种子 np.random.seed(42) cpu_mcts = TorchMCTS(device="cpu", batch_size=2048) cpu_action, cpu_stats = cpu_mcts.search(game.copy(), 3000) np.random.seed(42) gpu_mcts = TorchMCTS(device="cuda", batch_size=2048) gpu_action, gpu_stats = gpu_mcts.search(game.copy(), 3000) # 由于随机性,动作可能不完全一致,但应该在合理范围内 # 这里主要验证两个设备都能正常工作 assert cpu_action in game.get_valid_moves() assert gpu_action in game.get_valid_moves() # 验证访问次数总和 cpu_total = sum(cpu_stats['action_visits'].values()) gpu_total = sum(gpu_stats['action_visits'].values()) assert cpu_total == gpu_total == 3000 del gpu_mcts torch.cuda.empty_cache() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用") def test_batch_size_optimization(): """测试批次大小优化""" game = Game2048(height=3, width=3, seed=42) mcts = TorchMCTS(device="cuda", batch_size=4096) # 执行批次大小优化 optimal_size = mcts.optimize_batch_size(game, test_simulations=1000) # 验证优化结果 assert optimal_size > 0 assert mcts.batch_size == optimal_size # 验证优化后的性能 action, stats = mcts.search(game, 2000) assert action in game.get_valid_moves() assert stats['sims_per_second'] > 0 del mcts torch.cuda.empty_cache() if __name__ == "__main__": pytest.main([__file__, "-v"])