增加L0训练阶段的MCTS部分

This commit is contained in:
hisatri
2025-07-23 07:04:10 +08:00
parent 88bed2a1ef
commit 4410defbe5
23 changed files with 5205 additions and 0 deletions

209
tests/test_training_data.py Normal file
View File

@@ -0,0 +1,209 @@
"""
训练数据模块测试
测试棋盘变换、缓存系统、持久化等核心功能
"""
import numpy as np
import pytest
import tempfile
import shutil
from pathlib import Path
import torch
from training_data import (
BoardTransform,
ScoreCalculator,
TrainingDataCache,
TrainingExample,
TrainingDataManager
)
class TestBoardTransform:
"""棋盘变换测试"""
def test_log_transform(self):
"""测试对数变换"""
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
expected = np.array([
[1, 2, 3, 4],
[0, 1, 2, 3],
[0, 0, 1, 2],
[0, 0, 0, 1]
])
result = BoardTransform.log_transform(board)
np.testing.assert_array_equal(result, expected)
# 测试逆变换
restored = BoardTransform.inverse_log_transform(result)
np.testing.assert_array_equal(restored, board)
def test_canonical_form(self):
"""测试规范形式"""
board = np.array([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
])
transforms = BoardTransform.get_all_transforms(board)
assert len(transforms) == 8
# 所有变换的规范形式应该相同
canonical_forms = []
for transform in transforms:
canonical, _ = BoardTransform.get_canonical_form(transform)
canonical_forms.append(canonical)
first_canonical = canonical_forms[0]
for canonical in canonical_forms[1:]:
np.testing.assert_array_equal(canonical, first_canonical)
def test_hash_consistency(self):
"""测试哈希一致性"""
board = np.array([[1, 2], [3, 4]])
transforms = BoardTransform.get_all_transforms(board)
hashes = [BoardTransform.compute_hash(t) for t in transforms]
first_hash = hashes[0]
for hash_val in hashes[1:]:
assert hash_val == first_hash
class TestScoreCalculator:
"""分数计算测试"""
def test_tile_value_calculation(self):
"""测试瓦片价值计算"""
# V(2) = 0, V(4) = 4, V(8) = 16, V(16) = 48
assert ScoreCalculator.calculate_tile_value(1) == 0 # 2^1 = 2
assert ScoreCalculator.calculate_tile_value(2) == 4 # 2^2 = 4
assert ScoreCalculator.calculate_tile_value(3) == 16 # 2^3 = 8
assert ScoreCalculator.calculate_tile_value(4) == 48 # 2^4 = 16
def test_board_score_calculation(self):
"""测试棋盘分数计算"""
log_board = np.array([
[1, 2], # 2, 4
[3, 4] # 8, 16
])
total_score = ScoreCalculator.calculate_board_score(log_board)
expected = 0 + 4 + 16 + 48 # V(2) + V(4) + V(8) + V(16)
assert total_score == expected
class TestTrainingDataCache:
"""训练数据缓存测试"""
def setup_method(self):
"""测试前的设置"""
self.cache = TrainingDataCache(max_size=5)
# 创建测试样本
self.sample_examples = []
for i in range(10):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i * 100),
canonical_hash=f"hash_{i}"
)
self.sample_examples.append(example)
def test_basic_operations(self):
"""测试基本操作"""
assert self.cache.size() == 0
assert self.cache.get("nonexistent") is None
example = self.sample_examples[0]
self.cache.put("key1", example)
assert self.cache.size() == 1
retrieved = self.cache.get("key1")
assert retrieved is not None
assert retrieved.value == example.value
def test_lru_eviction(self):
"""测试LRU淘汰"""
# 填满缓存
for i in range(5):
self.cache.put(f"key_{i}", self.sample_examples[i])
assert self.cache.size() == 5
# 访问key_1
self.cache.get("key_1")
# 添加新项目应该淘汰key_0
self.cache.put("key_5", self.sample_examples[5])
assert self.cache.size() == 5
assert self.cache.get("key_0") is None
assert self.cache.get("key_1") is not None
assert self.cache.get("key_5") is not None
class TestTrainingDataManager:
"""训练数据管理器测试"""
def setup_method(self):
"""测试前的设置"""
self.temp_dir = tempfile.mkdtemp()
self.manager = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
def teardown_method(self):
"""测试后的清理"""
shutil.rmtree(self.temp_dir)
def test_add_and_retrieve_examples(self):
"""测试添加和检索样本"""
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
cache_key = self.manager.add_training_example(board, action=1, value=500.0)
assert cache_key is not None
stats = self.manager.get_cache_stats()
assert stats["cache_size"] == 1
dataset = self.manager.get_pytorch_dataset()
assert len(dataset) == 1
def test_pytorch_integration(self):
"""测试PyTorch集成"""
for i in range(5):
board = np.random.randint(0, 16, size=(4, 4))
board[0, 0] = 2 ** (i % 4 + 1)
self.manager.add_training_example(board, i % 4, float(i * 50))
dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False)
for boards, actions, values in dataloader:
assert boards.shape[0] <= 3 # 批次大小
assert boards.shape[1] == 18 # 通道数 (max_tile_value + 1)
assert boards.shape[2:] == (4, 4) # 棋盘大小
break # 只测试第一个批次
if __name__ == "__main__":
pytest.main([__file__, "-v"])