Files
deep2048/tests/test_board_compression.py
2025-07-23 07:04:10 +08:00

252 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
棋盘压缩算法测试
验证二面体群D4变换和规范化的正确性
"""
import numpy as np
import pytest
from training_data import BoardTransform
class TestBoardTransform:
"""棋盘变换测试类"""
def setup_method(self):
"""测试前的设置"""
# 创建一个非对称的测试棋盘,便于验证变换
self.test_board = np.array([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
])
# 创建一个简单的2x2棋盘用于手动验证
self.simple_board = np.array([
[1, 2],
[3, 4]
])
def test_log_transform(self):
"""测试对数变换"""
# 测试正常情况
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
expected = np.array([
[1, 2, 3, 4],
[0, 1, 2, 3],
[0, 0, 1, 2],
[0, 0, 0, 1]
])
result = BoardTransform.log_transform(board)
np.testing.assert_array_equal(result, expected)
# 测试逆变换
restored = BoardTransform.inverse_log_transform(result)
np.testing.assert_array_equal(restored, board)
def test_rotate_90(self):
"""测试90度旋转"""
# 手动验证2x2矩阵的90度顺时针旋转
# [1, 2] -> [3, 1]
# [3, 4] [4, 2]
expected = np.array([
[3, 1],
[4, 2]
])
result = BoardTransform.rotate_90(self.simple_board)
np.testing.assert_array_equal(result, expected)
def test_flip_horizontal(self):
"""测试水平翻转"""
# 手动验证2x2矩阵的水平翻转
# [1, 2] -> [2, 1]
# [3, 4] [4, 3]
expected = np.array([
[2, 1],
[4, 3]
])
result = BoardTransform.flip_horizontal(self.simple_board)
np.testing.assert_array_equal(result, expected)
def test_all_transforms_count(self):
"""测试是否生成了正确数量的变换"""
transforms = BoardTransform.get_all_transforms(self.test_board)
assert len(transforms) == 8, "应该生成8种变换"
def test_all_transforms_uniqueness(self):
"""测试所有变换是否唯一(对于非对称矩阵)"""
transforms = BoardTransform.get_all_transforms(self.test_board)
# 将每个变换转换为字符串进行比较
transform_strings = [str(t.flatten()) for t in transforms]
unique_transforms = set(transform_strings)
assert len(unique_transforms) == 8, "对于非对称矩阵8种变换应该都不相同"
def test_transform_properties(self):
"""测试变换的数学性质"""
board = self.test_board
# 测试4次90度旋转应该回到原始状态
result = board.copy()
for _ in range(4):
result = BoardTransform.rotate_90(result)
np.testing.assert_array_equal(result, board)
# 测试两次水平翻转应该回到原始状态
flipped = BoardTransform.flip_horizontal(board)
double_flipped = BoardTransform.flip_horizontal(flipped)
np.testing.assert_array_equal(double_flipped, board)
def test_canonical_form_consistency(self):
"""测试规范形式的一致性"""
board = self.test_board
transforms = BoardTransform.get_all_transforms(board)
# 所有变换的规范形式应该相同
canonical_forms = []
transform_indices = []
for transform in transforms:
canonical, idx = BoardTransform.get_canonical_form(transform)
canonical_forms.append(canonical)
transform_indices.append(idx)
# 所有规范形式应该相同
first_canonical = canonical_forms[0]
for canonical in canonical_forms[1:]:
np.testing.assert_array_equal(canonical, first_canonical)
def test_hash_consistency(self):
"""测试哈希值的一致性"""
board = self.test_board
transforms = BoardTransform.get_all_transforms(board)
# 所有变换的哈希值应该相同
hashes = [BoardTransform.compute_hash(transform) for transform in transforms]
first_hash = hashes[0]
for hash_val in hashes[1:]:
assert hash_val == first_hash, "所有等价变换的哈希值应该相同"
def test_symmetric_board(self):
"""测试对称棋盘的情况"""
# 创建一个完全对称的棋盘
symmetric_board = np.array([
[1, 2, 2, 1],
[2, 3, 3, 2],
[2, 3, 3, 2],
[1, 2, 2, 1]
])
transforms = BoardTransform.get_all_transforms(symmetric_board)
# 对于这个特殊的对称棋盘,某些变换可能相同
# 但规范形式应该仍然一致
canonical, _ = BoardTransform.get_canonical_form(symmetric_board)
for transform in transforms:
transform_canonical, _ = BoardTransform.get_canonical_form(transform)
np.testing.assert_array_equal(transform_canonical, canonical)
def test_edge_cases(self):
"""测试边界情况"""
# 测试全零矩阵
zero_board = np.zeros((4, 4), dtype=int)
canonical_zero, _ = BoardTransform.get_canonical_form(zero_board)
np.testing.assert_array_equal(canonical_zero, zero_board)
# 测试单元素矩阵
single_element = np.array([[1]])
canonical_single, _ = BoardTransform.get_canonical_form(single_element)
np.testing.assert_array_equal(canonical_single, single_element)
# 测试1x4矩阵
row_matrix = np.array([[1, 2, 3, 4]])
transforms_row = BoardTransform.get_all_transforms(row_matrix)
assert len(transforms_row) == 8
def test_different_board_sizes(self):
"""测试不同大小的棋盘"""
# 测试3x3棋盘
board_3x3 = np.array([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
])
transforms_3x3 = BoardTransform.get_all_transforms(board_3x3)
assert len(transforms_3x3) == 8
# 验证规范形式一致性
hashes_3x3 = [BoardTransform.compute_hash(t) for t in transforms_3x3]
assert all(h == hashes_3x3[0] for h in hashes_3x3)
# 测试2x3矩形棋盘
board_2x3 = np.array([
[1, 2, 3],
[4, 5, 6]
])
transforms_2x3 = BoardTransform.get_all_transforms(board_2x3)
assert len(transforms_2x3) == 8
# 验证规范形式一致性
hashes_2x3 = [BoardTransform.compute_hash(t) for t in transforms_2x3]
assert all(h == hashes_2x3[0] for h in hashes_2x3)
def test_manual_verification():
"""手动验证一些关键变换"""
# 创建一个简单的测试用例进行手动验证
board = np.array([
[1, 2],
[3, 4]
])
transforms = BoardTransform.get_all_transforms(board)
# 预期的8种变换结果
expected_transforms = [
np.array([[1, 2], [3, 4]]), # R0: 原始
np.array([[3, 1], [4, 2]]), # R90: 旋转90°
np.array([[4, 3], [2, 1]]), # R180: 旋转180°
np.array([[2, 4], [1, 3]]), # R270: 旋转270°
np.array([[2, 1], [4, 3]]), # F: 水平翻转
np.array([[4, 2], [3, 1]]), # F+R90: 翻转后旋转90°
np.array([[3, 4], [1, 2]]), # F+R180: 翻转后旋转180°
np.array([[1, 3], [2, 4]]) # F+R270: 翻转后旋转270°
]
print("手动验证2x2矩阵的8种变换:")
print(f"原始矩阵:\n{board}")
for i, (actual, expected) in enumerate(zip(transforms, expected_transforms)):
print(f"\n变换 {i}:")
print(f"实际结果:\n{actual}")
print(f"预期结果:\n{expected}")
np.testing.assert_array_equal(actual, expected,
err_msg=f"变换 {i} 不匹配")
print("\n所有变换验证通过!")
if __name__ == "__main__":
# 运行手动验证
test_manual_verification()
# 运行pytest测试
pytest.main([__file__, "-v"])