210 lines
6.2 KiB
Python
210 lines
6.2 KiB
Python
"""
|
||
训练数据模块测试
|
||
|
||
测试棋盘变换、缓存系统、持久化等核心功能
|
||
"""
|
||
|
||
import numpy as np
|
||
import pytest
|
||
import tempfile
|
||
import shutil
|
||
from pathlib import Path
|
||
import torch
|
||
|
||
from training_data import (
|
||
BoardTransform,
|
||
ScoreCalculator,
|
||
TrainingDataCache,
|
||
TrainingExample,
|
||
TrainingDataManager
|
||
)
|
||
|
||
|
||
class TestBoardTransform:
|
||
"""棋盘变换测试"""
|
||
|
||
def test_log_transform(self):
|
||
"""测试对数变换"""
|
||
board = np.array([
|
||
[2, 4, 8, 16],
|
||
[0, 2, 4, 8],
|
||
[0, 0, 2, 4],
|
||
[0, 0, 0, 2]
|
||
])
|
||
|
||
expected = np.array([
|
||
[1, 2, 3, 4],
|
||
[0, 1, 2, 3],
|
||
[0, 0, 1, 2],
|
||
[0, 0, 0, 1]
|
||
])
|
||
|
||
result = BoardTransform.log_transform(board)
|
||
np.testing.assert_array_equal(result, expected)
|
||
|
||
# 测试逆变换
|
||
restored = BoardTransform.inverse_log_transform(result)
|
||
np.testing.assert_array_equal(restored, board)
|
||
|
||
def test_canonical_form(self):
|
||
"""测试规范形式"""
|
||
board = np.array([
|
||
[1, 2, 3, 4],
|
||
[5, 6, 7, 8],
|
||
[9, 10, 11, 12],
|
||
[13, 14, 15, 16]
|
||
])
|
||
|
||
transforms = BoardTransform.get_all_transforms(board)
|
||
assert len(transforms) == 8
|
||
|
||
# 所有变换的规范形式应该相同
|
||
canonical_forms = []
|
||
for transform in transforms:
|
||
canonical, _ = BoardTransform.get_canonical_form(transform)
|
||
canonical_forms.append(canonical)
|
||
|
||
first_canonical = canonical_forms[0]
|
||
for canonical in canonical_forms[1:]:
|
||
np.testing.assert_array_equal(canonical, first_canonical)
|
||
|
||
def test_hash_consistency(self):
|
||
"""测试哈希一致性"""
|
||
board = np.array([[1, 2], [3, 4]])
|
||
transforms = BoardTransform.get_all_transforms(board)
|
||
|
||
hashes = [BoardTransform.compute_hash(t) for t in transforms]
|
||
first_hash = hashes[0]
|
||
for hash_val in hashes[1:]:
|
||
assert hash_val == first_hash
|
||
|
||
|
||
class TestScoreCalculator:
|
||
"""分数计算测试"""
|
||
|
||
def test_tile_value_calculation(self):
|
||
"""测试瓦片价值计算"""
|
||
# V(2) = 0, V(4) = 4, V(8) = 16, V(16) = 48
|
||
assert ScoreCalculator.calculate_tile_value(1) == 0 # 2^1 = 2
|
||
assert ScoreCalculator.calculate_tile_value(2) == 4 # 2^2 = 4
|
||
assert ScoreCalculator.calculate_tile_value(3) == 16 # 2^3 = 8
|
||
assert ScoreCalculator.calculate_tile_value(4) == 48 # 2^4 = 16
|
||
|
||
def test_board_score_calculation(self):
|
||
"""测试棋盘分数计算"""
|
||
log_board = np.array([
|
||
[1, 2], # 2, 4
|
||
[3, 4] # 8, 16
|
||
])
|
||
|
||
total_score = ScoreCalculator.calculate_board_score(log_board)
|
||
expected = 0 + 4 + 16 + 48 # V(2) + V(4) + V(8) + V(16)
|
||
assert total_score == expected
|
||
|
||
|
||
class TestTrainingDataCache:
|
||
"""训练数据缓存测试"""
|
||
|
||
def setup_method(self):
|
||
"""测试前的设置"""
|
||
self.cache = TrainingDataCache(max_size=5)
|
||
|
||
# 创建测试样本
|
||
self.sample_examples = []
|
||
for i in range(10):
|
||
board = np.random.randint(0, 17, size=(4, 4))
|
||
example = TrainingExample(
|
||
board_state=board,
|
||
action=i % 4,
|
||
value=float(i * 100),
|
||
canonical_hash=f"hash_{i}"
|
||
)
|
||
self.sample_examples.append(example)
|
||
|
||
def test_basic_operations(self):
|
||
"""测试基本操作"""
|
||
assert self.cache.size() == 0
|
||
assert self.cache.get("nonexistent") is None
|
||
|
||
example = self.sample_examples[0]
|
||
self.cache.put("key1", example)
|
||
|
||
assert self.cache.size() == 1
|
||
retrieved = self.cache.get("key1")
|
||
assert retrieved is not None
|
||
assert retrieved.value == example.value
|
||
|
||
def test_lru_eviction(self):
|
||
"""测试LRU淘汰"""
|
||
# 填满缓存
|
||
for i in range(5):
|
||
self.cache.put(f"key_{i}", self.sample_examples[i])
|
||
|
||
assert self.cache.size() == 5
|
||
|
||
# 访问key_1
|
||
self.cache.get("key_1")
|
||
|
||
# 添加新项目,应该淘汰key_0
|
||
self.cache.put("key_5", self.sample_examples[5])
|
||
|
||
assert self.cache.size() == 5
|
||
assert self.cache.get("key_0") is None
|
||
assert self.cache.get("key_1") is not None
|
||
assert self.cache.get("key_5") is not None
|
||
|
||
|
||
class TestTrainingDataManager:
|
||
"""训练数据管理器测试"""
|
||
|
||
def setup_method(self):
|
||
"""测试前的设置"""
|
||
self.temp_dir = tempfile.mkdtemp()
|
||
self.manager = TrainingDataManager(
|
||
data_dir=self.temp_dir,
|
||
cache_size=100,
|
||
board_size=(4, 4)
|
||
)
|
||
|
||
def teardown_method(self):
|
||
"""测试后的清理"""
|
||
shutil.rmtree(self.temp_dir)
|
||
|
||
def test_add_and_retrieve_examples(self):
|
||
"""测试添加和检索样本"""
|
||
board = np.array([
|
||
[2, 4, 8, 16],
|
||
[0, 2, 4, 8],
|
||
[0, 0, 2, 4],
|
||
[0, 0, 0, 2]
|
||
])
|
||
|
||
cache_key = self.manager.add_training_example(board, action=1, value=500.0)
|
||
assert cache_key is not None
|
||
|
||
stats = self.manager.get_cache_stats()
|
||
assert stats["cache_size"] == 1
|
||
|
||
dataset = self.manager.get_pytorch_dataset()
|
||
assert len(dataset) == 1
|
||
|
||
def test_pytorch_integration(self):
|
||
"""测试PyTorch集成"""
|
||
for i in range(5):
|
||
board = np.random.randint(0, 16, size=(4, 4))
|
||
board[0, 0] = 2 ** (i % 4 + 1)
|
||
|
||
self.manager.add_training_example(board, i % 4, float(i * 50))
|
||
|
||
dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False)
|
||
|
||
for boards, actions, values in dataloader:
|
||
assert boards.shape[0] <= 3 # 批次大小
|
||
assert boards.shape[1] == 18 # 通道数 (max_tile_value + 1)
|
||
assert boards.shape[2:] == (4, 4) # 棋盘大小
|
||
break # 只测试第一个批次
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|