增加L0训练阶段的MCTS部分
This commit is contained in:
5
tests/__init__.py
Normal file
5
tests/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
测试模块
|
||||
|
||||
包含所有的测试文件和基准测试
|
||||
"""
|
||||
100
tests/run_all_tests.py
Normal file
100
tests/run_all_tests.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
统一测试运行器
|
||||
|
||||
运行所有测试并生成报告
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""运行所有测试"""
|
||||
print("Deep2048 项目测试套件")
|
||||
print("=" * 50)
|
||||
|
||||
test_dir = Path(__file__).parent
|
||||
|
||||
# 测试文件列表
|
||||
test_files = [
|
||||
"test_training_data.py",
|
||||
"test_game_engine.py",
|
||||
"test_torch_mcts.py",
|
||||
"test_board_compression.py",
|
||||
"test_cache_system.py",
|
||||
"test_persistence.py",
|
||||
"test_performance_benchmark.py"
|
||||
]
|
||||
|
||||
# 检查测试文件是否存在
|
||||
existing_tests = []
|
||||
for test_file in test_files:
|
||||
test_path = test_dir / test_file
|
||||
if test_path.exists():
|
||||
existing_tests.append(str(test_path))
|
||||
else:
|
||||
print(f"警告: 测试文件不存在 {test_file}")
|
||||
|
||||
if not existing_tests:
|
||||
print("错误: 没有找到测试文件")
|
||||
return False
|
||||
|
||||
print(f"找到 {len(existing_tests)} 个测试文件")
|
||||
|
||||
# 运行测试
|
||||
start_time = time.time()
|
||||
|
||||
# pytest参数
|
||||
args = [
|
||||
"-v", # 详细输出
|
||||
"--tb=short", # 简短的错误回溯
|
||||
"--durations=10", # 显示最慢的10个测试
|
||||
] + existing_tests
|
||||
|
||||
result = pytest.main(args)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
print(f"\n测试完成,用时: {elapsed_time:.2f}秒")
|
||||
|
||||
if result == 0:
|
||||
print("✅ 所有测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("❌ 部分测试失败")
|
||||
return False
|
||||
|
||||
|
||||
def run_quick_tests():
|
||||
"""运行快速测试(跳过性能测试)"""
|
||||
print("快速测试模式")
|
||||
print("=" * 30)
|
||||
|
||||
test_dir = Path(__file__).parent
|
||||
|
||||
args = [
|
||||
"-v",
|
||||
"-k", "not performance and not slow", # 跳过性能测试
|
||||
str(test_dir)
|
||||
]
|
||||
|
||||
result = pytest.main(args)
|
||||
return result == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="运行Deep2048测试套件")
|
||||
parser.add_argument("--quick", action="store_true", help="快速测试模式")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.quick:
|
||||
success = run_quick_tests()
|
||||
else:
|
||||
success = run_all_tests()
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
251
tests/test_board_compression.py
Normal file
251
tests/test_board_compression.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
棋盘压缩算法测试
|
||||
|
||||
验证二面体群D4变换和规范化的正确性
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from training_data import BoardTransform
|
||||
|
||||
|
||||
class TestBoardTransform:
|
||||
"""棋盘变换测试类"""
|
||||
|
||||
def setup_method(self):
|
||||
"""测试前的设置"""
|
||||
# 创建一个非对称的测试棋盘,便于验证变换
|
||||
self.test_board = np.array([
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]
|
||||
])
|
||||
|
||||
# 创建一个简单的2x2棋盘用于手动验证
|
||||
self.simple_board = np.array([
|
||||
[1, 2],
|
||||
[3, 4]
|
||||
])
|
||||
|
||||
def test_log_transform(self):
|
||||
"""测试对数变换"""
|
||||
# 测试正常情况
|
||||
board = np.array([
|
||||
[2, 4, 8, 16],
|
||||
[0, 2, 4, 8],
|
||||
[0, 0, 2, 4],
|
||||
[0, 0, 0, 2]
|
||||
])
|
||||
|
||||
expected = np.array([
|
||||
[1, 2, 3, 4],
|
||||
[0, 1, 2, 3],
|
||||
[0, 0, 1, 2],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
|
||||
result = BoardTransform.log_transform(board)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
# 测试逆变换
|
||||
restored = BoardTransform.inverse_log_transform(result)
|
||||
np.testing.assert_array_equal(restored, board)
|
||||
|
||||
def test_rotate_90(self):
|
||||
"""测试90度旋转"""
|
||||
# 手动验证2x2矩阵的90度顺时针旋转
|
||||
# [1, 2] -> [3, 1]
|
||||
# [3, 4] [4, 2]
|
||||
|
||||
expected = np.array([
|
||||
[3, 1],
|
||||
[4, 2]
|
||||
])
|
||||
|
||||
result = BoardTransform.rotate_90(self.simple_board)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
def test_flip_horizontal(self):
|
||||
"""测试水平翻转"""
|
||||
# 手动验证2x2矩阵的水平翻转
|
||||
# [1, 2] -> [2, 1]
|
||||
# [3, 4] [4, 3]
|
||||
|
||||
expected = np.array([
|
||||
[2, 1],
|
||||
[4, 3]
|
||||
])
|
||||
|
||||
result = BoardTransform.flip_horizontal(self.simple_board)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
def test_all_transforms_count(self):
|
||||
"""测试是否生成了正确数量的变换"""
|
||||
transforms = BoardTransform.get_all_transforms(self.test_board)
|
||||
assert len(transforms) == 8, "应该生成8种变换"
|
||||
|
||||
def test_all_transforms_uniqueness(self):
|
||||
"""测试所有变换是否唯一(对于非对称矩阵)"""
|
||||
transforms = BoardTransform.get_all_transforms(self.test_board)
|
||||
|
||||
# 将每个变换转换为字符串进行比较
|
||||
transform_strings = [str(t.flatten()) for t in transforms]
|
||||
unique_transforms = set(transform_strings)
|
||||
|
||||
assert len(unique_transforms) == 8, "对于非对称矩阵,8种变换应该都不相同"
|
||||
|
||||
def test_transform_properties(self):
|
||||
"""测试变换的数学性质"""
|
||||
board = self.test_board
|
||||
|
||||
# 测试4次90度旋转应该回到原始状态
|
||||
result = board.copy()
|
||||
for _ in range(4):
|
||||
result = BoardTransform.rotate_90(result)
|
||||
np.testing.assert_array_equal(result, board)
|
||||
|
||||
# 测试两次水平翻转应该回到原始状态
|
||||
flipped = BoardTransform.flip_horizontal(board)
|
||||
double_flipped = BoardTransform.flip_horizontal(flipped)
|
||||
np.testing.assert_array_equal(double_flipped, board)
|
||||
|
||||
def test_canonical_form_consistency(self):
|
||||
"""测试规范形式的一致性"""
|
||||
board = self.test_board
|
||||
transforms = BoardTransform.get_all_transforms(board)
|
||||
|
||||
# 所有变换的规范形式应该相同
|
||||
canonical_forms = []
|
||||
transform_indices = []
|
||||
|
||||
for transform in transforms:
|
||||
canonical, idx = BoardTransform.get_canonical_form(transform)
|
||||
canonical_forms.append(canonical)
|
||||
transform_indices.append(idx)
|
||||
|
||||
# 所有规范形式应该相同
|
||||
first_canonical = canonical_forms[0]
|
||||
for canonical in canonical_forms[1:]:
|
||||
np.testing.assert_array_equal(canonical, first_canonical)
|
||||
|
||||
def test_hash_consistency(self):
|
||||
"""测试哈希值的一致性"""
|
||||
board = self.test_board
|
||||
transforms = BoardTransform.get_all_transforms(board)
|
||||
|
||||
# 所有变换的哈希值应该相同
|
||||
hashes = [BoardTransform.compute_hash(transform) for transform in transforms]
|
||||
|
||||
first_hash = hashes[0]
|
||||
for hash_val in hashes[1:]:
|
||||
assert hash_val == first_hash, "所有等价变换的哈希值应该相同"
|
||||
|
||||
def test_symmetric_board(self):
|
||||
"""测试对称棋盘的情况"""
|
||||
# 创建一个完全对称的棋盘
|
||||
symmetric_board = np.array([
|
||||
[1, 2, 2, 1],
|
||||
[2, 3, 3, 2],
|
||||
[2, 3, 3, 2],
|
||||
[1, 2, 2, 1]
|
||||
])
|
||||
|
||||
transforms = BoardTransform.get_all_transforms(symmetric_board)
|
||||
|
||||
# 对于这个特殊的对称棋盘,某些变换可能相同
|
||||
# 但规范形式应该仍然一致
|
||||
canonical, _ = BoardTransform.get_canonical_form(symmetric_board)
|
||||
|
||||
for transform in transforms:
|
||||
transform_canonical, _ = BoardTransform.get_canonical_form(transform)
|
||||
np.testing.assert_array_equal(transform_canonical, canonical)
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""测试边界情况"""
|
||||
# 测试全零矩阵
|
||||
zero_board = np.zeros((4, 4), dtype=int)
|
||||
canonical_zero, _ = BoardTransform.get_canonical_form(zero_board)
|
||||
np.testing.assert_array_equal(canonical_zero, zero_board)
|
||||
|
||||
# 测试单元素矩阵
|
||||
single_element = np.array([[1]])
|
||||
canonical_single, _ = BoardTransform.get_canonical_form(single_element)
|
||||
np.testing.assert_array_equal(canonical_single, single_element)
|
||||
|
||||
# 测试1x4矩阵
|
||||
row_matrix = np.array([[1, 2, 3, 4]])
|
||||
transforms_row = BoardTransform.get_all_transforms(row_matrix)
|
||||
assert len(transforms_row) == 8
|
||||
|
||||
def test_different_board_sizes(self):
|
||||
"""测试不同大小的棋盘"""
|
||||
# 测试3x3棋盘
|
||||
board_3x3 = np.array([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9]
|
||||
])
|
||||
|
||||
transforms_3x3 = BoardTransform.get_all_transforms(board_3x3)
|
||||
assert len(transforms_3x3) == 8
|
||||
|
||||
# 验证规范形式一致性
|
||||
hashes_3x3 = [BoardTransform.compute_hash(t) for t in transforms_3x3]
|
||||
assert all(h == hashes_3x3[0] for h in hashes_3x3)
|
||||
|
||||
# 测试2x3矩形棋盘
|
||||
board_2x3 = np.array([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
])
|
||||
|
||||
transforms_2x3 = BoardTransform.get_all_transforms(board_2x3)
|
||||
assert len(transforms_2x3) == 8
|
||||
|
||||
# 验证规范形式一致性
|
||||
hashes_2x3 = [BoardTransform.compute_hash(t) for t in transforms_2x3]
|
||||
assert all(h == hashes_2x3[0] for h in hashes_2x3)
|
||||
|
||||
|
||||
def test_manual_verification():
|
||||
"""手动验证一些关键变换"""
|
||||
# 创建一个简单的测试用例进行手动验证
|
||||
board = np.array([
|
||||
[1, 2],
|
||||
[3, 4]
|
||||
])
|
||||
|
||||
transforms = BoardTransform.get_all_transforms(board)
|
||||
|
||||
# 预期的8种变换结果
|
||||
expected_transforms = [
|
||||
np.array([[1, 2], [3, 4]]), # R0: 原始
|
||||
np.array([[3, 1], [4, 2]]), # R90: 旋转90°
|
||||
np.array([[4, 3], [2, 1]]), # R180: 旋转180°
|
||||
np.array([[2, 4], [1, 3]]), # R270: 旋转270°
|
||||
np.array([[2, 1], [4, 3]]), # F: 水平翻转
|
||||
np.array([[4, 2], [3, 1]]), # F+R90: 翻转后旋转90°
|
||||
np.array([[3, 4], [1, 2]]), # F+R180: 翻转后旋转180°
|
||||
np.array([[1, 3], [2, 4]]) # F+R270: 翻转后旋转270°
|
||||
]
|
||||
|
||||
print("手动验证2x2矩阵的8种变换:")
|
||||
print(f"原始矩阵:\n{board}")
|
||||
|
||||
for i, (actual, expected) in enumerate(zip(transforms, expected_transforms)):
|
||||
print(f"\n变换 {i}:")
|
||||
print(f"实际结果:\n{actual}")
|
||||
print(f"预期结果:\n{expected}")
|
||||
np.testing.assert_array_equal(actual, expected,
|
||||
err_msg=f"变换 {i} 不匹配")
|
||||
|
||||
print("\n所有变换验证通过!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行手动验证
|
||||
test_manual_verification()
|
||||
|
||||
# 运行pytest测试
|
||||
pytest.main([__file__, "-v"])
|
||||
311
tests/test_cache_system.py
Normal file
311
tests/test_cache_system.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
内存缓存系统测试
|
||||
|
||||
验证TrainingDataCache的功能和性能
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import time
|
||||
from training_data import TrainingDataCache, TrainingExample
|
||||
|
||||
|
||||
class TestTrainingDataCache:
|
||||
"""训练数据缓存测试类"""
|
||||
|
||||
def setup_method(self):
|
||||
"""测试前的设置"""
|
||||
self.cache = TrainingDataCache(max_size=5) # 小缓存便于测试
|
||||
|
||||
# 创建测试样本
|
||||
self.sample_examples = []
|
||||
for i in range(10):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(i * 100),
|
||||
canonical_hash=f"hash_{i}"
|
||||
)
|
||||
self.sample_examples.append(example)
|
||||
|
||||
def test_basic_operations(self):
|
||||
"""测试基本的存取操作"""
|
||||
# 测试空缓存
|
||||
assert self.cache.size() == 0
|
||||
assert self.cache.get("nonexistent") is None
|
||||
|
||||
# 添加一个项目
|
||||
example = self.sample_examples[0]
|
||||
self.cache.put("key1", example)
|
||||
|
||||
assert self.cache.size() == 1
|
||||
retrieved = self.cache.get("key1")
|
||||
assert retrieved is not None
|
||||
assert retrieved.value == example.value
|
||||
assert retrieved.action == example.action
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""测试LRU淘汰机制"""
|
||||
# 填满缓存
|
||||
for i in range(5):
|
||||
self.cache.put(f"key_{i}", self.sample_examples[i])
|
||||
|
||||
assert self.cache.size() == 5
|
||||
|
||||
# 访问key_1,使其成为最近使用的
|
||||
self.cache.get("key_1")
|
||||
|
||||
# 添加新项目,应该淘汰key_0(最久未使用)
|
||||
self.cache.put("key_5", self.sample_examples[5])
|
||||
|
||||
assert self.cache.size() == 5
|
||||
assert self.cache.get("key_0") is None # 应该被淘汰
|
||||
assert self.cache.get("key_1") is not None # 应该还在
|
||||
assert self.cache.get("key_5") is not None # 新添加的
|
||||
|
||||
def test_update_existing(self):
|
||||
"""测试更新现有项目"""
|
||||
example1 = self.sample_examples[0]
|
||||
example2 = self.sample_examples[1]
|
||||
|
||||
# 添加项目
|
||||
self.cache.put("key1", example1)
|
||||
assert self.cache.get("key1").value == example1.value
|
||||
|
||||
# 更新项目
|
||||
self.cache.put("key1", example2)
|
||||
assert self.cache.size() == 1 # 大小不变
|
||||
assert self.cache.get("key1").value == example2.value # 值已更新
|
||||
|
||||
def test_update_if_better(self):
|
||||
"""测试条件更新功能"""
|
||||
low_value_example = TrainingExample(
|
||||
board_state=np.zeros((4, 4)),
|
||||
action=0,
|
||||
value=100.0,
|
||||
canonical_hash="test_hash"
|
||||
)
|
||||
|
||||
high_value_example = TrainingExample(
|
||||
board_state=np.zeros((4, 4)),
|
||||
action=0,
|
||||
value=200.0,
|
||||
canonical_hash="test_hash"
|
||||
)
|
||||
|
||||
# 首次添加
|
||||
result = self.cache.update_if_better("key1", low_value_example)
|
||||
assert result is True
|
||||
assert self.cache.get("key1").value == 100.0
|
||||
|
||||
# 用更高价值更新
|
||||
result = self.cache.update_if_better("key1", high_value_example)
|
||||
assert result is True
|
||||
assert self.cache.get("key1").value == 200.0
|
||||
|
||||
# 用更低价值尝试更新(应该失败)
|
||||
result = self.cache.update_if_better("key1", low_value_example)
|
||||
assert result is False
|
||||
assert self.cache.get("key1").value == 200.0 # 值不变
|
||||
|
||||
def test_clear(self):
|
||||
"""测试清空缓存"""
|
||||
# 添加一些项目
|
||||
for i in range(3):
|
||||
self.cache.put(f"key_{i}", self.sample_examples[i])
|
||||
|
||||
assert self.cache.size() == 3
|
||||
|
||||
# 清空缓存
|
||||
self.cache.clear()
|
||||
|
||||
assert self.cache.size() == 0
|
||||
for i in range(3):
|
||||
assert self.cache.get(f"key_{i}") is None
|
||||
|
||||
def test_get_all_examples(self):
|
||||
"""测试获取所有样本"""
|
||||
# 添加一些项目
|
||||
added_examples = []
|
||||
for i in range(3):
|
||||
example = self.sample_examples[i]
|
||||
self.cache.put(f"key_{i}", example)
|
||||
added_examples.append(example)
|
||||
|
||||
# 获取所有样本
|
||||
all_examples = self.cache.get_all_examples()
|
||||
|
||||
assert len(all_examples) == 3
|
||||
|
||||
# 验证所有样本都在其中(顺序可能不同)
|
||||
all_values = {ex.value for ex in all_examples}
|
||||
expected_values = {ex.value for ex in added_examples}
|
||||
assert all_values == expected_values
|
||||
|
||||
def test_access_order_tracking(self):
|
||||
"""测试访问顺序跟踪"""
|
||||
# 先填满缓存
|
||||
for i in range(5):
|
||||
self.cache.put(f"key_{i}", self.sample_examples[i])
|
||||
|
||||
# 访问key_1,使其成为最近使用的
|
||||
self.cache.get("key_1")
|
||||
|
||||
# 访问key_3
|
||||
self.cache.get("key_3")
|
||||
|
||||
# 现在访问顺序应该是:key_0(最久), key_2, key_4, key_1, key_3(最新)
|
||||
|
||||
# 添加两个新项目,应该淘汰key_0和key_2
|
||||
self.cache.put("key_5", self.sample_examples[5])
|
||||
self.cache.put("key_6", self.sample_examples[6])
|
||||
|
||||
assert self.cache.get("key_0") is None # 最久的,应该被淘汰
|
||||
assert self.cache.get("key_2") is None # 第二久的,应该被淘汰
|
||||
assert self.cache.get("key_1") is not None # 应该还在
|
||||
assert self.cache.get("key_3") is not None # 应该还在
|
||||
assert self.cache.get("key_4") is not None # 应该还在
|
||||
assert self.cache.get("key_5") is not None # 新添加的
|
||||
assert self.cache.get("key_6") is not None # 新添加的
|
||||
|
||||
|
||||
class TestCachePerformance:
|
||||
"""缓存性能测试"""
|
||||
|
||||
def test_large_cache_performance(self):
|
||||
"""测试大缓存的性能"""
|
||||
large_cache = TrainingDataCache(max_size=10000)
|
||||
|
||||
# 创建大量测试数据
|
||||
examples = []
|
||||
for i in range(5000):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(i),
|
||||
canonical_hash=f"hash_{i}"
|
||||
)
|
||||
examples.append(example)
|
||||
|
||||
# 测试插入性能
|
||||
start_time = time.time()
|
||||
for i, example in enumerate(examples):
|
||||
large_cache.put(f"key_{i}", example)
|
||||
insert_time = time.time() - start_time
|
||||
|
||||
print(f"插入5000个项目耗时: {insert_time:.3f}秒")
|
||||
assert insert_time < 1.0, "插入操作应该很快"
|
||||
|
||||
# 测试查询性能
|
||||
start_time = time.time()
|
||||
for i in range(1000): # 随机查询1000次
|
||||
key = f"key_{np.random.randint(0, 5000)}"
|
||||
large_cache.get(key)
|
||||
query_time = time.time() - start_time
|
||||
|
||||
print(f"1000次随机查询耗时: {query_time:.3f}秒")
|
||||
assert query_time < 0.1, "查询操作应该很快"
|
||||
|
||||
# 验证缓存大小
|
||||
assert large_cache.size() == 5000
|
||||
|
||||
def test_memory_usage(self):
|
||||
"""测试内存使用情况"""
|
||||
import sys
|
||||
|
||||
cache = TrainingDataCache(max_size=1000)
|
||||
|
||||
# 测量空缓存的内存使用
|
||||
initial_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order)
|
||||
|
||||
# 添加数据
|
||||
for i in range(500):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(i),
|
||||
canonical_hash=f"hash_{i}"
|
||||
)
|
||||
cache.put(f"key_{i}", example)
|
||||
|
||||
# 测量填充后的内存使用
|
||||
filled_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order)
|
||||
|
||||
print(f"空缓存内存使用: {initial_size} bytes")
|
||||
print(f"500项目缓存内存使用: {filled_size} bytes")
|
||||
print(f"平均每项目内存使用: {(filled_size - initial_size) / 500:.2f} bytes")
|
||||
|
||||
|
||||
def test_cache_thread_safety():
|
||||
"""测试缓存的线程安全性(基础测试)"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
cache = TrainingDataCache(max_size=1000)
|
||||
errors = []
|
||||
|
||||
def worker(worker_id):
|
||||
"""工作线程函数"""
|
||||
try:
|
||||
for i in range(100):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(worker_id * 100 + i),
|
||||
canonical_hash=f"hash_{worker_id}_{i}"
|
||||
)
|
||||
|
||||
key = f"worker_{worker_id}_key_{i}"
|
||||
cache.put(key, example)
|
||||
|
||||
# 随机读取
|
||||
if i % 10 == 0:
|
||||
cache.get(key)
|
||||
|
||||
# 短暂休眠
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(f"Worker {worker_id}: {e}")
|
||||
|
||||
# 创建多个线程
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=worker, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# 等待所有线程完成
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# 检查是否有错误
|
||||
if errors:
|
||||
print("线程安全测试中的错误:")
|
||||
for error in errors:
|
||||
print(f" {error}")
|
||||
|
||||
print(f"最终缓存大小: {cache.size()}")
|
||||
print(f"线程安全测试完成,错误数: {len(errors)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行基本测试
|
||||
print("运行缓存系统测试...")
|
||||
|
||||
# 运行性能测试
|
||||
print("\n运行性能测试...")
|
||||
perf_test = TestCachePerformance()
|
||||
perf_test.test_large_cache_performance()
|
||||
perf_test.test_memory_usage()
|
||||
|
||||
# 运行线程安全测试
|
||||
print("\n运行线程安全测试...")
|
||||
test_cache_thread_safety()
|
||||
|
||||
# 运行pytest测试
|
||||
print("\n运行pytest测试...")
|
||||
pytest.main([__file__, "-v"])
|
||||
289
tests/test_game_engine.py
Normal file
289
tests/test_game_engine.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
2048游戏引擎测试
|
||||
|
||||
验证新游戏引擎的功能和正确性
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from game import Game2048, GameState
|
||||
|
||||
|
||||
class TestGame2048:
|
||||
"""2048游戏引擎测试类"""
|
||||
|
||||
def setup_method(self):
|
||||
"""测试前的设置"""
|
||||
self.game = Game2048(height=4, width=4, seed=42)
|
||||
|
||||
def test_initialization(self):
|
||||
"""测试游戏初始化"""
|
||||
game = Game2048(height=3, width=4, seed=123)
|
||||
|
||||
assert game.height == 3
|
||||
assert game.width == 4
|
||||
assert game.score == 0
|
||||
assert game.moves == 0
|
||||
assert not game.is_over
|
||||
|
||||
# 应该有两个初始数字
|
||||
non_zero_count = np.count_nonzero(game.board)
|
||||
assert non_zero_count == 2
|
||||
|
||||
# 初始数字应该是1或2(对数形式的2或4)
|
||||
non_zero_values = game.board[game.board != 0]
|
||||
assert all(val in [1, 2] for val in non_zero_values)
|
||||
|
||||
def test_move_row_left(self):
|
||||
"""测试行向左移动逻辑"""
|
||||
# 测试简单移动
|
||||
row = np.array([0, 1, 0, 2])
|
||||
result, score = self.game._move_row_left(row)
|
||||
expected = np.array([1, 2, 0, 0])
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
assert score == 0
|
||||
|
||||
# 测试合并
|
||||
row = np.array([1, 1, 2, 2])
|
||||
result, score = self.game._move_row_left(row)
|
||||
expected = np.array([2, 3, 0, 0])
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
# 分数应该是 2^2 + 2^3 = 4 + 8 = 12
|
||||
assert score == 12
|
||||
|
||||
# 测试复杂情况
|
||||
row = np.array([1, 1, 1, 1])
|
||||
result, score = self.game._move_row_left(row)
|
||||
expected = np.array([2, 2, 0, 0])
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
# 分数应该是 2^2 + 2^2 = 4 + 4 = 8
|
||||
assert score == 8
|
||||
|
||||
def test_move_directions(self):
|
||||
"""测试四个方向的移动"""
|
||||
# 创建特定的棋盘状态
|
||||
game = Game2048(height=3, width=3, seed=42)
|
||||
game.board = np.array([
|
||||
[1, 0, 1],
|
||||
[0, 2, 0],
|
||||
[1, 0, 1]
|
||||
])
|
||||
|
||||
initial_score = game.score
|
||||
|
||||
# 测试向左移动
|
||||
game_left = game.copy()
|
||||
success = game_left.move(2) # 左
|
||||
assert success
|
||||
|
||||
# 测试向右移动
|
||||
game_right = game.copy()
|
||||
success = game_right.move(3) # 右
|
||||
assert success
|
||||
|
||||
# 测试向上移动
|
||||
game_up = game.copy()
|
||||
success = game_up.move(0) # 上
|
||||
assert success
|
||||
|
||||
# 测试向下移动
|
||||
game_down = game.copy()
|
||||
success = game_down.move(1) # 下
|
||||
assert success
|
||||
|
||||
# 所有移动都应该改变棋盘状态
|
||||
assert not np.array_equal(game.board, game_left.board)
|
||||
assert not np.array_equal(game.board, game_right.board)
|
||||
assert not np.array_equal(game.board, game_up.board)
|
||||
assert not np.array_equal(game.board, game_down.board)
|
||||
|
||||
def test_score_calculation(self):
|
||||
"""测试分数计算"""
|
||||
game = Game2048(height=2, width=2, seed=42)
|
||||
|
||||
# 设置特定棋盘状态
|
||||
game.board = np.array([
|
||||
[1, 2], # 2, 4
|
||||
[3, 4] # 8, 16
|
||||
])
|
||||
|
||||
# 计算累积分数
|
||||
total_score = game.calculate_total_score()
|
||||
|
||||
# 根据论文公式:V(N) = (log2(N) - 1) * N
|
||||
# V(2) = 0, V(4) = 4, V(8) = 16, V(16) = 48
|
||||
expected = 0 + 4 + 16 + 48
|
||||
assert total_score == expected
|
||||
|
||||
def test_game_over_detection(self):
|
||||
"""测试游戏结束检测"""
|
||||
game = Game2048(height=2, width=2, seed=42)
|
||||
|
||||
# 设置无法移动的棋盘
|
||||
game.board = np.array([
|
||||
[1, 2], # 2, 4
|
||||
[3, 4] # 8, 16
|
||||
])
|
||||
|
||||
game._check_game_over()
|
||||
assert game.is_over
|
||||
|
||||
# 测试可以移动的棋盘
|
||||
game.board = np.array([
|
||||
[1, 1], # 2, 2 (可以合并)
|
||||
[3, 4] # 8, 16
|
||||
])
|
||||
game.is_over = False
|
||||
|
||||
game._check_game_over()
|
||||
assert not game.is_over
|
||||
|
||||
def test_valid_moves(self):
|
||||
"""测试有效移动检测"""
|
||||
game = Game2048(height=2, width=2, seed=42)
|
||||
|
||||
# 设置可以向所有方向移动的棋盘
|
||||
game.board = np.array([
|
||||
[1, 0],
|
||||
[0, 1]
|
||||
])
|
||||
|
||||
valid_moves = game.get_valid_moves()
|
||||
assert len(valid_moves) == 4 # 所有方向都可以移动
|
||||
|
||||
# 设置无法移动的棋盘
|
||||
game.board = np.array([
|
||||
[1, 2],
|
||||
[3, 4]
|
||||
])
|
||||
|
||||
valid_moves = game.get_valid_moves()
|
||||
assert len(valid_moves) == 0 # 无法移动
|
||||
|
||||
def test_board_display(self):
|
||||
"""测试棋盘显示"""
|
||||
game = Game2048(height=2, width=2, seed=42)
|
||||
|
||||
# 设置对数形式的棋盘
|
||||
game.board = np.array([
|
||||
[0, 1], # 0, 2
|
||||
[2, 3] # 4, 8
|
||||
])
|
||||
|
||||
display_board = game.get_board_display()
|
||||
expected = np.array([
|
||||
[0, 2],
|
||||
[4, 8]
|
||||
])
|
||||
|
||||
np.testing.assert_array_equal(display_board, expected)
|
||||
|
||||
def test_max_tile(self):
|
||||
"""测试最大数字获取"""
|
||||
game = Game2048(height=2, width=2, seed=42)
|
||||
|
||||
game.board = np.array([
|
||||
[1, 2], # 2, 4
|
||||
[3, 4] # 8, 16
|
||||
])
|
||||
|
||||
max_tile = game.get_max_tile()
|
||||
assert max_tile == 16
|
||||
|
||||
def test_state_management(self):
|
||||
"""测试游戏状态管理"""
|
||||
game = Game2048(height=2, width=2, seed=42)
|
||||
|
||||
# 获取初始状态
|
||||
initial_state = game.get_state()
|
||||
assert isinstance(initial_state, GameState)
|
||||
assert initial_state.score == game.score
|
||||
assert initial_state.moves == game.moves
|
||||
assert np.array_equal(initial_state.board, game.board)
|
||||
|
||||
# 执行移动
|
||||
move_success = game.move(2) # 左移
|
||||
|
||||
# 获取新状态
|
||||
new_state = game.get_state()
|
||||
|
||||
# 只有移动成功时才检查移动次数
|
||||
if move_success:
|
||||
assert new_state.moves == initial_state.moves + 1
|
||||
assert not np.array_equal(new_state.board, initial_state.board)
|
||||
else:
|
||||
# 如果移动失败,尝试其他方向
|
||||
for direction in range(4):
|
||||
if game.move(direction):
|
||||
new_state = game.get_state()
|
||||
assert new_state.moves == initial_state.moves + 1
|
||||
assert not np.array_equal(new_state.board, initial_state.board)
|
||||
break
|
||||
|
||||
# 恢复状态
|
||||
game.set_state(initial_state)
|
||||
assert game.score == initial_state.score
|
||||
assert game.moves == initial_state.moves
|
||||
np.testing.assert_array_equal(game.board, initial_state.board)
|
||||
|
||||
def test_copy_functionality(self):
|
||||
"""测试游戏复制功能"""
|
||||
game = Game2048(height=3, width=3, seed=42)
|
||||
|
||||
# 执行一些操作
|
||||
game.move(2)
|
||||
game.move(0)
|
||||
|
||||
# 创建副本
|
||||
game_copy = game.copy()
|
||||
|
||||
# 验证副本
|
||||
assert game_copy.height == game.height
|
||||
assert game_copy.width == game.width
|
||||
assert game_copy.score == game.score
|
||||
assert game_copy.moves == game.moves
|
||||
assert game_copy.is_over == game.is_over
|
||||
np.testing.assert_array_equal(game_copy.board, game.board)
|
||||
|
||||
# 修改副本不应影响原游戏
|
||||
game_copy.move(1)
|
||||
assert game_copy.moves != game.moves
|
||||
|
||||
def test_different_board_sizes(self):
|
||||
"""测试不同大小的棋盘"""
|
||||
# 测试3x3棋盘
|
||||
game_3x3 = Game2048(height=3, width=3, seed=42)
|
||||
assert game_3x3.board.shape == (3, 3)
|
||||
|
||||
# 测试2x4矩形棋盘
|
||||
game_2x4 = Game2048(height=2, width=4, seed=42)
|
||||
assert game_2x4.board.shape == (2, 4)
|
||||
|
||||
# 测试移动功能
|
||||
success = game_3x3.move(2)
|
||||
assert isinstance(success, bool)
|
||||
|
||||
success = game_2x4.move(0)
|
||||
assert isinstance(success, bool)
|
||||
|
||||
def test_spawn_probability(self):
|
||||
"""测试数字生成概率"""
|
||||
# 测试只生成2的情况
|
||||
game_only_2 = Game2048(height=4, width=4, spawn_prob_4=0.0, seed=42)
|
||||
|
||||
# 重置并检查生成的数字
|
||||
game_only_2.reset()
|
||||
non_zero_values = game_only_2.board[game_only_2.board != 0]
|
||||
assert all(val == 1 for val in non_zero_values) # 只有1(对数形式的2)
|
||||
|
||||
# 测试只生成4的情况
|
||||
game_only_4 = Game2048(height=4, width=4, spawn_prob_4=1.0, seed=42)
|
||||
game_only_4.reset()
|
||||
non_zero_values = game_only_4.board[game_only_4.board != 0]
|
||||
assert all(val == 2 for val in non_zero_values) # 只有2(对数形式的4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
print("运行2048游戏引擎测试...")
|
||||
pytest.main([__file__, "-v"])
|
||||
210
tests/test_performance_benchmark.py
Normal file
210
tests/test_performance_benchmark.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
性能基准测试
|
||||
|
||||
测试不同MCTS实现的性能对比
|
||||
"""
|
||||
|
||||
import time
|
||||
import torch
|
||||
import pytest
|
||||
from game import Game2048
|
||||
from torch_mcts import TorchMCTS
|
||||
|
||||
|
||||
class TestPerformanceBenchmark:
|
||||
"""性能基准测试类"""
|
||||
|
||||
@pytest.fixture
|
||||
def game(self):
|
||||
"""测试游戏状态"""
|
||||
return Game2048(height=3, width=3, seed=42)
|
||||
|
||||
def test_cpu_mcts_performance(self, game):
|
||||
"""测试CPU MCTS性能"""
|
||||
mcts = TorchMCTS(
|
||||
c_param=1.414,
|
||||
max_simulation_depth=50,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
simulations = 2000
|
||||
start_time = time.time()
|
||||
action, stats = mcts.search(game, simulations)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
|
||||
# CPU MCTS应该达到基本性能要求
|
||||
assert speed > 500, f"CPU MCTS性能过低: {speed:.1f} 模拟/秒"
|
||||
assert action in game.get_valid_moves()
|
||||
|
||||
def test_auto_device_mcts_performance(self, game):
|
||||
"""测试自动设备选择MCTS性能"""
|
||||
mcts = TorchMCTS(
|
||||
c_param=1.414,
|
||||
max_simulation_depth=50,
|
||||
device="auto"
|
||||
)
|
||||
|
||||
simulations = 2000
|
||||
start_time = time.time()
|
||||
action, stats = mcts.search(game, simulations)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
|
||||
# 自动设备选择应该有合理性能
|
||||
assert speed > 100, f"自动设备MCTS性能过低: {speed:.1f} 模拟/秒"
|
||||
assert action in game.get_valid_moves()
|
||||
|
||||
if mcts.device.type == "cuda":
|
||||
del mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_gpu_mcts_performance(self, game):
|
||||
"""测试GPU MCTS性能"""
|
||||
gpu_mcts = TorchMCTS(
|
||||
max_simulation_depth=50,
|
||||
batch_size=8192,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
simulations = 5000
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
action, stats = gpu_mcts.search(game, simulations)
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
|
||||
# GPU MCTS应该有显著性能提升
|
||||
assert speed > 200, f"GPU MCTS性能过低: {speed:.1f} 模拟/秒"
|
||||
assert action in game.get_valid_moves()
|
||||
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_performance_comparison(self, game):
|
||||
"""性能对比测试"""
|
||||
simulations = 3000
|
||||
results = {}
|
||||
|
||||
# CPU MCTS
|
||||
cpu_mcts = TorchMCTS(c_param=1.414, max_simulation_depth=50, device="cpu")
|
||||
start_time = time.time()
|
||||
cpu_action, cpu_stats = cpu_mcts.search(game.copy(), simulations)
|
||||
cpu_time = time.time() - start_time
|
||||
results['CPU'] = simulations / cpu_time
|
||||
|
||||
# GPU MCTS
|
||||
gpu_mcts = TorchMCTS(max_simulation_depth=50, batch_size=8192, device="cuda")
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
gpu_action, gpu_stats = gpu_mcts.search(game.copy(), simulations)
|
||||
torch.cuda.synchronize()
|
||||
gpu_time = time.time() - start_time
|
||||
results['GPU'] = simulations / gpu_time
|
||||
|
||||
# 验证性能提升
|
||||
speedup = results['GPU'] / results['CPU']
|
||||
print(f"\n性能对比:")
|
||||
print(f" CPU: {results['CPU']:.1f} 模拟/秒")
|
||||
print(f" GPU: {results['GPU']:.1f} 模拟/秒")
|
||||
print(f" 加速比: {speedup:.1f}x")
|
||||
|
||||
# GPU应该有一定的性能优势(至少不能太慢)
|
||||
assert speedup > 0.1, f"GPU性能严重低于CPU: {speedup:.2f}x"
|
||||
|
||||
# 清理
|
||||
del cpu_mcts, gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_batch_size_scaling(self):
|
||||
"""测试批次大小对性能的影响"""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA不可用")
|
||||
|
||||
game = Game2048(height=3, width=3, seed=42)
|
||||
batch_sizes = [1024, 4096, 16384]
|
||||
simulations = 2000
|
||||
|
||||
results = {}
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
gpu_mcts = TorchMCTS(
|
||||
max_simulation_depth=50,
|
||||
batch_size=batch_size,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
action, stats = gpu_mcts.search(game.copy(), simulations)
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
results[batch_size] = speed
|
||||
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 验证批次大小的影响
|
||||
speeds = list(results.values())
|
||||
max_speed = max(speeds)
|
||||
min_speed = min(speeds)
|
||||
|
||||
# 不同批次大小的性能差异应该在合理范围内
|
||||
speed_ratio = max_speed / min_speed
|
||||
assert speed_ratio < 10, f"批次大小性能差异过大: {speed_ratio:.2f}"
|
||||
|
||||
print(f"\n批次大小性能测试:")
|
||||
for batch_size, speed in results.items():
|
||||
print(f" {batch_size:,}: {speed:.1f} 模拟/秒")
|
||||
|
||||
|
||||
def test_memory_efficiency():
|
||||
"""内存效率测试"""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA不可用")
|
||||
|
||||
game = Game2048(height=3, width=3, seed=42)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
initial_memory = torch.cuda.memory_allocated()
|
||||
|
||||
gpu_mcts = TorchMCTS(
|
||||
max_simulation_depth=50,
|
||||
batch_size=32768,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 执行搜索
|
||||
action, stats = gpu_mcts.search(game, 10000)
|
||||
|
||||
peak_memory = torch.cuda.max_memory_allocated()
|
||||
memory_used = (peak_memory - initial_memory) / 1e6 # MB
|
||||
|
||||
# 内存使用应该合理
|
||||
assert memory_used < 500, f"GPU内存使用过多: {memory_used:.1f} MB"
|
||||
|
||||
# 计算内存效率
|
||||
speed = stats['sims_per_second']
|
||||
memory_efficiency = speed / memory_used if memory_used > 0 else 0
|
||||
|
||||
print(f"\n内存效率测试:")
|
||||
print(f" 内存使用: {memory_used:.1f} MB")
|
||||
print(f" 模拟速度: {speed:.1f} 模拟/秒")
|
||||
print(f" 内存效率: {memory_efficiency:.1f} 模拟/秒/MB")
|
||||
|
||||
# 清理
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
329
tests/test_persistence.py
Normal file
329
tests/test_persistence.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
硬盘持久化系统测试
|
||||
|
||||
验证TrainingDataPersistence的功能和可靠性
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
import os
|
||||
from pathlib import Path
|
||||
from training_data import (
|
||||
TrainingDataPersistence,
|
||||
TrainingDataCache,
|
||||
TrainingExample,
|
||||
TrainingDataManager
|
||||
)
|
||||
|
||||
|
||||
class TestTrainingDataPersistence:
|
||||
"""训练数据持久化测试类"""
|
||||
|
||||
def setup_method(self):
|
||||
"""测试前的设置"""
|
||||
# 创建临时目录用于测试
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.persistence = TrainingDataPersistence(self.temp_dir)
|
||||
|
||||
# 创建测试样本
|
||||
self.test_examples = []
|
||||
for i in range(10):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(i * 100),
|
||||
canonical_hash=f"hash_{i}"
|
||||
)
|
||||
self.test_examples.append(example)
|
||||
|
||||
def teardown_method(self):
|
||||
"""测试后的清理"""
|
||||
# 删除临时目录
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def test_save_and_load_cache(self):
|
||||
"""测试缓存的保存和加载"""
|
||||
# 创建缓存并添加数据
|
||||
cache = TrainingDataCache(max_size=100)
|
||||
for i, example in enumerate(self.test_examples[:5]):
|
||||
cache.put(f"key_{i}", example)
|
||||
|
||||
# 保存缓存
|
||||
filename = "test_cache"
|
||||
self.persistence.save_cache(cache, filename)
|
||||
|
||||
# 验证文件存在
|
||||
expected_path = Path(self.temp_dir) / f"{filename}.pkl"
|
||||
assert expected_path.exists()
|
||||
|
||||
# 加载缓存
|
||||
loaded_examples = self.persistence.load_cache(filename)
|
||||
|
||||
# 验证加载的数据
|
||||
assert len(loaded_examples) == 5
|
||||
|
||||
# 验证数据内容
|
||||
loaded_values = {ex.value for ex in loaded_examples}
|
||||
original_values = {ex.value for ex in self.test_examples[:5]}
|
||||
assert loaded_values == original_values
|
||||
|
||||
def test_save_examples_batch(self):
|
||||
"""测试批量保存样本"""
|
||||
batch_name = "test_batch"
|
||||
examples = self.test_examples[:7]
|
||||
|
||||
# 保存批次
|
||||
self.persistence.save_examples_batch(examples, batch_name)
|
||||
|
||||
# 验证文件存在
|
||||
expected_path = Path(self.temp_dir) / f"{batch_name}.pkl"
|
||||
assert expected_path.exists()
|
||||
|
||||
# 加载并验证
|
||||
loaded_examples = self.persistence.load_cache(batch_name)
|
||||
assert len(loaded_examples) == 7
|
||||
|
||||
# 验证数据完整性
|
||||
for original, loaded in zip(examples, loaded_examples):
|
||||
assert original.action == loaded.action
|
||||
assert original.value == loaded.value
|
||||
assert original.canonical_hash == loaded.canonical_hash
|
||||
np.testing.assert_array_equal(original.board_state, loaded.board_state)
|
||||
|
||||
def test_load_nonexistent_file(self):
|
||||
"""测试加载不存在的文件"""
|
||||
loaded_examples = self.persistence.load_cache("nonexistent_file")
|
||||
assert loaded_examples == []
|
||||
|
||||
def test_list_saved_files(self):
|
||||
"""测试列出保存的文件"""
|
||||
# 初始应该没有文件
|
||||
files = self.persistence.list_saved_files()
|
||||
assert len(files) == 0
|
||||
|
||||
# 保存一些文件
|
||||
cache = TrainingDataCache(max_size=100)
|
||||
cache.put("key1", self.test_examples[0])
|
||||
|
||||
self.persistence.save_cache(cache, "file1")
|
||||
self.persistence.save_cache(cache, "file2")
|
||||
self.persistence.save_examples_batch(self.test_examples[:3], "batch1")
|
||||
|
||||
# 检查文件列表
|
||||
files = self.persistence.list_saved_files()
|
||||
assert len(files) == 3
|
||||
assert "file1" in files
|
||||
assert "file2" in files
|
||||
assert "batch1" in files
|
||||
|
||||
def test_large_data_persistence(self):
|
||||
"""测试大数据量的持久化"""
|
||||
# 创建大量测试数据
|
||||
large_examples = []
|
||||
for i in range(1000):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(i),
|
||||
canonical_hash=f"hash_{i}"
|
||||
)
|
||||
large_examples.append(example)
|
||||
|
||||
# 保存大批次
|
||||
batch_name = "large_batch"
|
||||
self.persistence.save_examples_batch(large_examples, batch_name)
|
||||
|
||||
# 加载并验证
|
||||
loaded_examples = self.persistence.load_cache(batch_name)
|
||||
assert len(loaded_examples) == 1000
|
||||
|
||||
# 验证一些随机样本
|
||||
for i in [0, 100, 500, 999]:
|
||||
assert loaded_examples[i].value == float(i)
|
||||
assert loaded_examples[i].action == i % 4
|
||||
|
||||
def test_data_integrity(self):
|
||||
"""测试数据完整性"""
|
||||
# 创建包含特殊值的测试数据
|
||||
special_board = np.array([
|
||||
[0, 1, 2, 17], # 包含边界值
|
||||
[3, 4, 5, 6],
|
||||
[7, 8, 9, 10],
|
||||
[11, 12, 13, 14]
|
||||
])
|
||||
|
||||
special_example = TrainingExample(
|
||||
board_state=special_board,
|
||||
action=3,
|
||||
value=12345.67,
|
||||
canonical_hash="special_hash_123"
|
||||
)
|
||||
|
||||
# 保存
|
||||
self.persistence.save_examples_batch([special_example], "special_test")
|
||||
|
||||
# 加载
|
||||
loaded = self.persistence.load_cache("special_test")
|
||||
assert len(loaded) == 1
|
||||
|
||||
loaded_example = loaded[0]
|
||||
|
||||
# 验证所有字段
|
||||
np.testing.assert_array_equal(loaded_example.board_state, special_board)
|
||||
assert loaded_example.action == 3
|
||||
assert abs(loaded_example.value - 12345.67) < 1e-6
|
||||
assert loaded_example.canonical_hash == "special_hash_123"
|
||||
|
||||
|
||||
class TestTrainingDataManager:
|
||||
"""训练数据管理器测试类"""
|
||||
|
||||
def setup_method(self):
|
||||
"""测试前的设置"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.manager = TrainingDataManager(
|
||||
data_dir=self.temp_dir,
|
||||
cache_size=100,
|
||||
board_size=(4, 4)
|
||||
)
|
||||
|
||||
def teardown_method(self):
|
||||
"""测试后的清理"""
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def test_add_and_retrieve_examples(self):
|
||||
"""测试添加和检索训练样本"""
|
||||
# 创建测试棋盘
|
||||
board = np.array([
|
||||
[2, 4, 8, 16],
|
||||
[0, 2, 4, 8],
|
||||
[0, 0, 2, 4],
|
||||
[0, 0, 0, 2]
|
||||
])
|
||||
|
||||
# 添加训练样本
|
||||
cache_key = self.manager.add_training_example(board, action=1, value=500.0)
|
||||
assert cache_key is not None
|
||||
|
||||
# 验证缓存统计
|
||||
stats = self.manager.get_cache_stats()
|
||||
assert stats["cache_size"] == 1
|
||||
|
||||
# 获取PyTorch数据集
|
||||
dataset = self.manager.get_pytorch_dataset()
|
||||
assert len(dataset) == 1
|
||||
|
||||
# 验证数据集内容
|
||||
board_tensor, action_tensor, value_tensor = dataset[0]
|
||||
assert action_tensor.item() == 1
|
||||
assert abs(value_tensor.item() - 500.0) < 1e-6
|
||||
|
||||
def test_save_and_load_workflow(self):
|
||||
"""测试完整的保存和加载工作流"""
|
||||
# 添加一些训练样本
|
||||
boards = [
|
||||
np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]]),
|
||||
np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]]),
|
||||
np.array([[8, 16, 32, 64], [4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8]])
|
||||
]
|
||||
|
||||
for i, board in enumerate(boards):
|
||||
for action in range(4):
|
||||
value = (i + 1) * 100 + action * 10
|
||||
self.manager.add_training_example(board, action, value)
|
||||
|
||||
# 保存当前缓存
|
||||
self.manager.save_current_cache("workflow_test")
|
||||
|
||||
# 创建新的管理器
|
||||
new_manager = TrainingDataManager(
|
||||
data_dir=self.temp_dir,
|
||||
cache_size=100,
|
||||
board_size=(4, 4)
|
||||
)
|
||||
|
||||
# 加载数据
|
||||
loaded_count = new_manager.load_from_file("workflow_test")
|
||||
assert loaded_count == 12 # 3个棋盘 × 4个动作
|
||||
|
||||
# 验证数据
|
||||
dataset = new_manager.get_pytorch_dataset()
|
||||
assert len(dataset) == 12
|
||||
|
||||
def test_merge_caches(self):
|
||||
"""测试缓存合并功能"""
|
||||
# 在第一个管理器中添加数据
|
||||
board1 = np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]])
|
||||
self.manager.add_training_example(board1, 0, 100.0)
|
||||
self.manager.add_training_example(board1, 1, 200.0)
|
||||
|
||||
# 创建第二个管理器
|
||||
manager2 = TrainingDataManager(
|
||||
data_dir=self.temp_dir,
|
||||
cache_size=100,
|
||||
board_size=(4, 4)
|
||||
)
|
||||
|
||||
# 在第二个管理器中添加不同的数据
|
||||
board2 = np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]])
|
||||
manager2.add_training_example(board2, 0, 300.0)
|
||||
manager2.add_training_example(board2, 1, 400.0)
|
||||
|
||||
# 合并缓存
|
||||
merged_count = self.manager.merge_caches(manager2)
|
||||
assert merged_count == 2
|
||||
|
||||
# 验证合并后的数据
|
||||
stats = self.manager.get_cache_stats()
|
||||
assert stats["cache_size"] == 4
|
||||
|
||||
dataset = self.manager.get_pytorch_dataset()
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_pytorch_integration(self):
|
||||
"""测试PyTorch集成"""
|
||||
# 添加测试数据
|
||||
for i in range(10):
|
||||
board = np.random.randint(0, 16, size=(4, 4))
|
||||
# 确保至少有一些非零值
|
||||
board[0, 0] = 2 ** (i % 4 + 1)
|
||||
|
||||
action = i % 4
|
||||
value = float(i * 50)
|
||||
self.manager.add_training_example(board, action, value)
|
||||
|
||||
# 获取DataLoader
|
||||
dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False)
|
||||
|
||||
# 验证批次
|
||||
batch_count = 0
|
||||
total_samples = 0
|
||||
|
||||
for boards, actions, values in dataloader:
|
||||
batch_count += 1
|
||||
batch_size = boards.shape[0]
|
||||
total_samples += batch_size
|
||||
|
||||
# 验证张量形状
|
||||
assert boards.shape == (batch_size, 18, 4, 4) # max_tile_value + 1 = 18
|
||||
assert actions.shape == (batch_size,)
|
||||
assert values.shape == (batch_size,)
|
||||
|
||||
# 验证数据类型
|
||||
assert boards.dtype == torch.float32
|
||||
assert actions.dtype == torch.long
|
||||
assert values.dtype == torch.float32
|
||||
|
||||
assert total_samples == 10
|
||||
assert batch_count == 4 # ceil(10/3) = 4
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
print("运行持久化系统测试...")
|
||||
pytest.main([__file__, "-v"])
|
||||
295
tests/test_torch_mcts.py
Normal file
295
tests/test_torch_mcts.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
PyTorch MCTS测试
|
||||
|
||||
测试统一的PyTorch MCTS实现
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from game import Game2048
|
||||
from torch_mcts import TorchMCTS
|
||||
from training_data import TrainingDataManager
|
||||
|
||||
|
||||
class TestTorchMCTS:
|
||||
"""PyTorch MCTS测试类"""
|
||||
|
||||
@pytest.fixture
|
||||
def game(self):
|
||||
"""测试游戏状态"""
|
||||
return Game2048(height=3, width=3, seed=42)
|
||||
|
||||
@pytest.fixture
|
||||
def cpu_mcts(self):
|
||||
"""CPU MCTS实例"""
|
||||
return TorchMCTS(
|
||||
c_param=1.414,
|
||||
max_simulation_depth=30,
|
||||
batch_size=1024,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def gpu_mcts(self):
|
||||
"""GPU MCTS实例"""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA不可用")
|
||||
|
||||
return TorchMCTS(
|
||||
c_param=1.414,
|
||||
max_simulation_depth=30,
|
||||
batch_size=4096,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
def test_cpu_mcts_basic_functionality(self, game, cpu_mcts):
|
||||
"""测试CPU MCTS基本功能"""
|
||||
# 执行搜索
|
||||
action, stats = cpu_mcts.search(game, 1000)
|
||||
|
||||
# 验证结果
|
||||
assert action in game.get_valid_moves(), f"选择了无效动作: {action}"
|
||||
assert 'action_visits' in stats
|
||||
assert 'action_avg_values' in stats
|
||||
assert 'sims_per_second' in stats
|
||||
assert stats['device'] == 'cpu'
|
||||
|
||||
# 验证访问次数
|
||||
total_visits = sum(stats['action_visits'].values())
|
||||
assert total_visits == 1000, f"访问次数不匹配: {total_visits}"
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_gpu_mcts_basic_functionality(self, game, gpu_mcts):
|
||||
"""测试GPU MCTS基本功能"""
|
||||
# 执行搜索
|
||||
action, stats = gpu_mcts.search(game, 2000)
|
||||
|
||||
# 验证结果
|
||||
assert action in game.get_valid_moves(), f"选择了无效动作: {action}"
|
||||
assert 'action_visits' in stats
|
||||
assert 'action_avg_values' in stats
|
||||
assert 'sims_per_second' in stats
|
||||
assert stats['device'] == 'cuda'
|
||||
|
||||
# 验证访问次数
|
||||
total_visits = sum(stats['action_visits'].values())
|
||||
assert total_visits == 2000, f"访问次数不匹配: {total_visits}"
|
||||
|
||||
def test_action_distribution_quality(self, game, cpu_mcts):
|
||||
"""测试动作分布质量"""
|
||||
action, stats = cpu_mcts.search(game, 5000)
|
||||
|
||||
action_visits = stats['action_visits']
|
||||
visit_values = list(action_visits.values())
|
||||
|
||||
# 检查分布不应该完全均匀(MCTS应该有偏向性)
|
||||
assert len(set(visit_values)) > 1, "动作分布完全均匀,不符合MCTS预期"
|
||||
|
||||
# 检查最佳动作应该有最多访问次数
|
||||
best_action_visits = action_visits[action]
|
||||
assert best_action_visits == max(visit_values), "最佳动作访问次数不是最多"
|
||||
|
||||
# 检查价值的合理性
|
||||
action_values = stats['action_avg_values']
|
||||
for act, value in action_values.items():
|
||||
assert value > 0, f"动作{act}的价值应该为正: {value}"
|
||||
assert value < 100000, f"动作{act}的价值过大: {value}"
|
||||
|
||||
def test_device_auto_selection(self, game):
|
||||
"""测试设备自动选择"""
|
||||
mcts = TorchMCTS(device="auto", batch_size=1024)
|
||||
|
||||
# 验证设备选择
|
||||
if torch.cuda.is_available():
|
||||
assert mcts.device.type == "cuda"
|
||||
else:
|
||||
assert mcts.device.type == "cpu"
|
||||
|
||||
# 执行搜索验证功能
|
||||
action, stats = mcts.search(game, 1000)
|
||||
assert action in game.get_valid_moves()
|
||||
|
||||
if mcts.device.type == "cuda":
|
||||
del mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_batch_size_auto_selection(self, game):
|
||||
"""测试批次大小自动选择"""
|
||||
# CPU自动选择
|
||||
cpu_mcts = TorchMCTS(device="cpu", batch_size=None)
|
||||
assert cpu_mcts.batch_size == 4096 # CPU默认批次大小
|
||||
|
||||
# GPU自动选择(如果可用)
|
||||
if torch.cuda.is_available():
|
||||
gpu_mcts = TorchMCTS(device="cuda", batch_size=None)
|
||||
assert gpu_mcts.batch_size == 32768 # GPU默认批次大小
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_performance_cpu(self, game, cpu_mcts):
|
||||
"""测试CPU性能"""
|
||||
simulations = 2000
|
||||
|
||||
start_time = time.time()
|
||||
action, stats = cpu_mcts.search(game, simulations)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
|
||||
# CPU应该达到基本性能要求
|
||||
assert speed > 100, f"CPU性能过低: {speed:.1f} 模拟/秒"
|
||||
|
||||
# 验证统计信息准确性
|
||||
assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确"
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_performance_gpu(self, game, gpu_mcts):
|
||||
"""测试GPU性能"""
|
||||
simulations = 5000
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
action, stats = gpu_mcts.search(game, simulations)
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = simulations / elapsed_time
|
||||
|
||||
# GPU应该有合理的性能
|
||||
assert speed > 50, f"GPU性能过低: {speed:.1f} 模拟/秒"
|
||||
|
||||
# 验证统计信息准确性
|
||||
assert abs(stats['sims_per_second'] - speed) < speed * 0.2, "统计信息不准确"
|
||||
|
||||
def test_training_data_collection(self, game):
|
||||
"""测试训练数据收集"""
|
||||
# 创建训练数据管理器
|
||||
training_manager = TrainingDataManager(
|
||||
data_dir="data/test_torch_training",
|
||||
cache_size=5000,
|
||||
board_size=(3, 3)
|
||||
)
|
||||
|
||||
mcts = TorchMCTS(
|
||||
max_simulation_depth=30,
|
||||
batch_size=1024,
|
||||
device="cpu",
|
||||
training_manager=training_manager
|
||||
)
|
||||
|
||||
# 执行搜索
|
||||
action, stats = mcts.search(game, 2000)
|
||||
|
||||
# 验证训练数据收集
|
||||
cache_stats = training_manager.get_cache_stats()
|
||||
assert cache_stats['cache_size'] > 0, "未收集到训练数据"
|
||||
|
||||
# 验证数据质量
|
||||
assert cache_stats['cache_size'] <= 2000, "收集的样本数超出预期"
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_memory_management(self, game):
|
||||
"""测试GPU内存管理"""
|
||||
torch.cuda.empty_cache()
|
||||
initial_memory = torch.cuda.memory_allocated()
|
||||
|
||||
gpu_mcts = TorchMCTS(
|
||||
max_simulation_depth=30,
|
||||
batch_size=8192,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 执行搜索
|
||||
action, stats = gpu_mcts.search(game, 3000)
|
||||
|
||||
# 检查内存使用
|
||||
peak_memory = torch.cuda.max_memory_allocated()
|
||||
memory_used = (peak_memory - initial_memory) / 1e6 # MB
|
||||
|
||||
assert memory_used < 200, f"GPU内存使用过多: {memory_used:.1f} MB"
|
||||
|
||||
# 清理并验证内存释放
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
final_memory = torch.cuda.memory_allocated()
|
||||
assert final_memory <= initial_memory * 1.1, "GPU内存未正确释放"
|
||||
|
||||
def test_device_switching(self, game):
|
||||
"""测试设备动态切换"""
|
||||
mcts = TorchMCTS(device="cpu", batch_size=1024)
|
||||
|
||||
# 初始为CPU
|
||||
assert mcts.device.type == "cpu"
|
||||
action1, stats1 = mcts.search(game.copy(), 1000)
|
||||
assert stats1['device'] == 'cpu'
|
||||
|
||||
# 切换到GPU(如果可用)
|
||||
if torch.cuda.is_available():
|
||||
mcts.set_device("cuda")
|
||||
assert mcts.device.type == "cuda"
|
||||
|
||||
action2, stats2 = mcts.search(game.copy(), 1000)
|
||||
assert stats2['device'] == 'cuda'
|
||||
|
||||
# 切换回CPU
|
||||
mcts.set_device("cpu")
|
||||
assert mcts.device.type == "cpu"
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_consistency_across_devices(self, game):
|
||||
"""测试不同设备间的一致性"""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA不可用")
|
||||
|
||||
# 使用相同的随机种子
|
||||
np.random.seed(42)
|
||||
cpu_mcts = TorchMCTS(device="cpu", batch_size=2048)
|
||||
cpu_action, cpu_stats = cpu_mcts.search(game.copy(), 3000)
|
||||
|
||||
np.random.seed(42)
|
||||
gpu_mcts = TorchMCTS(device="cuda", batch_size=2048)
|
||||
gpu_action, gpu_stats = gpu_mcts.search(game.copy(), 3000)
|
||||
|
||||
# 由于随机性,动作可能不完全一致,但应该在合理范围内
|
||||
# 这里主要验证两个设备都能正常工作
|
||||
assert cpu_action in game.get_valid_moves()
|
||||
assert gpu_action in game.get_valid_moves()
|
||||
|
||||
# 验证访问次数总和
|
||||
cpu_total = sum(cpu_stats['action_visits'].values())
|
||||
gpu_total = sum(gpu_stats['action_visits'].values())
|
||||
assert cpu_total == gpu_total == 3000
|
||||
|
||||
del gpu_mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA不可用")
|
||||
def test_batch_size_optimization():
|
||||
"""测试批次大小优化"""
|
||||
game = Game2048(height=3, width=3, seed=42)
|
||||
|
||||
mcts = TorchMCTS(device="cuda", batch_size=4096)
|
||||
|
||||
# 执行批次大小优化
|
||||
optimal_size = mcts.optimize_batch_size(game, test_simulations=1000)
|
||||
|
||||
# 验证优化结果
|
||||
assert optimal_size > 0
|
||||
assert mcts.batch_size == optimal_size
|
||||
|
||||
# 验证优化后的性能
|
||||
action, stats = mcts.search(game, 2000)
|
||||
assert action in game.get_valid_moves()
|
||||
assert stats['sims_per_second'] > 0
|
||||
|
||||
del mcts
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
209
tests/test_training_data.py
Normal file
209
tests/test_training_data.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
训练数据模块测试
|
||||
|
||||
测试棋盘变换、缓存系统、持久化等核心功能
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
from training_data import (
|
||||
BoardTransform,
|
||||
ScoreCalculator,
|
||||
TrainingDataCache,
|
||||
TrainingExample,
|
||||
TrainingDataManager
|
||||
)
|
||||
|
||||
|
||||
class TestBoardTransform:
|
||||
"""棋盘变换测试"""
|
||||
|
||||
def test_log_transform(self):
|
||||
"""测试对数变换"""
|
||||
board = np.array([
|
||||
[2, 4, 8, 16],
|
||||
[0, 2, 4, 8],
|
||||
[0, 0, 2, 4],
|
||||
[0, 0, 0, 2]
|
||||
])
|
||||
|
||||
expected = np.array([
|
||||
[1, 2, 3, 4],
|
||||
[0, 1, 2, 3],
|
||||
[0, 0, 1, 2],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
|
||||
result = BoardTransform.log_transform(board)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
# 测试逆变换
|
||||
restored = BoardTransform.inverse_log_transform(result)
|
||||
np.testing.assert_array_equal(restored, board)
|
||||
|
||||
def test_canonical_form(self):
|
||||
"""测试规范形式"""
|
||||
board = np.array([
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]
|
||||
])
|
||||
|
||||
transforms = BoardTransform.get_all_transforms(board)
|
||||
assert len(transforms) == 8
|
||||
|
||||
# 所有变换的规范形式应该相同
|
||||
canonical_forms = []
|
||||
for transform in transforms:
|
||||
canonical, _ = BoardTransform.get_canonical_form(transform)
|
||||
canonical_forms.append(canonical)
|
||||
|
||||
first_canonical = canonical_forms[0]
|
||||
for canonical in canonical_forms[1:]:
|
||||
np.testing.assert_array_equal(canonical, first_canonical)
|
||||
|
||||
def test_hash_consistency(self):
|
||||
"""测试哈希一致性"""
|
||||
board = np.array([[1, 2], [3, 4]])
|
||||
transforms = BoardTransform.get_all_transforms(board)
|
||||
|
||||
hashes = [BoardTransform.compute_hash(t) for t in transforms]
|
||||
first_hash = hashes[0]
|
||||
for hash_val in hashes[1:]:
|
||||
assert hash_val == first_hash
|
||||
|
||||
|
||||
class TestScoreCalculator:
|
||||
"""分数计算测试"""
|
||||
|
||||
def test_tile_value_calculation(self):
|
||||
"""测试瓦片价值计算"""
|
||||
# V(2) = 0, V(4) = 4, V(8) = 16, V(16) = 48
|
||||
assert ScoreCalculator.calculate_tile_value(1) == 0 # 2^1 = 2
|
||||
assert ScoreCalculator.calculate_tile_value(2) == 4 # 2^2 = 4
|
||||
assert ScoreCalculator.calculate_tile_value(3) == 16 # 2^3 = 8
|
||||
assert ScoreCalculator.calculate_tile_value(4) == 48 # 2^4 = 16
|
||||
|
||||
def test_board_score_calculation(self):
|
||||
"""测试棋盘分数计算"""
|
||||
log_board = np.array([
|
||||
[1, 2], # 2, 4
|
||||
[3, 4] # 8, 16
|
||||
])
|
||||
|
||||
total_score = ScoreCalculator.calculate_board_score(log_board)
|
||||
expected = 0 + 4 + 16 + 48 # V(2) + V(4) + V(8) + V(16)
|
||||
assert total_score == expected
|
||||
|
||||
|
||||
class TestTrainingDataCache:
|
||||
"""训练数据缓存测试"""
|
||||
|
||||
def setup_method(self):
|
||||
"""测试前的设置"""
|
||||
self.cache = TrainingDataCache(max_size=5)
|
||||
|
||||
# 创建测试样本
|
||||
self.sample_examples = []
|
||||
for i in range(10):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(i * 100),
|
||||
canonical_hash=f"hash_{i}"
|
||||
)
|
||||
self.sample_examples.append(example)
|
||||
|
||||
def test_basic_operations(self):
|
||||
"""测试基本操作"""
|
||||
assert self.cache.size() == 0
|
||||
assert self.cache.get("nonexistent") is None
|
||||
|
||||
example = self.sample_examples[0]
|
||||
self.cache.put("key1", example)
|
||||
|
||||
assert self.cache.size() == 1
|
||||
retrieved = self.cache.get("key1")
|
||||
assert retrieved is not None
|
||||
assert retrieved.value == example.value
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""测试LRU淘汰"""
|
||||
# 填满缓存
|
||||
for i in range(5):
|
||||
self.cache.put(f"key_{i}", self.sample_examples[i])
|
||||
|
||||
assert self.cache.size() == 5
|
||||
|
||||
# 访问key_1
|
||||
self.cache.get("key_1")
|
||||
|
||||
# 添加新项目,应该淘汰key_0
|
||||
self.cache.put("key_5", self.sample_examples[5])
|
||||
|
||||
assert self.cache.size() == 5
|
||||
assert self.cache.get("key_0") is None
|
||||
assert self.cache.get("key_1") is not None
|
||||
assert self.cache.get("key_5") is not None
|
||||
|
||||
|
||||
class TestTrainingDataManager:
|
||||
"""训练数据管理器测试"""
|
||||
|
||||
def setup_method(self):
|
||||
"""测试前的设置"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.manager = TrainingDataManager(
|
||||
data_dir=self.temp_dir,
|
||||
cache_size=100,
|
||||
board_size=(4, 4)
|
||||
)
|
||||
|
||||
def teardown_method(self):
|
||||
"""测试后的清理"""
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def test_add_and_retrieve_examples(self):
|
||||
"""测试添加和检索样本"""
|
||||
board = np.array([
|
||||
[2, 4, 8, 16],
|
||||
[0, 2, 4, 8],
|
||||
[0, 0, 2, 4],
|
||||
[0, 0, 0, 2]
|
||||
])
|
||||
|
||||
cache_key = self.manager.add_training_example(board, action=1, value=500.0)
|
||||
assert cache_key is not None
|
||||
|
||||
stats = self.manager.get_cache_stats()
|
||||
assert stats["cache_size"] == 1
|
||||
|
||||
dataset = self.manager.get_pytorch_dataset()
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_pytorch_integration(self):
|
||||
"""测试PyTorch集成"""
|
||||
for i in range(5):
|
||||
board = np.random.randint(0, 16, size=(4, 4))
|
||||
board[0, 0] = 2 ** (i % 4 + 1)
|
||||
|
||||
self.manager.add_training_example(board, i % 4, float(i * 50))
|
||||
|
||||
dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False)
|
||||
|
||||
for boards, actions, values in dataloader:
|
||||
assert boards.shape[0] <= 3 # 批次大小
|
||||
assert boards.shape[1] == 18 # 通道数 (max_tile_value + 1)
|
||||
assert boards.shape[2:] == (4, 4) # 棋盘大小
|
||||
break # 只测试第一个批次
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user