增加L0训练阶段的MCTS部分
This commit is contained in:
209
tests/test_training_data.py
Normal file
209
tests/test_training_data.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
训练数据模块测试
|
||||
|
||||
测试棋盘变换、缓存系统、持久化等核心功能
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
from training_data import (
|
||||
BoardTransform,
|
||||
ScoreCalculator,
|
||||
TrainingDataCache,
|
||||
TrainingExample,
|
||||
TrainingDataManager
|
||||
)
|
||||
|
||||
|
||||
class TestBoardTransform:
|
||||
"""棋盘变换测试"""
|
||||
|
||||
def test_log_transform(self):
|
||||
"""测试对数变换"""
|
||||
board = np.array([
|
||||
[2, 4, 8, 16],
|
||||
[0, 2, 4, 8],
|
||||
[0, 0, 2, 4],
|
||||
[0, 0, 0, 2]
|
||||
])
|
||||
|
||||
expected = np.array([
|
||||
[1, 2, 3, 4],
|
||||
[0, 1, 2, 3],
|
||||
[0, 0, 1, 2],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
|
||||
result = BoardTransform.log_transform(board)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
# 测试逆变换
|
||||
restored = BoardTransform.inverse_log_transform(result)
|
||||
np.testing.assert_array_equal(restored, board)
|
||||
|
||||
def test_canonical_form(self):
|
||||
"""测试规范形式"""
|
||||
board = np.array([
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]
|
||||
])
|
||||
|
||||
transforms = BoardTransform.get_all_transforms(board)
|
||||
assert len(transforms) == 8
|
||||
|
||||
# 所有变换的规范形式应该相同
|
||||
canonical_forms = []
|
||||
for transform in transforms:
|
||||
canonical, _ = BoardTransform.get_canonical_form(transform)
|
||||
canonical_forms.append(canonical)
|
||||
|
||||
first_canonical = canonical_forms[0]
|
||||
for canonical in canonical_forms[1:]:
|
||||
np.testing.assert_array_equal(canonical, first_canonical)
|
||||
|
||||
def test_hash_consistency(self):
|
||||
"""测试哈希一致性"""
|
||||
board = np.array([[1, 2], [3, 4]])
|
||||
transforms = BoardTransform.get_all_transforms(board)
|
||||
|
||||
hashes = [BoardTransform.compute_hash(t) for t in transforms]
|
||||
first_hash = hashes[0]
|
||||
for hash_val in hashes[1:]:
|
||||
assert hash_val == first_hash
|
||||
|
||||
|
||||
class TestScoreCalculator:
|
||||
"""分数计算测试"""
|
||||
|
||||
def test_tile_value_calculation(self):
|
||||
"""测试瓦片价值计算"""
|
||||
# V(2) = 0, V(4) = 4, V(8) = 16, V(16) = 48
|
||||
assert ScoreCalculator.calculate_tile_value(1) == 0 # 2^1 = 2
|
||||
assert ScoreCalculator.calculate_tile_value(2) == 4 # 2^2 = 4
|
||||
assert ScoreCalculator.calculate_tile_value(3) == 16 # 2^3 = 8
|
||||
assert ScoreCalculator.calculate_tile_value(4) == 48 # 2^4 = 16
|
||||
|
||||
def test_board_score_calculation(self):
|
||||
"""测试棋盘分数计算"""
|
||||
log_board = np.array([
|
||||
[1, 2], # 2, 4
|
||||
[3, 4] # 8, 16
|
||||
])
|
||||
|
||||
total_score = ScoreCalculator.calculate_board_score(log_board)
|
||||
expected = 0 + 4 + 16 + 48 # V(2) + V(4) + V(8) + V(16)
|
||||
assert total_score == expected
|
||||
|
||||
|
||||
class TestTrainingDataCache:
|
||||
"""训练数据缓存测试"""
|
||||
|
||||
def setup_method(self):
|
||||
"""测试前的设置"""
|
||||
self.cache = TrainingDataCache(max_size=5)
|
||||
|
||||
# 创建测试样本
|
||||
self.sample_examples = []
|
||||
for i in range(10):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(i * 100),
|
||||
canonical_hash=f"hash_{i}"
|
||||
)
|
||||
self.sample_examples.append(example)
|
||||
|
||||
def test_basic_operations(self):
|
||||
"""测试基本操作"""
|
||||
assert self.cache.size() == 0
|
||||
assert self.cache.get("nonexistent") is None
|
||||
|
||||
example = self.sample_examples[0]
|
||||
self.cache.put("key1", example)
|
||||
|
||||
assert self.cache.size() == 1
|
||||
retrieved = self.cache.get("key1")
|
||||
assert retrieved is not None
|
||||
assert retrieved.value == example.value
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""测试LRU淘汰"""
|
||||
# 填满缓存
|
||||
for i in range(5):
|
||||
self.cache.put(f"key_{i}", self.sample_examples[i])
|
||||
|
||||
assert self.cache.size() == 5
|
||||
|
||||
# 访问key_1
|
||||
self.cache.get("key_1")
|
||||
|
||||
# 添加新项目,应该淘汰key_0
|
||||
self.cache.put("key_5", self.sample_examples[5])
|
||||
|
||||
assert self.cache.size() == 5
|
||||
assert self.cache.get("key_0") is None
|
||||
assert self.cache.get("key_1") is not None
|
||||
assert self.cache.get("key_5") is not None
|
||||
|
||||
|
||||
class TestTrainingDataManager:
|
||||
"""训练数据管理器测试"""
|
||||
|
||||
def setup_method(self):
|
||||
"""测试前的设置"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.manager = TrainingDataManager(
|
||||
data_dir=self.temp_dir,
|
||||
cache_size=100,
|
||||
board_size=(4, 4)
|
||||
)
|
||||
|
||||
def teardown_method(self):
|
||||
"""测试后的清理"""
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def test_add_and_retrieve_examples(self):
|
||||
"""测试添加和检索样本"""
|
||||
board = np.array([
|
||||
[2, 4, 8, 16],
|
||||
[0, 2, 4, 8],
|
||||
[0, 0, 2, 4],
|
||||
[0, 0, 0, 2]
|
||||
])
|
||||
|
||||
cache_key = self.manager.add_training_example(board, action=1, value=500.0)
|
||||
assert cache_key is not None
|
||||
|
||||
stats = self.manager.get_cache_stats()
|
||||
assert stats["cache_size"] == 1
|
||||
|
||||
dataset = self.manager.get_pytorch_dataset()
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_pytorch_integration(self):
|
||||
"""测试PyTorch集成"""
|
||||
for i in range(5):
|
||||
board = np.random.randint(0, 16, size=(4, 4))
|
||||
board[0, 0] = 2 ** (i % 4 + 1)
|
||||
|
||||
self.manager.add_training_example(board, i % 4, float(i * 50))
|
||||
|
||||
dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False)
|
||||
|
||||
for boards, actions, values in dataloader:
|
||||
assert boards.shape[0] <= 3 # 批次大小
|
||||
assert boards.shape[1] == 18 # 通道数 (max_tile_value + 1)
|
||||
assert boards.shape[2:] == (4, 4) # 棋盘大小
|
||||
break # 只测试第一个批次
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user