增加L0训练阶段的MCTS部分

This commit is contained in:
hisatri
2025-07-23 07:04:10 +08:00
parent 88bed2a1ef
commit 4410defbe5
23 changed files with 5205 additions and 0 deletions

329
tests/test_persistence.py Normal file
View File

@@ -0,0 +1,329 @@
"""
硬盘持久化系统测试
验证TrainingDataPersistence的功能和可靠性
"""
import numpy as np
import torch
import pytest
import tempfile
import shutil
import os
from pathlib import Path
from training_data import (
TrainingDataPersistence,
TrainingDataCache,
TrainingExample,
TrainingDataManager
)
class TestTrainingDataPersistence:
"""训练数据持久化测试类"""
def setup_method(self):
"""测试前的设置"""
# 创建临时目录用于测试
self.temp_dir = tempfile.mkdtemp()
self.persistence = TrainingDataPersistence(self.temp_dir)
# 创建测试样本
self.test_examples = []
for i in range(10):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i * 100),
canonical_hash=f"hash_{i}"
)
self.test_examples.append(example)
def teardown_method(self):
"""测试后的清理"""
# 删除临时目录
shutil.rmtree(self.temp_dir)
def test_save_and_load_cache(self):
"""测试缓存的保存和加载"""
# 创建缓存并添加数据
cache = TrainingDataCache(max_size=100)
for i, example in enumerate(self.test_examples[:5]):
cache.put(f"key_{i}", example)
# 保存缓存
filename = "test_cache"
self.persistence.save_cache(cache, filename)
# 验证文件存在
expected_path = Path(self.temp_dir) / f"{filename}.pkl"
assert expected_path.exists()
# 加载缓存
loaded_examples = self.persistence.load_cache(filename)
# 验证加载的数据
assert len(loaded_examples) == 5
# 验证数据内容
loaded_values = {ex.value for ex in loaded_examples}
original_values = {ex.value for ex in self.test_examples[:5]}
assert loaded_values == original_values
def test_save_examples_batch(self):
"""测试批量保存样本"""
batch_name = "test_batch"
examples = self.test_examples[:7]
# 保存批次
self.persistence.save_examples_batch(examples, batch_name)
# 验证文件存在
expected_path = Path(self.temp_dir) / f"{batch_name}.pkl"
assert expected_path.exists()
# 加载并验证
loaded_examples = self.persistence.load_cache(batch_name)
assert len(loaded_examples) == 7
# 验证数据完整性
for original, loaded in zip(examples, loaded_examples):
assert original.action == loaded.action
assert original.value == loaded.value
assert original.canonical_hash == loaded.canonical_hash
np.testing.assert_array_equal(original.board_state, loaded.board_state)
def test_load_nonexistent_file(self):
"""测试加载不存在的文件"""
loaded_examples = self.persistence.load_cache("nonexistent_file")
assert loaded_examples == []
def test_list_saved_files(self):
"""测试列出保存的文件"""
# 初始应该没有文件
files = self.persistence.list_saved_files()
assert len(files) == 0
# 保存一些文件
cache = TrainingDataCache(max_size=100)
cache.put("key1", self.test_examples[0])
self.persistence.save_cache(cache, "file1")
self.persistence.save_cache(cache, "file2")
self.persistence.save_examples_batch(self.test_examples[:3], "batch1")
# 检查文件列表
files = self.persistence.list_saved_files()
assert len(files) == 3
assert "file1" in files
assert "file2" in files
assert "batch1" in files
def test_large_data_persistence(self):
"""测试大数据量的持久化"""
# 创建大量测试数据
large_examples = []
for i in range(1000):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i),
canonical_hash=f"hash_{i}"
)
large_examples.append(example)
# 保存大批次
batch_name = "large_batch"
self.persistence.save_examples_batch(large_examples, batch_name)
# 加载并验证
loaded_examples = self.persistence.load_cache(batch_name)
assert len(loaded_examples) == 1000
# 验证一些随机样本
for i in [0, 100, 500, 999]:
assert loaded_examples[i].value == float(i)
assert loaded_examples[i].action == i % 4
def test_data_integrity(self):
"""测试数据完整性"""
# 创建包含特殊值的测试数据
special_board = np.array([
[0, 1, 2, 17], # 包含边界值
[3, 4, 5, 6],
[7, 8, 9, 10],
[11, 12, 13, 14]
])
special_example = TrainingExample(
board_state=special_board,
action=3,
value=12345.67,
canonical_hash="special_hash_123"
)
# 保存
self.persistence.save_examples_batch([special_example], "special_test")
# 加载
loaded = self.persistence.load_cache("special_test")
assert len(loaded) == 1
loaded_example = loaded[0]
# 验证所有字段
np.testing.assert_array_equal(loaded_example.board_state, special_board)
assert loaded_example.action == 3
assert abs(loaded_example.value - 12345.67) < 1e-6
assert loaded_example.canonical_hash == "special_hash_123"
class TestTrainingDataManager:
"""训练数据管理器测试类"""
def setup_method(self):
"""测试前的设置"""
self.temp_dir = tempfile.mkdtemp()
self.manager = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
def teardown_method(self):
"""测试后的清理"""
shutil.rmtree(self.temp_dir)
def test_add_and_retrieve_examples(self):
"""测试添加和检索训练样本"""
# 创建测试棋盘
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
# 添加训练样本
cache_key = self.manager.add_training_example(board, action=1, value=500.0)
assert cache_key is not None
# 验证缓存统计
stats = self.manager.get_cache_stats()
assert stats["cache_size"] == 1
# 获取PyTorch数据集
dataset = self.manager.get_pytorch_dataset()
assert len(dataset) == 1
# 验证数据集内容
board_tensor, action_tensor, value_tensor = dataset[0]
assert action_tensor.item() == 1
assert abs(value_tensor.item() - 500.0) < 1e-6
def test_save_and_load_workflow(self):
"""测试完整的保存和加载工作流"""
# 添加一些训练样本
boards = [
np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]]),
np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]]),
np.array([[8, 16, 32, 64], [4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8]])
]
for i, board in enumerate(boards):
for action in range(4):
value = (i + 1) * 100 + action * 10
self.manager.add_training_example(board, action, value)
# 保存当前缓存
self.manager.save_current_cache("workflow_test")
# 创建新的管理器
new_manager = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
# 加载数据
loaded_count = new_manager.load_from_file("workflow_test")
assert loaded_count == 12 # 3个棋盘 × 4个动作
# 验证数据
dataset = new_manager.get_pytorch_dataset()
assert len(dataset) == 12
def test_merge_caches(self):
"""测试缓存合并功能"""
# 在第一个管理器中添加数据
board1 = np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]])
self.manager.add_training_example(board1, 0, 100.0)
self.manager.add_training_example(board1, 1, 200.0)
# 创建第二个管理器
manager2 = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
# 在第二个管理器中添加不同的数据
board2 = np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]])
manager2.add_training_example(board2, 0, 300.0)
manager2.add_training_example(board2, 1, 400.0)
# 合并缓存
merged_count = self.manager.merge_caches(manager2)
assert merged_count == 2
# 验证合并后的数据
stats = self.manager.get_cache_stats()
assert stats["cache_size"] == 4
dataset = self.manager.get_pytorch_dataset()
assert len(dataset) == 4
def test_pytorch_integration(self):
"""测试PyTorch集成"""
# 添加测试数据
for i in range(10):
board = np.random.randint(0, 16, size=(4, 4))
# 确保至少有一些非零值
board[0, 0] = 2 ** (i % 4 + 1)
action = i % 4
value = float(i * 50)
self.manager.add_training_example(board, action, value)
# 获取DataLoader
dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False)
# 验证批次
batch_count = 0
total_samples = 0
for boards, actions, values in dataloader:
batch_count += 1
batch_size = boards.shape[0]
total_samples += batch_size
# 验证张量形状
assert boards.shape == (batch_size, 18, 4, 4) # max_tile_value + 1 = 18
assert actions.shape == (batch_size,)
assert values.shape == (batch_size,)
# 验证数据类型
assert boards.dtype == torch.float32
assert actions.dtype == torch.long
assert values.dtype == torch.float32
assert total_samples == 10
assert batch_count == 4 # ceil(10/3) = 4
if __name__ == "__main__":
# 运行测试
print("运行持久化系统测试...")
pytest.main([__file__, "-v"])