增加L0训练阶段的MCTS部分
This commit is contained in:
295
tests/test_torch_mcts.py
Normal file
295
tests/test_torch_mcts.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
PyTorch MCTS测试
|
||||
|
||||
测试统一的PyTorch MCTS实现
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from game import Game2048
|
||||
from torch_mcts import TorchMCTS
|
||||
from training_data import TrainingDataManager
|
||||
|
||||
|
||||
class TestTorchMCTS:
|
||||
"""PyTorch MCTS测试类"""
|
||||
|
||||
@pytest.fixture
|
||||
def game(self):
|
||||
"""测试游戏状态"""
|
||||
return Game2048(height=3, width=3, seed=42)
|
||||
|
||||
@pytest.fixture
|
||||
def cpu_mcts(self):
|
||||
"""CPU MCTS实例"""
|
||||
return TorchMCTS(
|
||||
c_param=1.414,
|
||||
max_simulation_depth=30,
|
||||
batch_size=1024,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def gpu_mcts(self):
|
||||
"""GPU MCTS实例"""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA不可用")
|
||||
|
||||
return TorchMCTS(
|
||||
c_param=1.414,
|
||||
max_simulation_depth=30,
|
||||
batch_size=4096,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
def test_cpu_mcts_basic_functionality(self, game, cpu_mcts):
|
||||
"""测试CPU MCTS基本功能"""
|
||||
# 执行搜索
|
||||
action, stats = cpu_mcts.search(game, 1000)
|
||||
|
||||
# 验证结果
|
||||
assert action in game.get_valid_moves(), f"选择了无效动作: {action}"
|
||||
assert 'action_visits' in stats
|
||||
assert 'action_avg_values' in stats
|
||||
assert 'sims_per_second' in stats
|
||||
assert stats['device'] == 'cpu'
|
||||
|
||||
# 验证访问次数
|
||||
total_visits = sum(stats['action_visits'].values())
|
||||
assert total_visits == 1000, f"访问次数不匹配: {total_visits}"
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_gpu_mcts_basic_functionality(self, game, gpu_mcts):
|
||||
"""测试GPU MCTS基本功能"""
|
||||
# 执行搜索
|
||||
action, stats = gpu_mcts.search(game, 2000)
|
||||
|
||||
# 验证结果
|
||||
assert action in game.get_valid_moves(), f"选择了无效动作: {action}"
|
||||
assert 'action_visits' in stats
|
||||
assert 'action_avg_values' in stats
|
||||
assert 'sims_per_second' in stats
|
||||
assert stats['device'] == 'cuda'
|
||||
|
||||
# 验证访问次数
|
||||
total_visits = sum(stats['action_visits'].values())
|
||||
assert total_visits == 2000, f"访问次数不匹配: {total_visits}"
|
||||
|
||||
def test_action_distribution_quality(self, game, cpu_mcts):
|
||||
"""测试动作分布质量"""
|
||||
action, stats = cpu_mcts.search(game, 5000)
|
||||
|
||||
action_visits = stats['action_visits']
|
||||
visit_values = list(action_visits.values())
|
||||
|
||||
# 检查分布不应该完全均匀(MCTS应该有偏向性)
|
||||
assert len(set(visit_values)) > 1, "动作分布完全均匀,不符合MCTS预期"
|
||||
|
||||
# 检查最佳动作应该有最多访问次数
|
||||
best_action_visits = action_visits[action]
|
||||
assert best_action_visits == max(visit_values), "最佳动作访问次数不是最多"
|
||||
|
||||
# 检查价值的合理性
|
||||
action_values = stats['action_avg_values']
|
||||
for act, value in action_values.items():
|
||||
assert value > 0, f"动作{act}的价值应该为正: {value}"
|
||||
assert value < 100000, f"动作{act}的价值过大: {value}"
|
||||
|
||||
def test_device_auto_selection(self, game):
|
||||
"""测试设备自动选择"""
|
||||
mcts = TorchMCTS(device="auto", batch_size=1024)
|
||||
|
||||
# 验证设备选择
|
||||
if torch.cuda.is_available():
|
||||
assert mcts.device.type == "cuda"
|
||||
else:
|
||||
assert mcts.device.type == "cpu"
|
||||
|
||||
# 执行搜索验证功能
|
||||
action, stats = mcts.search(game, 1000)
|
||||
assert action in game.get_valid_moves()
|
||||
|
||||
if mcts.device.type == "cuda":
|
||||
del mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_batch_size_auto_selection(self, game):
|
||||
"""测试批次大小自动选择"""
|
||||
# CPU自动选择
|
||||
cpu_mcts = TorchMCTS(device="cpu", batch_size=None)
|
||||
assert cpu_mcts.batch_size == 4096 # CPU默认批次大小
|
||||
|
||||
# GPU自动选择(如果可用)
|
||||
if torch.cuda.is_available():
|
||||
gpu_mcts = TorchMCTS(device="cuda", batch_size=None)
|
||||
assert gpu_mcts.batch_size == 32768 # GPU默认批次大小
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_performance_cpu(self, game, cpu_mcts):
|
||||
"""测试CPU性能"""
|
||||
simulations = 2000
|
||||
|
||||
start_time = time.time()
|
||||
action, stats = cpu_mcts.search(game, simulations)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
|
||||
# CPU应该达到基本性能要求
|
||||
assert speed > 100, f"CPU性能过低: {speed:.1f} 模拟/秒"
|
||||
|
||||
# 验证统计信息准确性
|
||||
assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确"
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_performance_gpu(self, game, gpu_mcts):
|
||||
"""测试GPU性能"""
|
||||
simulations = 5000
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
action, stats = gpu_mcts.search(game, simulations)
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
|
||||
# GPU应该有合理的性能
|
||||
assert speed > 50, f"GPU性能过低: {speed:.1f} 模拟/秒"
|
||||
|
||||
# 验证统计信息准确性
|
||||
assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确"
|
||||
|
||||
def test_training_data_collection(self, game):
|
||||
"""测试训练数据收集"""
|
||||
# 创建训练数据管理器
|
||||
training_manager = TrainingDataManager(
|
||||
data_dir="data/test_torch_training",
|
||||
cache_size=5000,
|
||||
board_size=(3, 3)
|
||||
)
|
||||
|
||||
mcts = TorchMCTS(
|
||||
max_simulation_depth=30,
|
||||
batch_size=1024,
|
||||
device="cpu",
|
||||
training_manager=training_manager
|
||||
)
|
||||
|
||||
# 执行搜索
|
||||
action, stats = mcts.search(game, 2000)
|
||||
|
||||
# 验证训练数据收集
|
||||
cache_stats = training_manager.get_cache_stats()
|
||||
assert cache_stats['cache_size'] > 0, "未收集到训练数据"
|
||||
|
||||
# 验证数据质量
|
||||
assert cache_stats['cache_size'] <= 2000, "收集的样本数超出预期"
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_memory_management(self, game):
|
||||
"""测试GPU内存管理"""
|
||||
torch.cuda.empty_cache()
|
||||
initial_memory = torch.cuda.memory_allocated()
|
||||
|
||||
gpu_mcts = TorchMCTS(
|
||||
max_simulation_depth=30,
|
||||
batch_size=8192,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 执行搜索
|
||||
action, stats = gpu_mcts.search(game, 3000)
|
||||
|
||||
# 检查内存使用
|
||||
peak_memory = torch.cuda.max_memory_allocated()
|
||||
memory_used = (peak_memory - initial_memory) / 1e6 # MB
|
||||
|
||||
assert memory_used < 200, f"GPU内存使用过多: {memory_used:.1f} MB"
|
||||
|
||||
# 清理并验证内存释放
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
final_memory = torch.cuda.memory_allocated()
|
||||
assert final_memory <= initial_memory * 1.1, "GPU内存未正确释放"
|
||||
|
||||
def test_device_switching(self, game):
|
||||
"""测试设备动态切换"""
|
||||
mcts = TorchMCTS(device="cpu", batch_size=1024)
|
||||
|
||||
# 初始为CPU
|
||||
assert mcts.device.type == "cpu"
|
||||
action1, stats1 = mcts.search(game.copy(), 1000)
|
||||
assert stats1['device'] == 'cpu'
|
||||
|
||||
# 切换到GPU(如果可用)
|
||||
if torch.cuda.is_available():
|
||||
mcts.set_device("cuda")
|
||||
assert mcts.device.type == "cuda"
|
||||
|
||||
action2, stats2 = mcts.search(game.copy(), 1000)
|
||||
assert stats2['device'] == 'cuda'
|
||||
|
||||
# 切换回CPU
|
||||
mcts.set_device("cpu")
|
||||
assert mcts.device.type == "cpu"
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_consistency_across_devices(self, game):
|
||||
"""测试不同设备间的一致性"""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA不可用")
|
||||
|
||||
# 使用相同的随机种子
|
||||
np.random.seed(42)
|
||||
cpu_mcts = TorchMCTS(device="cpu", batch_size=2048)
|
||||
cpu_action, cpu_stats = cpu_mcts.search(game.copy(), 3000)
|
||||
|
||||
np.random.seed(42)
|
||||
gpu_mcts = TorchMCTS(device="cuda", batch_size=2048)
|
||||
gpu_action, gpu_stats = gpu_mcts.search(game.copy(), 3000)
|
||||
|
||||
# 由于随机性,动作可能不完全一致,但应该在合理范围内
|
||||
# 这里主要验证两个设备都能正常工作
|
||||
assert cpu_action in game.get_valid_moves()
|
||||
assert gpu_action in game.get_valid_moves()
|
||||
|
||||
# 验证访问次数总和
|
||||
cpu_total = sum(cpu_stats['action_visits'].values())
|
||||
gpu_total = sum(gpu_stats['action_visits'].values())
|
||||
assert cpu_total == gpu_total == 3000
|
||||
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_batch_size_optimization():
|
||||
"""测试批次大小优化"""
|
||||
game = Game2048(height=3, width=3, seed=42)
|
||||
|
||||
mcts = TorchMCTS(device="cuda", batch_size=4096)
|
||||
|
||||
# 执行批次大小优化
|
||||
optimal_size = mcts.optimize_batch_size(game, test_simulations=1000)
|
||||
|
||||
# 验证优化结果
|
||||
assert optimal_size > 0
|
||||
assert mcts.batch_size == optimal_size
|
||||
|
||||
# 验证优化后的性能
|
||||
action, stats = mcts.search(game, 2000)
|
||||
assert action in game.get_valid_moves()
|
||||
assert stats['sims_per_second'] > 0
|
||||
|
||||
del mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user