252 lines
8.3 KiB
Python
252 lines
8.3 KiB
Python
"""
|
||
棋盘压缩算法测试
|
||
|
||
验证二面体群D4变换和规范化的正确性
|
||
"""
|
||
|
||
import numpy as np
|
||
import pytest
|
||
from training_data import BoardTransform
|
||
|
||
|
||
class TestBoardTransform:
|
||
"""棋盘变换测试类"""
|
||
|
||
def setup_method(self):
|
||
"""测试前的设置"""
|
||
# 创建一个非对称的测试棋盘,便于验证变换
|
||
self.test_board = np.array([
|
||
[1, 2, 3, 4],
|
||
[5, 6, 7, 8],
|
||
[9, 10, 11, 12],
|
||
[13, 14, 15, 16]
|
||
])
|
||
|
||
# 创建一个简单的2x2棋盘用于手动验证
|
||
self.simple_board = np.array([
|
||
[1, 2],
|
||
[3, 4]
|
||
])
|
||
|
||
def test_log_transform(self):
|
||
"""测试对数变换"""
|
||
# 测试正常情况
|
||
board = np.array([
|
||
[2, 4, 8, 16],
|
||
[0, 2, 4, 8],
|
||
[0, 0, 2, 4],
|
||
[0, 0, 0, 2]
|
||
])
|
||
|
||
expected = np.array([
|
||
[1, 2, 3, 4],
|
||
[0, 1, 2, 3],
|
||
[0, 0, 1, 2],
|
||
[0, 0, 0, 1]
|
||
])
|
||
|
||
result = BoardTransform.log_transform(board)
|
||
np.testing.assert_array_equal(result, expected)
|
||
|
||
# 测试逆变换
|
||
restored = BoardTransform.inverse_log_transform(result)
|
||
np.testing.assert_array_equal(restored, board)
|
||
|
||
def test_rotate_90(self):
|
||
"""测试90度旋转"""
|
||
# 手动验证2x2矩阵的90度顺时针旋转
|
||
# [1, 2] -> [3, 1]
|
||
# [3, 4] [4, 2]
|
||
|
||
expected = np.array([
|
||
[3, 1],
|
||
[4, 2]
|
||
])
|
||
|
||
result = BoardTransform.rotate_90(self.simple_board)
|
||
np.testing.assert_array_equal(result, expected)
|
||
|
||
def test_flip_horizontal(self):
|
||
"""测试水平翻转"""
|
||
# 手动验证2x2矩阵的水平翻转
|
||
# [1, 2] -> [2, 1]
|
||
# [3, 4] [4, 3]
|
||
|
||
expected = np.array([
|
||
[2, 1],
|
||
[4, 3]
|
||
])
|
||
|
||
result = BoardTransform.flip_horizontal(self.simple_board)
|
||
np.testing.assert_array_equal(result, expected)
|
||
|
||
def test_all_transforms_count(self):
|
||
"""测试是否生成了正确数量的变换"""
|
||
transforms = BoardTransform.get_all_transforms(self.test_board)
|
||
assert len(transforms) == 8, "应该生成8种变换"
|
||
|
||
def test_all_transforms_uniqueness(self):
|
||
"""测试所有变换是否唯一(对于非对称矩阵)"""
|
||
transforms = BoardTransform.get_all_transforms(self.test_board)
|
||
|
||
# 将每个变换转换为字符串进行比较
|
||
transform_strings = [str(t.flatten()) for t in transforms]
|
||
unique_transforms = set(transform_strings)
|
||
|
||
assert len(unique_transforms) == 8, "对于非对称矩阵,8种变换应该都不相同"
|
||
|
||
def test_transform_properties(self):
|
||
"""测试变换的数学性质"""
|
||
board = self.test_board
|
||
|
||
# 测试4次90度旋转应该回到原始状态
|
||
result = board.copy()
|
||
for _ in range(4):
|
||
result = BoardTransform.rotate_90(result)
|
||
np.testing.assert_array_equal(result, board)
|
||
|
||
# 测试两次水平翻转应该回到原始状态
|
||
flipped = BoardTransform.flip_horizontal(board)
|
||
double_flipped = BoardTransform.flip_horizontal(flipped)
|
||
np.testing.assert_array_equal(double_flipped, board)
|
||
|
||
def test_canonical_form_consistency(self):
|
||
"""测试规范形式的一致性"""
|
||
board = self.test_board
|
||
transforms = BoardTransform.get_all_transforms(board)
|
||
|
||
# 所有变换的规范形式应该相同
|
||
canonical_forms = []
|
||
transform_indices = []
|
||
|
||
for transform in transforms:
|
||
canonical, idx = BoardTransform.get_canonical_form(transform)
|
||
canonical_forms.append(canonical)
|
||
transform_indices.append(idx)
|
||
|
||
# 所有规范形式应该相同
|
||
first_canonical = canonical_forms[0]
|
||
for canonical in canonical_forms[1:]:
|
||
np.testing.assert_array_equal(canonical, first_canonical)
|
||
|
||
def test_hash_consistency(self):
|
||
"""测试哈希值的一致性"""
|
||
board = self.test_board
|
||
transforms = BoardTransform.get_all_transforms(board)
|
||
|
||
# 所有变换的哈希值应该相同
|
||
hashes = [BoardTransform.compute_hash(transform) for transform in transforms]
|
||
|
||
first_hash = hashes[0]
|
||
for hash_val in hashes[1:]:
|
||
assert hash_val == first_hash, "所有等价变换的哈希值应该相同"
|
||
|
||
def test_symmetric_board(self):
|
||
"""测试对称棋盘的情况"""
|
||
# 创建一个完全对称的棋盘
|
||
symmetric_board = np.array([
|
||
[1, 2, 2, 1],
|
||
[2, 3, 3, 2],
|
||
[2, 3, 3, 2],
|
||
[1, 2, 2, 1]
|
||
])
|
||
|
||
transforms = BoardTransform.get_all_transforms(symmetric_board)
|
||
|
||
# 对于这个特殊的对称棋盘,某些变换可能相同
|
||
# 但规范形式应该仍然一致
|
||
canonical, _ = BoardTransform.get_canonical_form(symmetric_board)
|
||
|
||
for transform in transforms:
|
||
transform_canonical, _ = BoardTransform.get_canonical_form(transform)
|
||
np.testing.assert_array_equal(transform_canonical, canonical)
|
||
|
||
def test_edge_cases(self):
|
||
"""测试边界情况"""
|
||
# 测试全零矩阵
|
||
zero_board = np.zeros((4, 4), dtype=int)
|
||
canonical_zero, _ = BoardTransform.get_canonical_form(zero_board)
|
||
np.testing.assert_array_equal(canonical_zero, zero_board)
|
||
|
||
# 测试单元素矩阵
|
||
single_element = np.array([[1]])
|
||
canonical_single, _ = BoardTransform.get_canonical_form(single_element)
|
||
np.testing.assert_array_equal(canonical_single, single_element)
|
||
|
||
# 测试1x4矩阵
|
||
row_matrix = np.array([[1, 2, 3, 4]])
|
||
transforms_row = BoardTransform.get_all_transforms(row_matrix)
|
||
assert len(transforms_row) == 8
|
||
|
||
def test_different_board_sizes(self):
|
||
"""测试不同大小的棋盘"""
|
||
# 测试3x3棋盘
|
||
board_3x3 = np.array([
|
||
[1, 2, 3],
|
||
[4, 5, 6],
|
||
[7, 8, 9]
|
||
])
|
||
|
||
transforms_3x3 = BoardTransform.get_all_transforms(board_3x3)
|
||
assert len(transforms_3x3) == 8
|
||
|
||
# 验证规范形式一致性
|
||
hashes_3x3 = [BoardTransform.compute_hash(t) for t in transforms_3x3]
|
||
assert all(h == hashes_3x3[0] for h in hashes_3x3)
|
||
|
||
# 测试2x3矩形棋盘
|
||
board_2x3 = np.array([
|
||
[1, 2, 3],
|
||
[4, 5, 6]
|
||
])
|
||
|
||
transforms_2x3 = BoardTransform.get_all_transforms(board_2x3)
|
||
assert len(transforms_2x3) == 8
|
||
|
||
# 验证规范形式一致性
|
||
hashes_2x3 = [BoardTransform.compute_hash(t) for t in transforms_2x3]
|
||
assert all(h == hashes_2x3[0] for h in hashes_2x3)
|
||
|
||
|
||
def test_manual_verification():
|
||
"""手动验证一些关键变换"""
|
||
# 创建一个简单的测试用例进行手动验证
|
||
board = np.array([
|
||
[1, 2],
|
||
[3, 4]
|
||
])
|
||
|
||
transforms = BoardTransform.get_all_transforms(board)
|
||
|
||
# 预期的8种变换结果
|
||
expected_transforms = [
|
||
np.array([[1, 2], [3, 4]]), # R0: 原始
|
||
np.array([[3, 1], [4, 2]]), # R90: 旋转90°
|
||
np.array([[4, 3], [2, 1]]), # R180: 旋转180°
|
||
np.array([[2, 4], [1, 3]]), # R270: 旋转270°
|
||
np.array([[2, 1], [4, 3]]), # F: 水平翻转
|
||
np.array([[4, 2], [3, 1]]), # F+R90: 翻转后旋转90°
|
||
np.array([[3, 4], [1, 2]]), # F+R180: 翻转后旋转180°
|
||
np.array([[1, 3], [2, 4]]) # F+R270: 翻转后旋转270°
|
||
]
|
||
|
||
print("手动验证2x2矩阵的8种变换:")
|
||
print(f"原始矩阵:\n{board}")
|
||
|
||
for i, (actual, expected) in enumerate(zip(transforms, expected_transforms)):
|
||
print(f"\n变换 {i}:")
|
||
print(f"实际结果:\n{actual}")
|
||
print(f"预期结果:\n{expected}")
|
||
np.testing.assert_array_equal(actual, expected,
|
||
err_msg=f"变换 {i} 不匹配")
|
||
|
||
print("\n所有变换验证通过!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 运行手动验证
|
||
test_manual_verification()
|
||
|
||
# 运行pytest测试
|
||
pytest.main([__file__, "-v"])
|