增加L0训练阶段的MCTS部分

This commit is contained in:
hisatri
2025-07-23 07:04:10 +08:00
parent 88bed2a1ef
commit 4410defbe5
23 changed files with 5205 additions and 0 deletions

5
tests/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""
测试模块
包含所有的测试文件和基准测试
"""

100
tests/run_all_tests.py Normal file
View File

@@ -0,0 +1,100 @@
"""
统一测试运行器
运行所有测试并生成报告
"""
import pytest
import sys
import time
from pathlib import Path
def run_all_tests():
"""运行所有测试"""
print("Deep2048 项目测试套件")
print("=" * 50)
test_dir = Path(__file__).parent
# 测试文件列表
test_files = [
"test_training_data.py",
"test_game_engine.py",
"test_torch_mcts.py",
"test_board_compression.py",
"test_cache_system.py",
"test_persistence.py",
"test_performance_benchmark.py"
]
# 检查测试文件是否存在
existing_tests = []
for test_file in test_files:
test_path = test_dir / test_file
if test_path.exists():
existing_tests.append(str(test_path))
else:
print(f"警告: 测试文件不存在 {test_file}")
if not existing_tests:
print("错误: 没有找到测试文件")
return False
print(f"找到 {len(existing_tests)} 个测试文件")
# 运行测试
start_time = time.time()
# pytest参数
args = [
"-v", # 详细输出
"--tb=short", # 简短的错误回溯
"--durations=10", # 显示最慢的10个测试
] + existing_tests
result = pytest.main(args)
elapsed_time = time.time() - start_time
print(f"\n测试完成,用时: {elapsed_time:.2f}")
if result == 0:
print("✅ 所有测试通过!")
return True
else:
print("❌ 部分测试失败")
return False
def run_quick_tests():
"""运行快速测试(跳过性能测试)"""
print("快速测试模式")
print("=" * 30)
test_dir = Path(__file__).parent
args = [
"-v",
"-k", "not performance and not slow", # 跳过性能测试
str(test_dir)
]
result = pytest.main(args)
return result == 0
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="运行Deep2048测试套件")
parser.add_argument("--quick", action="store_true", help="快速测试模式")
args = parser.parse_args()
if args.quick:
success = run_quick_tests()
else:
success = run_all_tests()
sys.exit(0 if success else 1)

View File

@@ -0,0 +1,251 @@
"""
棋盘压缩算法测试
验证二面体群D4变换和规范化的正确性
"""
import numpy as np
import pytest
from training_data import BoardTransform
class TestBoardTransform:
"""棋盘变换测试类"""
def setup_method(self):
"""测试前的设置"""
# 创建一个非对称的测试棋盘,便于验证变换
self.test_board = np.array([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
])
# 创建一个简单的2x2棋盘用于手动验证
self.simple_board = np.array([
[1, 2],
[3, 4]
])
def test_log_transform(self):
"""测试对数变换"""
# 测试正常情况
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
expected = np.array([
[1, 2, 3, 4],
[0, 1, 2, 3],
[0, 0, 1, 2],
[0, 0, 0, 1]
])
result = BoardTransform.log_transform(board)
np.testing.assert_array_equal(result, expected)
# 测试逆变换
restored = BoardTransform.inverse_log_transform(result)
np.testing.assert_array_equal(restored, board)
def test_rotate_90(self):
"""测试90度旋转"""
# 手动验证2x2矩阵的90度顺时针旋转
# [1, 2] -> [3, 1]
# [3, 4] [4, 2]
expected = np.array([
[3, 1],
[4, 2]
])
result = BoardTransform.rotate_90(self.simple_board)
np.testing.assert_array_equal(result, expected)
def test_flip_horizontal(self):
"""测试水平翻转"""
# 手动验证2x2矩阵的水平翻转
# [1, 2] -> [2, 1]
# [3, 4] [4, 3]
expected = np.array([
[2, 1],
[4, 3]
])
result = BoardTransform.flip_horizontal(self.simple_board)
np.testing.assert_array_equal(result, expected)
def test_all_transforms_count(self):
"""测试是否生成了正确数量的变换"""
transforms = BoardTransform.get_all_transforms(self.test_board)
assert len(transforms) == 8, "应该生成8种变换"
def test_all_transforms_uniqueness(self):
"""测试所有变换是否唯一(对于非对称矩阵)"""
transforms = BoardTransform.get_all_transforms(self.test_board)
# 将每个变换转换为字符串进行比较
transform_strings = [str(t.flatten()) for t in transforms]
unique_transforms = set(transform_strings)
assert len(unique_transforms) == 8, "对于非对称矩阵8种变换应该都不相同"
def test_transform_properties(self):
"""测试变换的数学性质"""
board = self.test_board
# 测试4次90度旋转应该回到原始状态
result = board.copy()
for _ in range(4):
result = BoardTransform.rotate_90(result)
np.testing.assert_array_equal(result, board)
# 测试两次水平翻转应该回到原始状态
flipped = BoardTransform.flip_horizontal(board)
double_flipped = BoardTransform.flip_horizontal(flipped)
np.testing.assert_array_equal(double_flipped, board)
def test_canonical_form_consistency(self):
"""测试规范形式的一致性"""
board = self.test_board
transforms = BoardTransform.get_all_transforms(board)
# 所有变换的规范形式应该相同
canonical_forms = []
transform_indices = []
for transform in transforms:
canonical, idx = BoardTransform.get_canonical_form(transform)
canonical_forms.append(canonical)
transform_indices.append(idx)
# 所有规范形式应该相同
first_canonical = canonical_forms[0]
for canonical in canonical_forms[1:]:
np.testing.assert_array_equal(canonical, first_canonical)
def test_hash_consistency(self):
"""测试哈希值的一致性"""
board = self.test_board
transforms = BoardTransform.get_all_transforms(board)
# 所有变换的哈希值应该相同
hashes = [BoardTransform.compute_hash(transform) for transform in transforms]
first_hash = hashes[0]
for hash_val in hashes[1:]:
assert hash_val == first_hash, "所有等价变换的哈希值应该相同"
def test_symmetric_board(self):
"""测试对称棋盘的情况"""
# 创建一个完全对称的棋盘
symmetric_board = np.array([
[1, 2, 2, 1],
[2, 3, 3, 2],
[2, 3, 3, 2],
[1, 2, 2, 1]
])
transforms = BoardTransform.get_all_transforms(symmetric_board)
# 对于这个特殊的对称棋盘,某些变换可能相同
# 但规范形式应该仍然一致
canonical, _ = BoardTransform.get_canonical_form(symmetric_board)
for transform in transforms:
transform_canonical, _ = BoardTransform.get_canonical_form(transform)
np.testing.assert_array_equal(transform_canonical, canonical)
def test_edge_cases(self):
"""测试边界情况"""
# 测试全零矩阵
zero_board = np.zeros((4, 4), dtype=int)
canonical_zero, _ = BoardTransform.get_canonical_form(zero_board)
np.testing.assert_array_equal(canonical_zero, zero_board)
# 测试单元素矩阵
single_element = np.array([[1]])
canonical_single, _ = BoardTransform.get_canonical_form(single_element)
np.testing.assert_array_equal(canonical_single, single_element)
# 测试1x4矩阵
row_matrix = np.array([[1, 2, 3, 4]])
transforms_row = BoardTransform.get_all_transforms(row_matrix)
assert len(transforms_row) == 8
def test_different_board_sizes(self):
"""测试不同大小的棋盘"""
# 测试3x3棋盘
board_3x3 = np.array([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
])
transforms_3x3 = BoardTransform.get_all_transforms(board_3x3)
assert len(transforms_3x3) == 8
# 验证规范形式一致性
hashes_3x3 = [BoardTransform.compute_hash(t) for t in transforms_3x3]
assert all(h == hashes_3x3[0] for h in hashes_3x3)
# 测试2x3矩形棋盘
board_2x3 = np.array([
[1, 2, 3],
[4, 5, 6]
])
transforms_2x3 = BoardTransform.get_all_transforms(board_2x3)
assert len(transforms_2x3) == 8
# 验证规范形式一致性
hashes_2x3 = [BoardTransform.compute_hash(t) for t in transforms_2x3]
assert all(h == hashes_2x3[0] for h in hashes_2x3)
def test_manual_verification():
"""手动验证一些关键变换"""
# 创建一个简单的测试用例进行手动验证
board = np.array([
[1, 2],
[3, 4]
])
transforms = BoardTransform.get_all_transforms(board)
# 预期的8种变换结果
expected_transforms = [
np.array([[1, 2], [3, 4]]), # R0: 原始
np.array([[3, 1], [4, 2]]), # R90: 旋转90°
np.array([[4, 3], [2, 1]]), # R180: 旋转180°
np.array([[2, 4], [1, 3]]), # R270: 旋转270°
np.array([[2, 1], [4, 3]]), # F: 水平翻转
np.array([[4, 2], [3, 1]]), # F+R90: 翻转后旋转90°
np.array([[3, 4], [1, 2]]), # F+R180: 翻转后旋转180°
np.array([[1, 3], [2, 4]]) # F+R270: 翻转后旋转270°
]
print("手动验证2x2矩阵的8种变换:")
print(f"原始矩阵:\n{board}")
for i, (actual, expected) in enumerate(zip(transforms, expected_transforms)):
print(f"\n变换 {i}:")
print(f"实际结果:\n{actual}")
print(f"预期结果:\n{expected}")
np.testing.assert_array_equal(actual, expected,
err_msg=f"变换 {i} 不匹配")
print("\n所有变换验证通过!")
if __name__ == "__main__":
# 运行手动验证
test_manual_verification()
# 运行pytest测试
pytest.main([__file__, "-v"])

311
tests/test_cache_system.py Normal file
View File

@@ -0,0 +1,311 @@
"""
内存缓存系统测试
验证TrainingDataCache的功能和性能
"""
import numpy as np
import pytest
import time
from training_data import TrainingDataCache, TrainingExample
class TestTrainingDataCache:
"""训练数据缓存测试类"""
def setup_method(self):
"""测试前的设置"""
self.cache = TrainingDataCache(max_size=5) # 小缓存便于测试
# 创建测试样本
self.sample_examples = []
for i in range(10):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i * 100),
canonical_hash=f"hash_{i}"
)
self.sample_examples.append(example)
def test_basic_operations(self):
"""测试基本的存取操作"""
# 测试空缓存
assert self.cache.size() == 0
assert self.cache.get("nonexistent") is None
# 添加一个项目
example = self.sample_examples[0]
self.cache.put("key1", example)
assert self.cache.size() == 1
retrieved = self.cache.get("key1")
assert retrieved is not None
assert retrieved.value == example.value
assert retrieved.action == example.action
def test_lru_eviction(self):
"""测试LRU淘汰机制"""
# 填满缓存
for i in range(5):
self.cache.put(f"key_{i}", self.sample_examples[i])
assert self.cache.size() == 5
# 访问key_1使其成为最近使用的
self.cache.get("key_1")
# 添加新项目应该淘汰key_0最久未使用
self.cache.put("key_5", self.sample_examples[5])
assert self.cache.size() == 5
assert self.cache.get("key_0") is None # 应该被淘汰
assert self.cache.get("key_1") is not None # 应该还在
assert self.cache.get("key_5") is not None # 新添加的
def test_update_existing(self):
"""测试更新现有项目"""
example1 = self.sample_examples[0]
example2 = self.sample_examples[1]
# 添加项目
self.cache.put("key1", example1)
assert self.cache.get("key1").value == example1.value
# 更新项目
self.cache.put("key1", example2)
assert self.cache.size() == 1 # 大小不变
assert self.cache.get("key1").value == example2.value # 值已更新
def test_update_if_better(self):
"""测试条件更新功能"""
low_value_example = TrainingExample(
board_state=np.zeros((4, 4)),
action=0,
value=100.0,
canonical_hash="test_hash"
)
high_value_example = TrainingExample(
board_state=np.zeros((4, 4)),
action=0,
value=200.0,
canonical_hash="test_hash"
)
# 首次添加
result = self.cache.update_if_better("key1", low_value_example)
assert result is True
assert self.cache.get("key1").value == 100.0
# 用更高价值更新
result = self.cache.update_if_better("key1", high_value_example)
assert result is True
assert self.cache.get("key1").value == 200.0
# 用更低价值尝试更新(应该失败)
result = self.cache.update_if_better("key1", low_value_example)
assert result is False
assert self.cache.get("key1").value == 200.0 # 值不变
def test_clear(self):
"""测试清空缓存"""
# 添加一些项目
for i in range(3):
self.cache.put(f"key_{i}", self.sample_examples[i])
assert self.cache.size() == 3
# 清空缓存
self.cache.clear()
assert self.cache.size() == 0
for i in range(3):
assert self.cache.get(f"key_{i}") is None
def test_get_all_examples(self):
"""测试获取所有样本"""
# 添加一些项目
added_examples = []
for i in range(3):
example = self.sample_examples[i]
self.cache.put(f"key_{i}", example)
added_examples.append(example)
# 获取所有样本
all_examples = self.cache.get_all_examples()
assert len(all_examples) == 3
# 验证所有样本都在其中(顺序可能不同)
all_values = {ex.value for ex in all_examples}
expected_values = {ex.value for ex in added_examples}
assert all_values == expected_values
def test_access_order_tracking(self):
"""测试访问顺序跟踪"""
# 先填满缓存
for i in range(5):
self.cache.put(f"key_{i}", self.sample_examples[i])
# 访问key_1使其成为最近使用的
self.cache.get("key_1")
# 访问key_3
self.cache.get("key_3")
# 现在访问顺序应该是key_0最久, key_2, key_4, key_1, key_3最新
# 添加两个新项目应该淘汰key_0和key_2
self.cache.put("key_5", self.sample_examples[5])
self.cache.put("key_6", self.sample_examples[6])
assert self.cache.get("key_0") is None # 最久的,应该被淘汰
assert self.cache.get("key_2") is None # 第二久的,应该被淘汰
assert self.cache.get("key_1") is not None # 应该还在
assert self.cache.get("key_3") is not None # 应该还在
assert self.cache.get("key_4") is not None # 应该还在
assert self.cache.get("key_5") is not None # 新添加的
assert self.cache.get("key_6") is not None # 新添加的
class TestCachePerformance:
"""缓存性能测试"""
def test_large_cache_performance(self):
"""测试大缓存的性能"""
large_cache = TrainingDataCache(max_size=10000)
# 创建大量测试数据
examples = []
for i in range(5000):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i),
canonical_hash=f"hash_{i}"
)
examples.append(example)
# 测试插入性能
start_time = time.time()
for i, example in enumerate(examples):
large_cache.put(f"key_{i}", example)
insert_time = time.time() - start_time
print(f"插入5000个项目耗时: {insert_time:.3f}")
assert insert_time < 1.0, "插入操作应该很快"
# 测试查询性能
start_time = time.time()
for i in range(1000): # 随机查询1000次
key = f"key_{np.random.randint(0, 5000)}"
large_cache.get(key)
query_time = time.time() - start_time
print(f"1000次随机查询耗时: {query_time:.3f}")
assert query_time < 0.1, "查询操作应该很快"
# 验证缓存大小
assert large_cache.size() == 5000
def test_memory_usage(self):
"""测试内存使用情况"""
import sys
cache = TrainingDataCache(max_size=1000)
# 测量空缓存的内存使用
initial_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order)
# 添加数据
for i in range(500):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i),
canonical_hash=f"hash_{i}"
)
cache.put(f"key_{i}", example)
# 测量填充后的内存使用
filled_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order)
print(f"空缓存内存使用: {initial_size} bytes")
print(f"500项目缓存内存使用: {filled_size} bytes")
print(f"平均每项目内存使用: {(filled_size - initial_size) / 500:.2f} bytes")
def test_cache_thread_safety():
"""测试缓存的线程安全性(基础测试)"""
import threading
import time
cache = TrainingDataCache(max_size=1000)
errors = []
def worker(worker_id):
"""工作线程函数"""
try:
for i in range(100):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(worker_id * 100 + i),
canonical_hash=f"hash_{worker_id}_{i}"
)
key = f"worker_{worker_id}_key_{i}"
cache.put(key, example)
# 随机读取
if i % 10 == 0:
cache.get(key)
# 短暂休眠
time.sleep(0.001)
except Exception as e:
errors.append(f"Worker {worker_id}: {e}")
# 创建多个线程
threads = []
for i in range(5):
thread = threading.Thread(target=worker, args=(i,))
threads.append(thread)
thread.start()
# 等待所有线程完成
for thread in threads:
thread.join()
# 检查是否有错误
if errors:
print("线程安全测试中的错误:")
for error in errors:
print(f" {error}")
print(f"最终缓存大小: {cache.size()}")
print(f"线程安全测试完成,错误数: {len(errors)}")
if __name__ == "__main__":
# 运行基本测试
print("运行缓存系统测试...")
# 运行性能测试
print("\n运行性能测试...")
perf_test = TestCachePerformance()
perf_test.test_large_cache_performance()
perf_test.test_memory_usage()
# 运行线程安全测试
print("\n运行线程安全测试...")
test_cache_thread_safety()
# 运行pytest测试
print("\n运行pytest测试...")
pytest.main([__file__, "-v"])

289
tests/test_game_engine.py Normal file
View File

@@ -0,0 +1,289 @@
"""
2048游戏引擎测试
验证新游戏引擎的功能和正确性
"""
import numpy as np
import pytest
from game import Game2048, GameState
class TestGame2048:
"""2048游戏引擎测试类"""
def setup_method(self):
"""测试前的设置"""
self.game = Game2048(height=4, width=4, seed=42)
def test_initialization(self):
"""测试游戏初始化"""
game = Game2048(height=3, width=4, seed=123)
assert game.height == 3
assert game.width == 4
assert game.score == 0
assert game.moves == 0
assert not game.is_over
# 应该有两个初始数字
non_zero_count = np.count_nonzero(game.board)
assert non_zero_count == 2
# 初始数字应该是1或2对数形式的2或4
non_zero_values = game.board[game.board != 0]
assert all(val in [1, 2] for val in non_zero_values)
def test_move_row_left(self):
"""测试行向左移动逻辑"""
# 测试简单移动
row = np.array([0, 1, 0, 2])
result, score = self.game._move_row_left(row)
expected = np.array([1, 2, 0, 0])
np.testing.assert_array_equal(result, expected)
assert score == 0
# 测试合并
row = np.array([1, 1, 2, 2])
result, score = self.game._move_row_left(row)
expected = np.array([2, 3, 0, 0])
np.testing.assert_array_equal(result, expected)
# 分数应该是 2^2 + 2^3 = 4 + 8 = 12
assert score == 12
# 测试复杂情况
row = np.array([1, 1, 1, 1])
result, score = self.game._move_row_left(row)
expected = np.array([2, 2, 0, 0])
np.testing.assert_array_equal(result, expected)
# 分数应该是 2^2 + 2^2 = 4 + 4 = 8
assert score == 8
def test_move_directions(self):
"""测试四个方向的移动"""
# 创建特定的棋盘状态
game = Game2048(height=3, width=3, seed=42)
game.board = np.array([
[1, 0, 1],
[0, 2, 0],
[1, 0, 1]
])
initial_score = game.score
# 测试向左移动
game_left = game.copy()
success = game_left.move(2) # 左
assert success
# 测试向右移动
game_right = game.copy()
success = game_right.move(3) # 右
assert success
# 测试向上移动
game_up = game.copy()
success = game_up.move(0) # 上
assert success
# 测试向下移动
game_down = game.copy()
success = game_down.move(1) # 下
assert success
# 所有移动都应该改变棋盘状态
assert not np.array_equal(game.board, game_left.board)
assert not np.array_equal(game.board, game_right.board)
assert not np.array_equal(game.board, game_up.board)
assert not np.array_equal(game.board, game_down.board)
def test_score_calculation(self):
"""测试分数计算"""
game = Game2048(height=2, width=2, seed=42)
# 设置特定棋盘状态
game.board = np.array([
[1, 2], # 2, 4
[3, 4] # 8, 16
])
# 计算累积分数
total_score = game.calculate_total_score()
# 根据论文公式V(N) = (log2(N) - 1) * N
# V(2) = 0, V(4) = 4, V(8) = 16, V(16) = 48
expected = 0 + 4 + 16 + 48
assert total_score == expected
def test_game_over_detection(self):
"""测试游戏结束检测"""
game = Game2048(height=2, width=2, seed=42)
# 设置无法移动的棋盘
game.board = np.array([
[1, 2], # 2, 4
[3, 4] # 8, 16
])
game._check_game_over()
assert game.is_over
# 测试可以移动的棋盘
game.board = np.array([
[1, 1], # 2, 2 (可以合并)
[3, 4] # 8, 16
])
game.is_over = False
game._check_game_over()
assert not game.is_over
def test_valid_moves(self):
"""测试有效移动检测"""
game = Game2048(height=2, width=2, seed=42)
# 设置可以向所有方向移动的棋盘
game.board = np.array([
[1, 0],
[0, 1]
])
valid_moves = game.get_valid_moves()
assert len(valid_moves) == 4 # 所有方向都可以移动
# 设置无法移动的棋盘
game.board = np.array([
[1, 2],
[3, 4]
])
valid_moves = game.get_valid_moves()
assert len(valid_moves) == 0 # 无法移动
def test_board_display(self):
"""测试棋盘显示"""
game = Game2048(height=2, width=2, seed=42)
# 设置对数形式的棋盘
game.board = np.array([
[0, 1], # 0, 2
[2, 3] # 4, 8
])
display_board = game.get_board_display()
expected = np.array([
[0, 2],
[4, 8]
])
np.testing.assert_array_equal(display_board, expected)
def test_max_tile(self):
"""测试最大数字获取"""
game = Game2048(height=2, width=2, seed=42)
game.board = np.array([
[1, 2], # 2, 4
[3, 4] # 8, 16
])
max_tile = game.get_max_tile()
assert max_tile == 16
def test_state_management(self):
"""测试游戏状态管理"""
game = Game2048(height=2, width=2, seed=42)
# 获取初始状态
initial_state = game.get_state()
assert isinstance(initial_state, GameState)
assert initial_state.score == game.score
assert initial_state.moves == game.moves
assert np.array_equal(initial_state.board, game.board)
# 执行移动
move_success = game.move(2) # 左移
# 获取新状态
new_state = game.get_state()
# 只有移动成功时才检查移动次数
if move_success:
assert new_state.moves == initial_state.moves + 1
assert not np.array_equal(new_state.board, initial_state.board)
else:
# 如果移动失败,尝试其他方向
for direction in range(4):
if game.move(direction):
new_state = game.get_state()
assert new_state.moves == initial_state.moves + 1
assert not np.array_equal(new_state.board, initial_state.board)
break
# 恢复状态
game.set_state(initial_state)
assert game.score == initial_state.score
assert game.moves == initial_state.moves
np.testing.assert_array_equal(game.board, initial_state.board)
def test_copy_functionality(self):
"""测试游戏复制功能"""
game = Game2048(height=3, width=3, seed=42)
# 执行一些操作
game.move(2)
game.move(0)
# 创建副本
game_copy = game.copy()
# 验证副本
assert game_copy.height == game.height
assert game_copy.width == game.width
assert game_copy.score == game.score
assert game_copy.moves == game.moves
assert game_copy.is_over == game.is_over
np.testing.assert_array_equal(game_copy.board, game.board)
# 修改副本不应影响原游戏
game_copy.move(1)
assert game_copy.moves != game.moves
def test_different_board_sizes(self):
"""测试不同大小的棋盘"""
# 测试3x3棋盘
game_3x3 = Game2048(height=3, width=3, seed=42)
assert game_3x3.board.shape == (3, 3)
# 测试2x4矩形棋盘
game_2x4 = Game2048(height=2, width=4, seed=42)
assert game_2x4.board.shape == (2, 4)
# 测试移动功能
success = game_3x3.move(2)
assert isinstance(success, bool)
success = game_2x4.move(0)
assert isinstance(success, bool)
def test_spawn_probability(self):
"""测试数字生成概率"""
# 测试只生成2的情况
game_only_2 = Game2048(height=4, width=4, spawn_prob_4=0.0, seed=42)
# 重置并检查生成的数字
game_only_2.reset()
non_zero_values = game_only_2.board[game_only_2.board != 0]
assert all(val == 1 for val in non_zero_values) # 只有1对数形式的2
# 测试只生成4的情况
game_only_4 = Game2048(height=4, width=4, spawn_prob_4=1.0, seed=42)
game_only_4.reset()
non_zero_values = game_only_4.board[game_only_4.board != 0]
assert all(val == 2 for val in non_zero_values) # 只有2对数形式的4
if __name__ == "__main__":
# 运行测试
print("运行2048游戏引擎测试...")
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,210 @@
"""
性能基准测试
测试不同MCTS实现的性能对比
"""
import time
import torch
import pytest
from game import Game2048
from torch_mcts import TorchMCTS
class TestPerformanceBenchmark:
"""性能基准测试类"""
@pytest.fixture
def game(self):
"""测试游戏状态"""
return Game2048(height=3, width=3, seed=42)
def test_cpu_mcts_performance(self, game):
"""测试CPU MCTS性能"""
mcts = TorchMCTS(
c_param=1.414,
max_simulation_depth=50,
device="cpu"
)
simulations = 2000
start_time = time.time()
action, stats = mcts.search(game, simulations)
elapsed_time = time.time() - start_time
speed = simulations / elapsed_time
# CPU MCTS应该达到基本性能要求
assert speed > 500, f"CPU MCTS性能过低: {speed:.1f} 模拟/秒"
assert action in game.get_valid_moves()
def test_auto_device_mcts_performance(self, game):
"""测试自动设备选择MCTS性能"""
mcts = TorchMCTS(
c_param=1.414,
max_simulation_depth=50,
device="auto"
)
simulations = 2000
start_time = time.time()
action, stats = mcts.search(game, simulations)
elapsed_time = time.time() - start_time
speed = simulations / elapsed_time
# 自动设备选择应该有合理性能
assert speed > 100, f"自动设备MCTS性能过低: {speed:.1f} 模拟/秒"
assert action in game.get_valid_moves()
if mcts.device.type == "cuda":
del mcts
torch.cuda.empty_cache()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
def test_gpu_mcts_performance(self, game):
"""测试GPU MCTS性能"""
gpu_mcts = TorchMCTS(
max_simulation_depth=50,
batch_size=8192,
device="cuda"
)
simulations = 5000
torch.cuda.synchronize()
start_time = time.time()
action, stats = gpu_mcts.search(game, simulations)
torch.cuda.synchronize()
elapsed_time = time.time() - start_time
speed = simulations / elapsed_time
# GPU MCTS应该有显著性能提升
assert speed > 200, f"GPU MCTS性能过低: {speed:.1f} 模拟/秒"
assert action in game.get_valid_moves()
del gpu_mcts
torch.cuda.empty_cache()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
def test_performance_comparison(self, game):
"""性能对比测试"""
simulations = 3000
results = {}
# CPU MCTS
cpu_mcts = TorchMCTS(c_param=1.414, max_simulation_depth=50, device="cpu")
start_time = time.time()
cpu_action, cpu_stats = cpu_mcts.search(game.copy(), simulations)
cpu_time = time.time() - start_time
results['CPU'] = simulations / cpu_time
# GPU MCTS
gpu_mcts = TorchMCTS(max_simulation_depth=50, batch_size=8192, device="cuda")
torch.cuda.synchronize()
start_time = time.time()
gpu_action, gpu_stats = gpu_mcts.search(game.copy(), simulations)
torch.cuda.synchronize()
gpu_time = time.time() - start_time
results['GPU'] = simulations / gpu_time
# 验证性能提升
speedup = results['GPU'] / results['CPU']
print(f"\n性能对比:")
print(f" CPU: {results['CPU']:.1f} 模拟/秒")
print(f" GPU: {results['GPU']:.1f} 模拟/秒")
print(f" 加速比: {speedup:.1f}x")
# GPU应该有一定的性能优势至少不能太慢
assert speedup > 0.1, f"GPU性能严重低于CPU: {speedup:.2f}x"
# 清理
del cpu_mcts, gpu_mcts
torch.cuda.empty_cache()
def test_batch_size_scaling(self):
"""测试批次大小对性能的影响"""
if not torch.cuda.is_available():
pytest.skip("CUDA不可用")
game = Game2048(height=3, width=3, seed=42)
batch_sizes = [1024, 4096, 16384]
simulations = 2000
results = {}
for batch_size in batch_sizes:
gpu_mcts = TorchMCTS(
max_simulation_depth=50,
batch_size=batch_size,
device="cuda"
)
torch.cuda.synchronize()
start_time = time.time()
action, stats = gpu_mcts.search(game.copy(), simulations)
torch.cuda.synchronize()
elapsed_time = time.time() - start_time
speed = simulations / elapsed_time
results[batch_size] = speed
del gpu_mcts
torch.cuda.empty_cache()
# 验证批次大小的影响
speeds = list(results.values())
max_speed = max(speeds)
min_speed = min(speeds)
# 不同批次大小的性能差异应该在合理范围内
speed_ratio = max_speed / min_speed
assert speed_ratio < 10, f"批次大小性能差异过大: {speed_ratio:.2f}"
print(f"\n批次大小性能测试:")
for batch_size, speed in results.items():
print(f" {batch_size:,}: {speed:.1f} 模拟/秒")
def test_memory_efficiency():
"""内存效率测试"""
if not torch.cuda.is_available():
pytest.skip("CUDA不可用")
game = Game2048(height=3, width=3, seed=42)
torch.cuda.empty_cache()
initial_memory = torch.cuda.memory_allocated()
gpu_mcts = TorchMCTS(
max_simulation_depth=50,
batch_size=32768,
device="cuda"
)
# 执行搜索
action, stats = gpu_mcts.search(game, 10000)
peak_memory = torch.cuda.max_memory_allocated()
memory_used = (peak_memory - initial_memory) / 1e6 # MB
# 内存使用应该合理
assert memory_used < 500, f"GPU内存使用过多: {memory_used:.1f} MB"
# 计算内存效率
speed = stats['sims_per_second']
memory_efficiency = speed / memory_used if memory_used > 0 else 0
print(f"\n内存效率测试:")
print(f" 内存使用: {memory_used:.1f} MB")
print(f" 模拟速度: {speed:.1f} 模拟/秒")
print(f" 内存效率: {memory_efficiency:.1f} 模拟/秒/MB")
# 清理
del gpu_mcts
torch.cuda.empty_cache()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

329
tests/test_persistence.py Normal file
View File

@@ -0,0 +1,329 @@
"""
硬盘持久化系统测试
验证TrainingDataPersistence的功能和可靠性
"""
import numpy as np
import torch
import pytest
import tempfile
import shutil
import os
from pathlib import Path
from training_data import (
TrainingDataPersistence,
TrainingDataCache,
TrainingExample,
TrainingDataManager
)
class TestTrainingDataPersistence:
"""训练数据持久化测试类"""
def setup_method(self):
"""测试前的设置"""
# 创建临时目录用于测试
self.temp_dir = tempfile.mkdtemp()
self.persistence = TrainingDataPersistence(self.temp_dir)
# 创建测试样本
self.test_examples = []
for i in range(10):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i * 100),
canonical_hash=f"hash_{i}"
)
self.test_examples.append(example)
def teardown_method(self):
"""测试后的清理"""
# 删除临时目录
shutil.rmtree(self.temp_dir)
def test_save_and_load_cache(self):
"""测试缓存的保存和加载"""
# 创建缓存并添加数据
cache = TrainingDataCache(max_size=100)
for i, example in enumerate(self.test_examples[:5]):
cache.put(f"key_{i}", example)
# 保存缓存
filename = "test_cache"
self.persistence.save_cache(cache, filename)
# 验证文件存在
expected_path = Path(self.temp_dir) / f"{filename}.pkl"
assert expected_path.exists()
# 加载缓存
loaded_examples = self.persistence.load_cache(filename)
# 验证加载的数据
assert len(loaded_examples) == 5
# 验证数据内容
loaded_values = {ex.value for ex in loaded_examples}
original_values = {ex.value for ex in self.test_examples[:5]}
assert loaded_values == original_values
def test_save_examples_batch(self):
"""测试批量保存样本"""
batch_name = "test_batch"
examples = self.test_examples[:7]
# 保存批次
self.persistence.save_examples_batch(examples, batch_name)
# 验证文件存在
expected_path = Path(self.temp_dir) / f"{batch_name}.pkl"
assert expected_path.exists()
# 加载并验证
loaded_examples = self.persistence.load_cache(batch_name)
assert len(loaded_examples) == 7
# 验证数据完整性
for original, loaded in zip(examples, loaded_examples):
assert original.action == loaded.action
assert original.value == loaded.value
assert original.canonical_hash == loaded.canonical_hash
np.testing.assert_array_equal(original.board_state, loaded.board_state)
def test_load_nonexistent_file(self):
"""测试加载不存在的文件"""
loaded_examples = self.persistence.load_cache("nonexistent_file")
assert loaded_examples == []
def test_list_saved_files(self):
"""测试列出保存的文件"""
# 初始应该没有文件
files = self.persistence.list_saved_files()
assert len(files) == 0
# 保存一些文件
cache = TrainingDataCache(max_size=100)
cache.put("key1", self.test_examples[0])
self.persistence.save_cache(cache, "file1")
self.persistence.save_cache(cache, "file2")
self.persistence.save_examples_batch(self.test_examples[:3], "batch1")
# 检查文件列表
files = self.persistence.list_saved_files()
assert len(files) == 3
assert "file1" in files
assert "file2" in files
assert "batch1" in files
def test_large_data_persistence(self):
"""测试大数据量的持久化"""
# 创建大量测试数据
large_examples = []
for i in range(1000):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i),
canonical_hash=f"hash_{i}"
)
large_examples.append(example)
# 保存大批次
batch_name = "large_batch"
self.persistence.save_examples_batch(large_examples, batch_name)
# 加载并验证
loaded_examples = self.persistence.load_cache(batch_name)
assert len(loaded_examples) == 1000
# 验证一些随机样本
for i in [0, 100, 500, 999]:
assert loaded_examples[i].value == float(i)
assert loaded_examples[i].action == i % 4
def test_data_integrity(self):
"""测试数据完整性"""
# 创建包含特殊值的测试数据
special_board = np.array([
[0, 1, 2, 17], # 包含边界值
[3, 4, 5, 6],
[7, 8, 9, 10],
[11, 12, 13, 14]
])
special_example = TrainingExample(
board_state=special_board,
action=3,
value=12345.67,
canonical_hash="special_hash_123"
)
# 保存
self.persistence.save_examples_batch([special_example], "special_test")
# 加载
loaded = self.persistence.load_cache("special_test")
assert len(loaded) == 1
loaded_example = loaded[0]
# 验证所有字段
np.testing.assert_array_equal(loaded_example.board_state, special_board)
assert loaded_example.action == 3
assert abs(loaded_example.value - 12345.67) < 1e-6
assert loaded_example.canonical_hash == "special_hash_123"
class TestTrainingDataManager:
"""训练数据管理器测试类"""
def setup_method(self):
"""测试前的设置"""
self.temp_dir = tempfile.mkdtemp()
self.manager = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
def teardown_method(self):
"""测试后的清理"""
shutil.rmtree(self.temp_dir)
def test_add_and_retrieve_examples(self):
"""测试添加和检索训练样本"""
# 创建测试棋盘
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
# 添加训练样本
cache_key = self.manager.add_training_example(board, action=1, value=500.0)
assert cache_key is not None
# 验证缓存统计
stats = self.manager.get_cache_stats()
assert stats["cache_size"] == 1
# 获取PyTorch数据集
dataset = self.manager.get_pytorch_dataset()
assert len(dataset) == 1
# 验证数据集内容
board_tensor, action_tensor, value_tensor = dataset[0]
assert action_tensor.item() == 1
assert abs(value_tensor.item() - 500.0) < 1e-6
def test_save_and_load_workflow(self):
"""测试完整的保存和加载工作流"""
# 添加一些训练样本
boards = [
np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]]),
np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]]),
np.array([[8, 16, 32, 64], [4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8]])
]
for i, board in enumerate(boards):
for action in range(4):
value = (i + 1) * 100 + action * 10
self.manager.add_training_example(board, action, value)
# 保存当前缓存
self.manager.save_current_cache("workflow_test")
# 创建新的管理器
new_manager = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
# 加载数据
loaded_count = new_manager.load_from_file("workflow_test")
assert loaded_count == 12 # 3个棋盘 × 4个动作
# 验证数据
dataset = new_manager.get_pytorch_dataset()
assert len(dataset) == 12
def test_merge_caches(self):
"""测试缓存合并功能"""
# 在第一个管理器中添加数据
board1 = np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]])
self.manager.add_training_example(board1, 0, 100.0)
self.manager.add_training_example(board1, 1, 200.0)
# 创建第二个管理器
manager2 = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
# 在第二个管理器中添加不同的数据
board2 = np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]])
manager2.add_training_example(board2, 0, 300.0)
manager2.add_training_example(board2, 1, 400.0)
# 合并缓存
merged_count = self.manager.merge_caches(manager2)
assert merged_count == 2
# 验证合并后的数据
stats = self.manager.get_cache_stats()
assert stats["cache_size"] == 4
dataset = self.manager.get_pytorch_dataset()
assert len(dataset) == 4
def test_pytorch_integration(self):
"""测试PyTorch集成"""
# 添加测试数据
for i in range(10):
board = np.random.randint(0, 16, size=(4, 4))
# 确保至少有一些非零值
board[0, 0] = 2 ** (i % 4 + 1)
action = i % 4
value = float(i * 50)
self.manager.add_training_example(board, action, value)
# 获取DataLoader
dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False)
# 验证批次
batch_count = 0
total_samples = 0
for boards, actions, values in dataloader:
batch_count += 1
batch_size = boards.shape[0]
total_samples += batch_size
# 验证张量形状
assert boards.shape == (batch_size, 18, 4, 4) # max_tile_value + 1 = 18
assert actions.shape == (batch_size,)
assert values.shape == (batch_size,)
# 验证数据类型
assert boards.dtype == torch.float32
assert actions.dtype == torch.long
assert values.dtype == torch.float32
assert total_samples == 10
assert batch_count == 4 # ceil(10/3) = 4
if __name__ == "__main__":
# 运行测试
print("运行持久化系统测试...")
pytest.main([__file__, "-v"])

295
tests/test_torch_mcts.py Normal file
View File

@@ -0,0 +1,295 @@
"""
PyTorch MCTS测试
测试统一的PyTorch MCTS实现
"""
import pytest
import torch
import time
import numpy as np
from game import Game2048
from torch_mcts import TorchMCTS
from training_data import TrainingDataManager
class TestTorchMCTS:
"""PyTorch MCTS测试类"""
@pytest.fixture
def game(self):
"""测试游戏状态"""
return Game2048(height=3, width=3, seed=42)
@pytest.fixture
def cpu_mcts(self):
"""CPU MCTS实例"""
return TorchMCTS(
c_param=1.414,
max_simulation_depth=30,
batch_size=1024,
device="cpu"
)
@pytest.fixture
def gpu_mcts(self):
"""GPU MCTS实例"""
if not torch.cuda.is_available():
pytest.skip("CUDA不可用")
return TorchMCTS(
c_param=1.414,
max_simulation_depth=30,
batch_size=4096,
device="cuda"
)
def test_cpu_mcts_basic_functionality(self, game, cpu_mcts):
"""测试CPU MCTS基本功能"""
# 执行搜索
action, stats = cpu_mcts.search(game, 1000)
# 验证结果
assert action in game.get_valid_moves(), f"选择了无效动作: {action}"
assert 'action_visits' in stats
assert 'action_avg_values' in stats
assert 'sims_per_second' in stats
assert stats['device'] == 'cpu'
# 验证访问次数
total_visits = sum(stats['action_visits'].values())
assert total_visits == 1000, f"访问次数不匹配: {total_visits}"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
def test_gpu_mcts_basic_functionality(self, game, gpu_mcts):
"""测试GPU MCTS基本功能"""
# 执行搜索
action, stats = gpu_mcts.search(game, 2000)
# 验证结果
assert action in game.get_valid_moves(), f"选择了无效动作: {action}"
assert 'action_visits' in stats
assert 'action_avg_values' in stats
assert 'sims_per_second' in stats
assert stats['device'] == 'cuda'
# 验证访问次数
total_visits = sum(stats['action_visits'].values())
assert total_visits == 2000, f"访问次数不匹配: {total_visits}"
def test_action_distribution_quality(self, game, cpu_mcts):
"""测试动作分布质量"""
action, stats = cpu_mcts.search(game, 5000)
action_visits = stats['action_visits']
visit_values = list(action_visits.values())
# 检查分布不应该完全均匀MCTS应该有偏向性
assert len(set(visit_values)) > 1, "动作分布完全均匀不符合MCTS预期"
# 检查最佳动作应该有最多访问次数
best_action_visits = action_visits[action]
assert best_action_visits == max(visit_values), "最佳动作访问次数不是最多"
# 检查价值的合理性
action_values = stats['action_avg_values']
for act, value in action_values.items():
assert value > 0, f"动作{act}的价值应该为正: {value}"
assert value < 100000, f"动作{act}的价值过大: {value}"
def test_device_auto_selection(self, game):
"""测试设备自动选择"""
mcts = TorchMCTS(device="auto", batch_size=1024)
# 验证设备选择
if torch.cuda.is_available():
assert mcts.device.type == "cuda"
else:
assert mcts.device.type == "cpu"
# 执行搜索验证功能
action, stats = mcts.search(game, 1000)
assert action in game.get_valid_moves()
if mcts.device.type == "cuda":
del mcts
torch.cuda.empty_cache()
def test_batch_size_auto_selection(self, game):
"""测试批次大小自动选择"""
# CPU自动选择
cpu_mcts = TorchMCTS(device="cpu", batch_size=None)
assert cpu_mcts.batch_size == 4096 # CPU默认批次大小
# GPU自动选择如果可用
if torch.cuda.is_available():
gpu_mcts = TorchMCTS(device="cuda", batch_size=None)
assert gpu_mcts.batch_size == 32768 # GPU默认批次大小
del gpu_mcts
torch.cuda.empty_cache()
def test_performance_cpu(self, game, cpu_mcts):
"""测试CPU性能"""
simulations = 2000
start_time = time.time()
action, stats = cpu_mcts.search(game, simulations)
elapsed_time = time.time() - start_time
speed = simulations / elapsed_time
# CPU应该达到基本性能要求
assert speed > 100, f"CPU性能过低: {speed:.1f} 模拟/秒"
# 验证统计信息准确性
assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
def test_performance_gpu(self, game, gpu_mcts):
"""测试GPU性能"""
simulations = 5000
torch.cuda.synchronize()
start_time = time.time()
action, stats = gpu_mcts.search(game, simulations)
torch.cuda.synchronize()
elapsed_time = time.time() - start_time
speed = simulations / elapsed_time
# GPU应该有合理的性能
assert speed > 50, f"GPU性能过低: {speed:.1f} 模拟/秒"
# 验证统计信息准确性
assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确"
def test_training_data_collection(self, game):
"""测试训练数据收集"""
# 创建训练数据管理器
training_manager = TrainingDataManager(
data_dir="data/test_torch_training",
cache_size=5000,
board_size=(3, 3)
)
mcts = TorchMCTS(
max_simulation_depth=30,
batch_size=1024,
device="cpu",
training_manager=training_manager
)
# 执行搜索
action, stats = mcts.search(game, 2000)
# 验证训练数据收集
cache_stats = training_manager.get_cache_stats()
assert cache_stats['cache_size'] > 0, "未收集到训练数据"
# 验证数据质量
assert cache_stats['cache_size'] <= 2000, "收集的样本数超出预期"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
def test_memory_management(self, game):
"""测试GPU内存管理"""
torch.cuda.empty_cache()
initial_memory = torch.cuda.memory_allocated()
gpu_mcts = TorchMCTS(
max_simulation_depth=30,
batch_size=8192,
device="cuda"
)
# 执行搜索
action, stats = gpu_mcts.search(game, 3000)
# 检查内存使用
peak_memory = torch.cuda.max_memory_allocated()
memory_used = (peak_memory - initial_memory) / 1e6 # MB
assert memory_used < 200, f"GPU内存使用过多: {memory_used:.1f} MB"
# 清理并验证内存释放
del gpu_mcts
torch.cuda.empty_cache()
final_memory = torch.cuda.memory_allocated()
assert final_memory <= initial_memory * 1.1, "GPU内存未正确释放"
def test_device_switching(self, game):
"""测试设备动态切换"""
mcts = TorchMCTS(device="cpu", batch_size=1024)
# 初始为CPU
assert mcts.device.type == "cpu"
action1, stats1 = mcts.search(game.copy(), 1000)
assert stats1['device'] == 'cpu'
# 切换到GPU如果可用
if torch.cuda.is_available():
mcts.set_device("cuda")
assert mcts.device.type == "cuda"
action2, stats2 = mcts.search(game.copy(), 1000)
assert stats2['device'] == 'cuda'
# 切换回CPU
mcts.set_device("cpu")
assert mcts.device.type == "cpu"
torch.cuda.empty_cache()
def test_consistency_across_devices(self, game):
"""测试不同设备间的一致性"""
if not torch.cuda.is_available():
pytest.skip("CUDA不可用")
# 使用相同的随机种子
np.random.seed(42)
cpu_mcts = TorchMCTS(device="cpu", batch_size=2048)
cpu_action, cpu_stats = cpu_mcts.search(game.copy(), 3000)
np.random.seed(42)
gpu_mcts = TorchMCTS(device="cuda", batch_size=2048)
gpu_action, gpu_stats = gpu_mcts.search(game.copy(), 3000)
# 由于随机性,动作可能不完全一致,但应该在合理范围内
# 这里主要验证两个设备都能正常工作
assert cpu_action in game.get_valid_moves()
assert gpu_action in game.get_valid_moves()
# 验证访问次数总和
cpu_total = sum(cpu_stats['action_visits'].values())
gpu_total = sum(gpu_stats['action_visits'].values())
assert cpu_total == gpu_total == 3000
del gpu_mcts
torch.cuda.empty_cache()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
def test_batch_size_optimization():
"""测试批次大小优化"""
game = Game2048(height=3, width=3, seed=42)
mcts = TorchMCTS(device="cuda", batch_size=4096)
# 执行批次大小优化
optimal_size = mcts.optimize_batch_size(game, test_simulations=1000)
# 验证优化结果
assert optimal_size > 0
assert mcts.batch_size == optimal_size
# 验证优化后的性能
action, stats = mcts.search(game, 2000)
assert action in game.get_valid_moves()
assert stats['sims_per_second'] > 0
del mcts
torch.cuda.empty_cache()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

209
tests/test_training_data.py Normal file
View File

@@ -0,0 +1,209 @@
"""
训练数据模块测试
测试棋盘变换、缓存系统、持久化等核心功能
"""
import numpy as np
import pytest
import tempfile
import shutil
from pathlib import Path
import torch
from training_data import (
BoardTransform,
ScoreCalculator,
TrainingDataCache,
TrainingExample,
TrainingDataManager
)
class TestBoardTransform:
"""棋盘变换测试"""
def test_log_transform(self):
"""测试对数变换"""
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
expected = np.array([
[1, 2, 3, 4],
[0, 1, 2, 3],
[0, 0, 1, 2],
[0, 0, 0, 1]
])
result = BoardTransform.log_transform(board)
np.testing.assert_array_equal(result, expected)
# 测试逆变换
restored = BoardTransform.inverse_log_transform(result)
np.testing.assert_array_equal(restored, board)
def test_canonical_form(self):
"""测试规范形式"""
board = np.array([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
])
transforms = BoardTransform.get_all_transforms(board)
assert len(transforms) == 8
# 所有变换的规范形式应该相同
canonical_forms = []
for transform in transforms:
canonical, _ = BoardTransform.get_canonical_form(transform)
canonical_forms.append(canonical)
first_canonical = canonical_forms[0]
for canonical in canonical_forms[1:]:
np.testing.assert_array_equal(canonical, first_canonical)
def test_hash_consistency(self):
"""测试哈希一致性"""
board = np.array([[1, 2], [3, 4]])
transforms = BoardTransform.get_all_transforms(board)
hashes = [BoardTransform.compute_hash(t) for t in transforms]
first_hash = hashes[0]
for hash_val in hashes[1:]:
assert hash_val == first_hash
class TestScoreCalculator:
"""分数计算测试"""
def test_tile_value_calculation(self):
"""测试瓦片价值计算"""
# V(2) = 0, V(4) = 4, V(8) = 16, V(16) = 48
assert ScoreCalculator.calculate_tile_value(1) == 0 # 2^1 = 2
assert ScoreCalculator.calculate_tile_value(2) == 4 # 2^2 = 4
assert ScoreCalculator.calculate_tile_value(3) == 16 # 2^3 = 8
assert ScoreCalculator.calculate_tile_value(4) == 48 # 2^4 = 16
def test_board_score_calculation(self):
"""测试棋盘分数计算"""
log_board = np.array([
[1, 2], # 2, 4
[3, 4] # 8, 16
])
total_score = ScoreCalculator.calculate_board_score(log_board)
expected = 0 + 4 + 16 + 48 # V(2) + V(4) + V(8) + V(16)
assert total_score == expected
class TestTrainingDataCache:
"""训练数据缓存测试"""
def setup_method(self):
"""测试前的设置"""
self.cache = TrainingDataCache(max_size=5)
# 创建测试样本
self.sample_examples = []
for i in range(10):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i * 100),
canonical_hash=f"hash_{i}"
)
self.sample_examples.append(example)
def test_basic_operations(self):
"""测试基本操作"""
assert self.cache.size() == 0
assert self.cache.get("nonexistent") is None
example = self.sample_examples[0]
self.cache.put("key1", example)
assert self.cache.size() == 1
retrieved = self.cache.get("key1")
assert retrieved is not None
assert retrieved.value == example.value
def test_lru_eviction(self):
"""测试LRU淘汰"""
# 填满缓存
for i in range(5):
self.cache.put(f"key_{i}", self.sample_examples[i])
assert self.cache.size() == 5
# 访问key_1
self.cache.get("key_1")
# 添加新项目应该淘汰key_0
self.cache.put("key_5", self.sample_examples[5])
assert self.cache.size() == 5
assert self.cache.get("key_0") is None
assert self.cache.get("key_1") is not None
assert self.cache.get("key_5") is not None
class TestTrainingDataManager:
"""训练数据管理器测试"""
def setup_method(self):
"""测试前的设置"""
self.temp_dir = tempfile.mkdtemp()
self.manager = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
def teardown_method(self):
"""测试后的清理"""
shutil.rmtree(self.temp_dir)
def test_add_and_retrieve_examples(self):
"""测试添加和检索样本"""
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
cache_key = self.manager.add_training_example(board, action=1, value=500.0)
assert cache_key is not None
stats = self.manager.get_cache_stats()
assert stats["cache_size"] == 1
dataset = self.manager.get_pytorch_dataset()
assert len(dataset) == 1
def test_pytorch_integration(self):
"""测试PyTorch集成"""
for i in range(5):
board = np.random.randint(0, 16, size=(4, 4))
board[0, 0] = 2 ** (i % 4 + 1)
self.manager.add_training_example(board, i % 4, float(i * 50))
dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False)
for boards, actions, values in dataloader:
assert boards.shape[0] <= 3 # 批次大小
assert boards.shape[1] == 18 # 通道数 (max_tile_value + 1)
assert boards.shape[2:] == (4, 4) # 棋盘大小
break # 只测试第一个批次
if __name__ == "__main__":
pytest.main([__file__, "-v"])