296 lines
9.9 KiB
Python
296 lines
9.9 KiB
Python
"""
|
||
PyTorch MCTS测试
|
||
|
||
测试统一的PyTorch MCTS实现
|
||
"""
|
||
|
||
import pytest
|
||
import torch
|
||
import time
|
||
import numpy as np
|
||
from game import Game2048
|
||
from torch_mcts import TorchMCTS
|
||
from training_data import TrainingDataManager
|
||
|
||
|
||
class TestTorchMCTS:
|
||
"""PyTorch MCTS测试类"""
|
||
|
||
@pytest.fixture
|
||
def game(self):
|
||
"""测试游戏状态"""
|
||
return Game2048(height=3, width=3, seed=42)
|
||
|
||
@pytest.fixture
|
||
def cpu_mcts(self):
|
||
"""CPU MCTS实例"""
|
||
return TorchMCTS(
|
||
c_param=1.414,
|
||
max_simulation_depth=30,
|
||
batch_size=1024,
|
||
device="cpu"
|
||
)
|
||
|
||
@pytest.fixture
|
||
def gpu_mcts(self):
|
||
"""GPU MCTS实例"""
|
||
if not torch.cuda.is_available():
|
||
pytest.skip("CUDA不可用")
|
||
|
||
return TorchMCTS(
|
||
c_param=1.414,
|
||
max_simulation_depth=30,
|
||
batch_size=4096,
|
||
device="cuda"
|
||
)
|
||
|
||
def test_cpu_mcts_basic_functionality(self, game, cpu_mcts):
|
||
"""测试CPU MCTS基本功能"""
|
||
# 执行搜索
|
||
action, stats = cpu_mcts.search(game, 1000)
|
||
|
||
# 验证结果
|
||
assert action in game.get_valid_moves(), f"选择了无效动作: {action}"
|
||
assert 'action_visits' in stats
|
||
assert 'action_avg_values' in stats
|
||
assert 'sims_per_second' in stats
|
||
assert stats['device'] == 'cpu'
|
||
|
||
# 验证访问次数
|
||
total_visits = sum(stats['action_visits'].values())
|
||
assert total_visits == 1000, f"访问次数不匹配: {total_visits}"
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||
def test_gpu_mcts_basic_functionality(self, game, gpu_mcts):
|
||
"""测试GPU MCTS基本功能"""
|
||
# 执行搜索
|
||
action, stats = gpu_mcts.search(game, 2000)
|
||
|
||
# 验证结果
|
||
assert action in game.get_valid_moves(), f"选择了无效动作: {action}"
|
||
assert 'action_visits' in stats
|
||
assert 'action_avg_values' in stats
|
||
assert 'sims_per_second' in stats
|
||
assert stats['device'] == 'cuda'
|
||
|
||
# 验证访问次数
|
||
total_visits = sum(stats['action_visits'].values())
|
||
assert total_visits == 2000, f"访问次数不匹配: {total_visits}"
|
||
|
||
def test_action_distribution_quality(self, game, cpu_mcts):
|
||
"""测试动作分布质量"""
|
||
action, stats = cpu_mcts.search(game, 5000)
|
||
|
||
action_visits = stats['action_visits']
|
||
visit_values = list(action_visits.values())
|
||
|
||
# 检查分布不应该完全均匀(MCTS应该有偏向性)
|
||
assert len(set(visit_values)) > 1, "动作分布完全均匀,不符合MCTS预期"
|
||
|
||
# 检查最佳动作应该有最多访问次数
|
||
best_action_visits = action_visits[action]
|
||
assert best_action_visits == max(visit_values), "最佳动作访问次数不是最多"
|
||
|
||
# 检查价值的合理性
|
||
action_values = stats['action_avg_values']
|
||
for act, value in action_values.items():
|
||
assert value > 0, f"动作{act}的价值应该为正: {value}"
|
||
assert value < 100000, f"动作{act}的价值过大: {value}"
|
||
|
||
def test_device_auto_selection(self, game):
|
||
"""测试设备自动选择"""
|
||
mcts = TorchMCTS(device="auto", batch_size=1024)
|
||
|
||
# 验证设备选择
|
||
if torch.cuda.is_available():
|
||
assert mcts.device.type == "cuda"
|
||
else:
|
||
assert mcts.device.type == "cpu"
|
||
|
||
# 执行搜索验证功能
|
||
action, stats = mcts.search(game, 1000)
|
||
assert action in game.get_valid_moves()
|
||
|
||
if mcts.device.type == "cuda":
|
||
del mcts
|
||
torch.cuda.empty_cache()
|
||
|
||
def test_batch_size_auto_selection(self, game):
|
||
"""测试批次大小自动选择"""
|
||
# CPU自动选择
|
||
cpu_mcts = TorchMCTS(device="cpu", batch_size=None)
|
||
assert cpu_mcts.batch_size == 4096 # CPU默认批次大小
|
||
|
||
# GPU自动选择(如果可用)
|
||
if torch.cuda.is_available():
|
||
gpu_mcts = TorchMCTS(device="cuda", batch_size=None)
|
||
assert gpu_mcts.batch_size == 32768 # GPU默认批次大小
|
||
del gpu_mcts
|
||
torch.cuda.empty_cache()
|
||
|
||
def test_performance_cpu(self, game, cpu_mcts):
|
||
"""测试CPU性能"""
|
||
simulations = 2000
|
||
|
||
start_time = time.time()
|
||
action, stats = cpu_mcts.search(game, simulations)
|
||
elapsed_time = time.time() - start_time
|
||
|
||
speed = simulations / elapsed_time
|
||
|
||
# CPU应该达到基本性能要求
|
||
assert speed > 100, f"CPU性能过低: {speed:.1f} 模拟/秒"
|
||
|
||
# 验证统计信息准确性
|
||
assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确"
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||
def test_performance_gpu(self, game, gpu_mcts):
|
||
"""测试GPU性能"""
|
||
simulations = 5000
|
||
|
||
torch.cuda.synchronize()
|
||
start_time = time.time()
|
||
action, stats = gpu_mcts.search(game, simulations)
|
||
torch.cuda.synchronize()
|
||
elapsed_time = time.time() - start_time
|
||
|
||
speed = simulations / elapsed_time
|
||
|
||
# GPU应该有合理的性能
|
||
assert speed > 50, f"GPU性能过低: {speed:.1f} 模拟/秒"
|
||
|
||
# 验证统计信息准确性
|
||
assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确"
|
||
|
||
def test_training_data_collection(self, game):
|
||
"""测试训练数据收集"""
|
||
# 创建训练数据管理器
|
||
training_manager = TrainingDataManager(
|
||
data_dir="data/test_torch_training",
|
||
cache_size=5000,
|
||
board_size=(3, 3)
|
||
)
|
||
|
||
mcts = TorchMCTS(
|
||
max_simulation_depth=30,
|
||
batch_size=1024,
|
||
device="cpu",
|
||
training_manager=training_manager
|
||
)
|
||
|
||
# 执行搜索
|
||
action, stats = mcts.search(game, 2000)
|
||
|
||
# 验证训练数据收集
|
||
cache_stats = training_manager.get_cache_stats()
|
||
assert cache_stats['cache_size'] > 0, "未收集到训练数据"
|
||
|
||
# 验证数据质量
|
||
assert cache_stats['cache_size'] <= 2000, "收集的样本数超出预期"
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||
def test_memory_management(self, game):
|
||
"""测试GPU内存管理"""
|
||
torch.cuda.empty_cache()
|
||
initial_memory = torch.cuda.memory_allocated()
|
||
|
||
gpu_mcts = TorchMCTS(
|
||
max_simulation_depth=30,
|
||
batch_size=8192,
|
||
device="cuda"
|
||
)
|
||
|
||
# 执行搜索
|
||
action, stats = gpu_mcts.search(game, 3000)
|
||
|
||
# 检查内存使用
|
||
peak_memory = torch.cuda.max_memory_allocated()
|
||
memory_used = (peak_memory - initial_memory) / 1e6 # MB
|
||
|
||
assert memory_used < 200, f"GPU内存使用过多: {memory_used:.1f} MB"
|
||
|
||
# 清理并验证内存释放
|
||
del gpu_mcts
|
||
torch.cuda.empty_cache()
|
||
|
||
final_memory = torch.cuda.memory_allocated()
|
||
assert final_memory <= initial_memory * 1.1, "GPU内存未正确释放"
|
||
|
||
def test_device_switching(self, game):
|
||
"""测试设备动态切换"""
|
||
mcts = TorchMCTS(device="cpu", batch_size=1024)
|
||
|
||
# 初始为CPU
|
||
assert mcts.device.type == "cpu"
|
||
action1, stats1 = mcts.search(game.copy(), 1000)
|
||
assert stats1['device'] == 'cpu'
|
||
|
||
# 切换到GPU(如果可用)
|
||
if torch.cuda.is_available():
|
||
mcts.set_device("cuda")
|
||
assert mcts.device.type == "cuda"
|
||
|
||
action2, stats2 = mcts.search(game.copy(), 1000)
|
||
assert stats2['device'] == 'cuda'
|
||
|
||
# 切换回CPU
|
||
mcts.set_device("cpu")
|
||
assert mcts.device.type == "cpu"
|
||
|
||
torch.cuda.empty_cache()
|
||
|
||
def test_consistency_across_devices(self, game):
|
||
"""测试不同设备间的一致性"""
|
||
if not torch.cuda.is_available():
|
||
pytest.skip("CUDA不可用")
|
||
|
||
# 使用相同的随机种子
|
||
np.random.seed(42)
|
||
cpu_mcts = TorchMCTS(device="cpu", batch_size=2048)
|
||
cpu_action, cpu_stats = cpu_mcts.search(game.copy(), 3000)
|
||
|
||
np.random.seed(42)
|
||
gpu_mcts = TorchMCTS(device="cuda", batch_size=2048)
|
||
gpu_action, gpu_stats = gpu_mcts.search(game.copy(), 3000)
|
||
|
||
# 由于随机性,动作可能不完全一致,但应该在合理范围内
|
||
# 这里主要验证两个设备都能正常工作
|
||
assert cpu_action in game.get_valid_moves()
|
||
assert gpu_action in game.get_valid_moves()
|
||
|
||
# 验证访问次数总和
|
||
cpu_total = sum(cpu_stats['action_visits'].values())
|
||
gpu_total = sum(gpu_stats['action_visits'].values())
|
||
assert cpu_total == gpu_total == 3000
|
||
|
||
del gpu_mcts
|
||
torch.cuda.empty_cache()
|
||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||
def test_batch_size_optimization():
|
||
"""测试批次大小优化"""
|
||
game = Game2048(height=3, width=3, seed=42)
|
||
|
||
mcts = TorchMCTS(device="cuda", batch_size=4096)
|
||
|
||
# 执行批次大小优化
|
||
optimal_size = mcts.optimize_batch_size(game, test_simulations=1000)
|
||
|
||
# 验证优化结果
|
||
assert optimal_size > 0
|
||
assert mcts.batch_size == optimal_size
|
||
|
||
# 验证优化后的性能
|
||
action, stats = mcts.search(game, 2000)
|
||
assert action in game.get_valid_moves()
|
||
assert stats['sims_per_second'] > 0
|
||
|
||
del mcts
|
||
torch.cuda.empty_cache()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|