Files
deep2048/tests/test_torch_mcts.py
2025-07-23 07:04:10 +08:00

296 lines
9.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
PyTorch MCTS测试
测试统一的PyTorch MCTS实现
"""
import pytest
import torch
import time
import numpy as np
from game import Game2048
from torch_mcts import TorchMCTS
from training_data import TrainingDataManager
class TestTorchMCTS:
"""PyTorch MCTS测试类"""
@pytest.fixture
def game(self):
"""测试游戏状态"""
return Game2048(height=3, width=3, seed=42)
@pytest.fixture
def cpu_mcts(self):
"""CPU MCTS实例"""
return TorchMCTS(
c_param=1.414,
max_simulation_depth=30,
batch_size=1024,
device="cpu"
)
@pytest.fixture
def gpu_mcts(self):
"""GPU MCTS实例"""
if not torch.cuda.is_available():
pytest.skip("CUDA不可用")
return TorchMCTS(
c_param=1.414,
max_simulation_depth=30,
batch_size=4096,
device="cuda"
)
def test_cpu_mcts_basic_functionality(self, game, cpu_mcts):
"""测试CPU MCTS基本功能"""
# 执行搜索
action, stats = cpu_mcts.search(game, 1000)
# 验证结果
assert action in game.get_valid_moves(), f"选择了无效动作: {action}"
assert 'action_visits' in stats
assert 'action_avg_values' in stats
assert 'sims_per_second' in stats
assert stats['device'] == 'cpu'
# 验证访问次数
total_visits = sum(stats['action_visits'].values())
assert total_visits == 1000, f"访问次数不匹配: {total_visits}"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
def test_gpu_mcts_basic_functionality(self, game, gpu_mcts):
"""测试GPU MCTS基本功能"""
# 执行搜索
action, stats = gpu_mcts.search(game, 2000)
# 验证结果
assert action in game.get_valid_moves(), f"选择了无效动作: {action}"
assert 'action_visits' in stats
assert 'action_avg_values' in stats
assert 'sims_per_second' in stats
assert stats['device'] == 'cuda'
# 验证访问次数
total_visits = sum(stats['action_visits'].values())
assert total_visits == 2000, f"访问次数不匹配: {total_visits}"
def test_action_distribution_quality(self, game, cpu_mcts):
"""测试动作分布质量"""
action, stats = cpu_mcts.search(game, 5000)
action_visits = stats['action_visits']
visit_values = list(action_visits.values())
# 检查分布不应该完全均匀MCTS应该有偏向性
assert len(set(visit_values)) > 1, "动作分布完全均匀不符合MCTS预期"
# 检查最佳动作应该有最多访问次数
best_action_visits = action_visits[action]
assert best_action_visits == max(visit_values), "最佳动作访问次数不是最多"
# 检查价值的合理性
action_values = stats['action_avg_values']
for act, value in action_values.items():
assert value > 0, f"动作{act}的价值应该为正: {value}"
assert value < 100000, f"动作{act}的价值过大: {value}"
def test_device_auto_selection(self, game):
"""测试设备自动选择"""
mcts = TorchMCTS(device="auto", batch_size=1024)
# 验证设备选择
if torch.cuda.is_available():
assert mcts.device.type == "cuda"
else:
assert mcts.device.type == "cpu"
# 执行搜索验证功能
action, stats = mcts.search(game, 1000)
assert action in game.get_valid_moves()
if mcts.device.type == "cuda":
del mcts
torch.cuda.empty_cache()
def test_batch_size_auto_selection(self, game):
"""测试批次大小自动选择"""
# CPU自动选择
cpu_mcts = TorchMCTS(device="cpu", batch_size=None)
assert cpu_mcts.batch_size == 4096 # CPU默认批次大小
# GPU自动选择如果可用
if torch.cuda.is_available():
gpu_mcts = TorchMCTS(device="cuda", batch_size=None)
assert gpu_mcts.batch_size == 32768 # GPU默认批次大小
del gpu_mcts
torch.cuda.empty_cache()
def test_performance_cpu(self, game, cpu_mcts):
"""测试CPU性能"""
simulations = 2000
start_time = time.time()
action, stats = cpu_mcts.search(game, simulations)
elapsed_time = time.time() - start_time
speed = simulations / elapsed_time
# CPU应该达到基本性能要求
assert speed > 100, f"CPU性能过低: {speed:.1f} 模拟/秒"
# 验证统计信息准确性
assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
def test_performance_gpu(self, game, gpu_mcts):
"""测试GPU性能"""
simulations = 5000
torch.cuda.synchronize()
start_time = time.time()
action, stats = gpu_mcts.search(game, simulations)
torch.cuda.synchronize()
elapsed_time = time.time() - start_time
speed = simulations / elapsed_time
# GPU应该有合理的性能
assert speed > 50, f"GPU性能过低: {speed:.1f} 模拟/秒"
# 验证统计信息准确性
assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确"
def test_training_data_collection(self, game):
"""测试训练数据收集"""
# 创建训练数据管理器
training_manager = TrainingDataManager(
data_dir="data/test_torch_training",
cache_size=5000,
board_size=(3, 3)
)
mcts = TorchMCTS(
max_simulation_depth=30,
batch_size=1024,
device="cpu",
training_manager=training_manager
)
# 执行搜索
action, stats = mcts.search(game, 2000)
# 验证训练数据收集
cache_stats = training_manager.get_cache_stats()
assert cache_stats['cache_size'] > 0, "未收集到训练数据"
# 验证数据质量
assert cache_stats['cache_size'] <= 2000, "收集的样本数超出预期"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
def test_memory_management(self, game):
"""测试GPU内存管理"""
torch.cuda.empty_cache()
initial_memory = torch.cuda.memory_allocated()
gpu_mcts = TorchMCTS(
max_simulation_depth=30,
batch_size=8192,
device="cuda"
)
# 执行搜索
action, stats = gpu_mcts.search(game, 3000)
# 检查内存使用
peak_memory = torch.cuda.max_memory_allocated()
memory_used = (peak_memory - initial_memory) / 1e6 # MB
assert memory_used < 200, f"GPU内存使用过多: {memory_used:.1f} MB"
# 清理并验证内存释放
del gpu_mcts
torch.cuda.empty_cache()
final_memory = torch.cuda.memory_allocated()
assert final_memory <= initial_memory * 1.1, "GPU内存未正确释放"
def test_device_switching(self, game):
"""测试设备动态切换"""
mcts = TorchMCTS(device="cpu", batch_size=1024)
# 初始为CPU
assert mcts.device.type == "cpu"
action1, stats1 = mcts.search(game.copy(), 1000)
assert stats1['device'] == 'cpu'
# 切换到GPU如果可用
if torch.cuda.is_available():
mcts.set_device("cuda")
assert mcts.device.type == "cuda"
action2, stats2 = mcts.search(game.copy(), 1000)
assert stats2['device'] == 'cuda'
# 切换回CPU
mcts.set_device("cpu")
assert mcts.device.type == "cpu"
torch.cuda.empty_cache()
def test_consistency_across_devices(self, game):
"""测试不同设备间的一致性"""
if not torch.cuda.is_available():
pytest.skip("CUDA不可用")
# 使用相同的随机种子
np.random.seed(42)
cpu_mcts = TorchMCTS(device="cpu", batch_size=2048)
cpu_action, cpu_stats = cpu_mcts.search(game.copy(), 3000)
np.random.seed(42)
gpu_mcts = TorchMCTS(device="cuda", batch_size=2048)
gpu_action, gpu_stats = gpu_mcts.search(game.copy(), 3000)
# 由于随机性,动作可能不完全一致,但应该在合理范围内
# 这里主要验证两个设备都能正常工作
assert cpu_action in game.get_valid_moves()
assert gpu_action in game.get_valid_moves()
# 验证访问次数总和
cpu_total = sum(cpu_stats['action_visits'].values())
gpu_total = sum(gpu_stats['action_visits'].values())
assert cpu_total == gpu_total == 3000
del gpu_mcts
torch.cuda.empty_cache()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
def test_batch_size_optimization():
"""测试批次大小优化"""
game = Game2048(height=3, width=3, seed=42)
mcts = TorchMCTS(device="cuda", batch_size=4096)
# 执行批次大小优化
optimal_size = mcts.optimize_batch_size(game, test_simulations=1000)
# 验证优化结果
assert optimal_size > 0
assert mcts.batch_size == optimal_size
# 验证优化后的性能
action, stats = mcts.search(game, 2000)
assert action in game.get_valid_moves()
assert stats['sims_per_second'] > 0
del mcts
torch.cuda.empty_cache()
if __name__ == "__main__":
pytest.main([__file__, "-v"])