增加L0训练阶段的MCTS部分
This commit is contained in:
210
tests/test_performance_benchmark.py
Normal file
210
tests/test_performance_benchmark.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
性能基准测试
|
||||
|
||||
测试不同MCTS实现的性能对比
|
||||
"""
|
||||
|
||||
import time
|
||||
import torch
|
||||
import pytest
|
||||
from game import Game2048
|
||||
from torch_mcts import TorchMCTS
|
||||
|
||||
|
||||
class TestPerformanceBenchmark:
|
||||
"""性能基准测试类"""
|
||||
|
||||
@pytest.fixture
|
||||
def game(self):
|
||||
"""测试游戏状态"""
|
||||
return Game2048(height=3, width=3, seed=42)
|
||||
|
||||
def test_cpu_mcts_performance(self, game):
|
||||
"""测试CPU MCTS性能"""
|
||||
mcts = TorchMCTS(
|
||||
c_param=1.414,
|
||||
max_simulation_depth=50,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
simulations = 2000
|
||||
start_time = time.time()
|
||||
action, stats = mcts.search(game, simulations)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
|
||||
# CPU MCTS应该达到基本性能要求
|
||||
assert speed > 500, f"CPU MCTS性能过低: {speed:.1f} 模拟/秒"
|
||||
assert action in game.get_valid_moves()
|
||||
|
||||
def test_auto_device_mcts_performance(self, game):
|
||||
"""测试自动设备选择MCTS性能"""
|
||||
mcts = TorchMCTS(
|
||||
c_param=1.414,
|
||||
max_simulation_depth=50,
|
||||
device="auto"
|
||||
)
|
||||
|
||||
simulations = 2000
|
||||
start_time = time.time()
|
||||
action, stats = mcts.search(game, simulations)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
|
||||
# 自动设备选择应该有合理性能
|
||||
assert speed > 100, f"自动设备MCTS性能过低: {speed:.1f} 模拟/秒"
|
||||
assert action in game.get_valid_moves()
|
||||
|
||||
if mcts.device.type == "cuda":
|
||||
del mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_gpu_mcts_performance(self, game):
|
||||
"""测试GPU MCTS性能"""
|
||||
gpu_mcts = TorchMCTS(
|
||||
max_simulation_depth=50,
|
||||
batch_size=8192,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
simulations = 5000
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
action, stats = gpu_mcts.search(game, simulations)
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
|
||||
# GPU MCTS应该有显著性能提升
|
||||
assert speed > 200, f"GPU MCTS性能过低: {speed:.1f} 模拟/秒"
|
||||
assert action in game.get_valid_moves()
|
||||
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_performance_comparison(self, game):
|
||||
"""性能对比测试"""
|
||||
simulations = 3000
|
||||
results = {}
|
||||
|
||||
# CPU MCTS
|
||||
cpu_mcts = TorchMCTS(c_param=1.414, max_simulation_depth=50, device="cpu")
|
||||
start_time = time.time()
|
||||
cpu_action, cpu_stats = cpu_mcts.search(game.copy(), simulations)
|
||||
cpu_time = time.time() - start_time
|
||||
results['CPU'] = simulations / cpu_time
|
||||
|
||||
# GPU MCTS
|
||||
gpu_mcts = TorchMCTS(max_simulation_depth=50, batch_size=8192, device="cuda")
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
gpu_action, gpu_stats = gpu_mcts.search(game.copy(), simulations)
|
||||
torch.cuda.synchronize()
|
||||
gpu_time = time.time() - start_time
|
||||
results['GPU'] = simulations / gpu_time
|
||||
|
||||
# 验证性能提升
|
||||
speedup = results['GPU'] / results['CPU']
|
||||
print(f"\n性能对比:")
|
||||
print(f" CPU: {results['CPU']:.1f} 模拟/秒")
|
||||
print(f" GPU: {results['GPU']:.1f} 模拟/秒")
|
||||
print(f" 加速比: {speedup:.1f}x")
|
||||
|
||||
# GPU应该有一定的性能优势(至少不能太慢)
|
||||
assert speedup > 0.1, f"GPU性能严重低于CPU: {speedup:.2f}x"
|
||||
|
||||
# 清理
|
||||
del cpu_mcts, gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_batch_size_scaling(self):
|
||||
"""测试批次大小对性能的影响"""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA不可用")
|
||||
|
||||
game = Game2048(height=3, width=3, seed=42)
|
||||
batch_sizes = [1024, 4096, 16384]
|
||||
simulations = 2000
|
||||
|
||||
results = {}
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
gpu_mcts = TorchMCTS(
|
||||
max_simulation_depth=50,
|
||||
batch_size=batch_size,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
action, stats = gpu_mcts.search(game.copy(), simulations)
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
results[batch_size] = speed
|
||||
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 验证批次大小的影响
|
||||
speeds = list(results.values())
|
||||
max_speed = max(speeds)
|
||||
min_speed = min(speeds)
|
||||
|
||||
# 不同批次大小的性能差异应该在合理范围内
|
||||
speed_ratio = max_speed / min_speed
|
||||
assert speed_ratio < 10, f"批次大小性能差异过大: {speed_ratio:.2f}"
|
||||
|
||||
print(f"\n批次大小性能测试:")
|
||||
for batch_size, speed in results.items():
|
||||
print(f" {batch_size:,}: {speed:.1f} 模拟/秒")
|
||||
|
||||
|
||||
def test_memory_efficiency():
|
||||
"""内存效率测试"""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA不可用")
|
||||
|
||||
game = Game2048(height=3, width=3, seed=42)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
initial_memory = torch.cuda.memory_allocated()
|
||||
|
||||
gpu_mcts = TorchMCTS(
|
||||
max_simulation_depth=50,
|
||||
batch_size=32768,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 执行搜索
|
||||
action, stats = gpu_mcts.search(game, 10000)
|
||||
|
||||
peak_memory = torch.cuda.max_memory_allocated()
|
||||
memory_used = (peak_memory - initial_memory) / 1e6 # MB
|
||||
|
||||
# 内存使用应该合理
|
||||
assert memory_used < 500, f"GPU内存使用过多: {memory_used:.1f} MB"
|
||||
|
||||
# 计算内存效率
|
||||
speed = stats['sims_per_second']
|
||||
memory_efficiency = speed / memory_used if memory_used > 0 else 0
|
||||
|
||||
print(f"\n内存效率测试:")
|
||||
print(f" 内存使用: {memory_used:.1f} MB")
|
||||
print(f" 模拟速度: {speed:.1f} 模拟/秒")
|
||||
print(f" 内存效率: {memory_efficiency:.1f} 模拟/秒/MB")
|
||||
|
||||
# 清理
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user