增加L0训练阶段的MCTS部分

This commit is contained in:
hisatri
2025-07-23 07:04:10 +08:00
parent 88bed2a1ef
commit 4410defbe5
23 changed files with 5205 additions and 0 deletions

View File

@@ -0,0 +1,251 @@
"""
棋盘压缩算法测试
验证二面体群D4变换和规范化的正确性
"""
import numpy as np
import pytest
from training_data import BoardTransform
class TestBoardTransform:
"""棋盘变换测试类"""
def setup_method(self):
"""测试前的设置"""
# 创建一个非对称的测试棋盘,便于验证变换
self.test_board = np.array([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
])
# 创建一个简单的2x2棋盘用于手动验证
self.simple_board = np.array([
[1, 2],
[3, 4]
])
def test_log_transform(self):
"""测试对数变换"""
# 测试正常情况
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
expected = np.array([
[1, 2, 3, 4],
[0, 1, 2, 3],
[0, 0, 1, 2],
[0, 0, 0, 1]
])
result = BoardTransform.log_transform(board)
np.testing.assert_array_equal(result, expected)
# 测试逆变换
restored = BoardTransform.inverse_log_transform(result)
np.testing.assert_array_equal(restored, board)
def test_rotate_90(self):
"""测试90度旋转"""
# 手动验证2x2矩阵的90度顺时针旋转
# [1, 2] -> [3, 1]
# [3, 4] [4, 2]
expected = np.array([
[3, 1],
[4, 2]
])
result = BoardTransform.rotate_90(self.simple_board)
np.testing.assert_array_equal(result, expected)
def test_flip_horizontal(self):
"""测试水平翻转"""
# 手动验证2x2矩阵的水平翻转
# [1, 2] -> [2, 1]
# [3, 4] [4, 3]
expected = np.array([
[2, 1],
[4, 3]
])
result = BoardTransform.flip_horizontal(self.simple_board)
np.testing.assert_array_equal(result, expected)
def test_all_transforms_count(self):
"""测试是否生成了正确数量的变换"""
transforms = BoardTransform.get_all_transforms(self.test_board)
assert len(transforms) == 8, "应该生成8种变换"
def test_all_transforms_uniqueness(self):
"""测试所有变换是否唯一(对于非对称矩阵)"""
transforms = BoardTransform.get_all_transforms(self.test_board)
# 将每个变换转换为字符串进行比较
transform_strings = [str(t.flatten()) for t in transforms]
unique_transforms = set(transform_strings)
assert len(unique_transforms) == 8, "对于非对称矩阵8种变换应该都不相同"
def test_transform_properties(self):
"""测试变换的数学性质"""
board = self.test_board
# 测试4次90度旋转应该回到原始状态
result = board.copy()
for _ in range(4):
result = BoardTransform.rotate_90(result)
np.testing.assert_array_equal(result, board)
# 测试两次水平翻转应该回到原始状态
flipped = BoardTransform.flip_horizontal(board)
double_flipped = BoardTransform.flip_horizontal(flipped)
np.testing.assert_array_equal(double_flipped, board)
def test_canonical_form_consistency(self):
"""测试规范形式的一致性"""
board = self.test_board
transforms = BoardTransform.get_all_transforms(board)
# 所有变换的规范形式应该相同
canonical_forms = []
transform_indices = []
for transform in transforms:
canonical, idx = BoardTransform.get_canonical_form(transform)
canonical_forms.append(canonical)
transform_indices.append(idx)
# 所有规范形式应该相同
first_canonical = canonical_forms[0]
for canonical in canonical_forms[1:]:
np.testing.assert_array_equal(canonical, first_canonical)
def test_hash_consistency(self):
"""测试哈希值的一致性"""
board = self.test_board
transforms = BoardTransform.get_all_transforms(board)
# 所有变换的哈希值应该相同
hashes = [BoardTransform.compute_hash(transform) for transform in transforms]
first_hash = hashes[0]
for hash_val in hashes[1:]:
assert hash_val == first_hash, "所有等价变换的哈希值应该相同"
def test_symmetric_board(self):
"""测试对称棋盘的情况"""
# 创建一个完全对称的棋盘
symmetric_board = np.array([
[1, 2, 2, 1],
[2, 3, 3, 2],
[2, 3, 3, 2],
[1, 2, 2, 1]
])
transforms = BoardTransform.get_all_transforms(symmetric_board)
# 对于这个特殊的对称棋盘,某些变换可能相同
# 但规范形式应该仍然一致
canonical, _ = BoardTransform.get_canonical_form(symmetric_board)
for transform in transforms:
transform_canonical, _ = BoardTransform.get_canonical_form(transform)
np.testing.assert_array_equal(transform_canonical, canonical)
def test_edge_cases(self):
"""测试边界情况"""
# 测试全零矩阵
zero_board = np.zeros((4, 4), dtype=int)
canonical_zero, _ = BoardTransform.get_canonical_form(zero_board)
np.testing.assert_array_equal(canonical_zero, zero_board)
# 测试单元素矩阵
single_element = np.array([[1]])
canonical_single, _ = BoardTransform.get_canonical_form(single_element)
np.testing.assert_array_equal(canonical_single, single_element)
# 测试1x4矩阵
row_matrix = np.array([[1, 2, 3, 4]])
transforms_row = BoardTransform.get_all_transforms(row_matrix)
assert len(transforms_row) == 8
def test_different_board_sizes(self):
"""测试不同大小的棋盘"""
# 测试3x3棋盘
board_3x3 = np.array([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
])
transforms_3x3 = BoardTransform.get_all_transforms(board_3x3)
assert len(transforms_3x3) == 8
# 验证规范形式一致性
hashes_3x3 = [BoardTransform.compute_hash(t) for t in transforms_3x3]
assert all(h == hashes_3x3[0] for h in hashes_3x3)
# 测试2x3矩形棋盘
board_2x3 = np.array([
[1, 2, 3],
[4, 5, 6]
])
transforms_2x3 = BoardTransform.get_all_transforms(board_2x3)
assert len(transforms_2x3) == 8
# 验证规范形式一致性
hashes_2x3 = [BoardTransform.compute_hash(t) for t in transforms_2x3]
assert all(h == hashes_2x3[0] for h in hashes_2x3)
def test_manual_verification():
"""手动验证一些关键变换"""
# 创建一个简单的测试用例进行手动验证
board = np.array([
[1, 2],
[3, 4]
])
transforms = BoardTransform.get_all_transforms(board)
# 预期的8种变换结果
expected_transforms = [
np.array([[1, 2], [3, 4]]), # R0: 原始
np.array([[3, 1], [4, 2]]), # R90: 旋转90°
np.array([[4, 3], [2, 1]]), # R180: 旋转180°
np.array([[2, 4], [1, 3]]), # R270: 旋转270°
np.array([[2, 1], [4, 3]]), # F: 水平翻转
np.array([[4, 2], [3, 1]]), # F+R90: 翻转后旋转90°
np.array([[3, 4], [1, 2]]), # F+R180: 翻转后旋转180°
np.array([[1, 3], [2, 4]]) # F+R270: 翻转后旋转270°
]
print("手动验证2x2矩阵的8种变换:")
print(f"原始矩阵:\n{board}")
for i, (actual, expected) in enumerate(zip(transforms, expected_transforms)):
print(f"\n变换 {i}:")
print(f"实际结果:\n{actual}")
print(f"预期结果:\n{expected}")
np.testing.assert_array_equal(actual, expected,
err_msg=f"变换 {i} 不匹配")
print("\n所有变换验证通过!")
if __name__ == "__main__":
# 运行手动验证
test_manual_verification()
# 运行pytest测试
pytest.main([__file__, "-v"])