Files
deep2048/tests/test_training_data.py
2025-07-23 07:04:10 +08:00

210 lines
6.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
训练数据模块测试
测试棋盘变换、缓存系统、持久化等核心功能
"""
import numpy as np
import pytest
import tempfile
import shutil
from pathlib import Path
import torch
from training_data import (
BoardTransform,
ScoreCalculator,
TrainingDataCache,
TrainingExample,
TrainingDataManager
)
class TestBoardTransform:
"""棋盘变换测试"""
def test_log_transform(self):
"""测试对数变换"""
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
expected = np.array([
[1, 2, 3, 4],
[0, 1, 2, 3],
[0, 0, 1, 2],
[0, 0, 0, 1]
])
result = BoardTransform.log_transform(board)
np.testing.assert_array_equal(result, expected)
# 测试逆变换
restored = BoardTransform.inverse_log_transform(result)
np.testing.assert_array_equal(restored, board)
def test_canonical_form(self):
"""测试规范形式"""
board = np.array([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
])
transforms = BoardTransform.get_all_transforms(board)
assert len(transforms) == 8
# 所有变换的规范形式应该相同
canonical_forms = []
for transform in transforms:
canonical, _ = BoardTransform.get_canonical_form(transform)
canonical_forms.append(canonical)
first_canonical = canonical_forms[0]
for canonical in canonical_forms[1:]:
np.testing.assert_array_equal(canonical, first_canonical)
def test_hash_consistency(self):
"""测试哈希一致性"""
board = np.array([[1, 2], [3, 4]])
transforms = BoardTransform.get_all_transforms(board)
hashes = [BoardTransform.compute_hash(t) for t in transforms]
first_hash = hashes[0]
for hash_val in hashes[1:]:
assert hash_val == first_hash
class TestScoreCalculator:
"""分数计算测试"""
def test_tile_value_calculation(self):
"""测试瓦片价值计算"""
# V(2) = 0, V(4) = 4, V(8) = 16, V(16) = 48
assert ScoreCalculator.calculate_tile_value(1) == 0 # 2^1 = 2
assert ScoreCalculator.calculate_tile_value(2) == 4 # 2^2 = 4
assert ScoreCalculator.calculate_tile_value(3) == 16 # 2^3 = 8
assert ScoreCalculator.calculate_tile_value(4) == 48 # 2^4 = 16
def test_board_score_calculation(self):
"""测试棋盘分数计算"""
log_board = np.array([
[1, 2], # 2, 4
[3, 4] # 8, 16
])
total_score = ScoreCalculator.calculate_board_score(log_board)
expected = 0 + 4 + 16 + 48 # V(2) + V(4) + V(8) + V(16)
assert total_score == expected
class TestTrainingDataCache:
"""训练数据缓存测试"""
def setup_method(self):
"""测试前的设置"""
self.cache = TrainingDataCache(max_size=5)
# 创建测试样本
self.sample_examples = []
for i in range(10):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i * 100),
canonical_hash=f"hash_{i}"
)
self.sample_examples.append(example)
def test_basic_operations(self):
"""测试基本操作"""
assert self.cache.size() == 0
assert self.cache.get("nonexistent") is None
example = self.sample_examples[0]
self.cache.put("key1", example)
assert self.cache.size() == 1
retrieved = self.cache.get("key1")
assert retrieved is not None
assert retrieved.value == example.value
def test_lru_eviction(self):
"""测试LRU淘汰"""
# 填满缓存
for i in range(5):
self.cache.put(f"key_{i}", self.sample_examples[i])
assert self.cache.size() == 5
# 访问key_1
self.cache.get("key_1")
# 添加新项目应该淘汰key_0
self.cache.put("key_5", self.sample_examples[5])
assert self.cache.size() == 5
assert self.cache.get("key_0") is None
assert self.cache.get("key_1") is not None
assert self.cache.get("key_5") is not None
class TestTrainingDataManager:
"""训练数据管理器测试"""
def setup_method(self):
"""测试前的设置"""
self.temp_dir = tempfile.mkdtemp()
self.manager = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
def teardown_method(self):
"""测试后的清理"""
shutil.rmtree(self.temp_dir)
def test_add_and_retrieve_examples(self):
"""测试添加和检索样本"""
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
cache_key = self.manager.add_training_example(board, action=1, value=500.0)
assert cache_key is not None
stats = self.manager.get_cache_stats()
assert stats["cache_size"] == 1
dataset = self.manager.get_pytorch_dataset()
assert len(dataset) == 1
def test_pytorch_integration(self):
"""测试PyTorch集成"""
for i in range(5):
board = np.random.randint(0, 16, size=(4, 4))
board[0, 0] = 2 ** (i % 4 + 1)
self.manager.add_training_example(board, i % 4, float(i * 50))
dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False)
for boards, actions, values in dataloader:
assert boards.shape[0] <= 3 # 批次大小
assert boards.shape[1] == 18 # 通道数 (max_tile_value + 1)
assert boards.shape[2:] == (4, 4) # 棋盘大小
break # 只测试第一个批次
if __name__ == "__main__":
pytest.main([__file__, "-v"])