330 lines
11 KiB
Python
330 lines
11 KiB
Python
"""
|
||
硬盘持久化系统测试
|
||
|
||
验证TrainingDataPersistence的功能和可靠性
|
||
"""
|
||
|
||
import numpy as np
|
||
import torch
|
||
import pytest
|
||
import tempfile
|
||
import shutil
|
||
import os
|
||
from pathlib import Path
|
||
from training_data import (
|
||
TrainingDataPersistence,
|
||
TrainingDataCache,
|
||
TrainingExample,
|
||
TrainingDataManager
|
||
)
|
||
|
||
|
||
class TestTrainingDataPersistence:
|
||
"""训练数据持久化测试类"""
|
||
|
||
def setup_method(self):
|
||
"""测试前的设置"""
|
||
# 创建临时目录用于测试
|
||
self.temp_dir = tempfile.mkdtemp()
|
||
self.persistence = TrainingDataPersistence(self.temp_dir)
|
||
|
||
# 创建测试样本
|
||
self.test_examples = []
|
||
for i in range(10):
|
||
board = np.random.randint(0, 17, size=(4, 4))
|
||
example = TrainingExample(
|
||
board_state=board,
|
||
action=i % 4,
|
||
value=float(i * 100),
|
||
canonical_hash=f"hash_{i}"
|
||
)
|
||
self.test_examples.append(example)
|
||
|
||
def teardown_method(self):
|
||
"""测试后的清理"""
|
||
# 删除临时目录
|
||
shutil.rmtree(self.temp_dir)
|
||
|
||
def test_save_and_load_cache(self):
|
||
"""测试缓存的保存和加载"""
|
||
# 创建缓存并添加数据
|
||
cache = TrainingDataCache(max_size=100)
|
||
for i, example in enumerate(self.test_examples[:5]):
|
||
cache.put(f"key_{i}", example)
|
||
|
||
# 保存缓存
|
||
filename = "test_cache"
|
||
self.persistence.save_cache(cache, filename)
|
||
|
||
# 验证文件存在
|
||
expected_path = Path(self.temp_dir) / f"{filename}.pkl"
|
||
assert expected_path.exists()
|
||
|
||
# 加载缓存
|
||
loaded_examples = self.persistence.load_cache(filename)
|
||
|
||
# 验证加载的数据
|
||
assert len(loaded_examples) == 5
|
||
|
||
# 验证数据内容
|
||
loaded_values = {ex.value for ex in loaded_examples}
|
||
original_values = {ex.value for ex in self.test_examples[:5]}
|
||
assert loaded_values == original_values
|
||
|
||
def test_save_examples_batch(self):
|
||
"""测试批量保存样本"""
|
||
batch_name = "test_batch"
|
||
examples = self.test_examples[:7]
|
||
|
||
# 保存批次
|
||
self.persistence.save_examples_batch(examples, batch_name)
|
||
|
||
# 验证文件存在
|
||
expected_path = Path(self.temp_dir) / f"{batch_name}.pkl"
|
||
assert expected_path.exists()
|
||
|
||
# 加载并验证
|
||
loaded_examples = self.persistence.load_cache(batch_name)
|
||
assert len(loaded_examples) == 7
|
||
|
||
# 验证数据完整性
|
||
for original, loaded in zip(examples, loaded_examples):
|
||
assert original.action == loaded.action
|
||
assert original.value == loaded.value
|
||
assert original.canonical_hash == loaded.canonical_hash
|
||
np.testing.assert_array_equal(original.board_state, loaded.board_state)
|
||
|
||
def test_load_nonexistent_file(self):
|
||
"""测试加载不存在的文件"""
|
||
loaded_examples = self.persistence.load_cache("nonexistent_file")
|
||
assert loaded_examples == []
|
||
|
||
def test_list_saved_files(self):
|
||
"""测试列出保存的文件"""
|
||
# 初始应该没有文件
|
||
files = self.persistence.list_saved_files()
|
||
assert len(files) == 0
|
||
|
||
# 保存一些文件
|
||
cache = TrainingDataCache(max_size=100)
|
||
cache.put("key1", self.test_examples[0])
|
||
|
||
self.persistence.save_cache(cache, "file1")
|
||
self.persistence.save_cache(cache, "file2")
|
||
self.persistence.save_examples_batch(self.test_examples[:3], "batch1")
|
||
|
||
# 检查文件列表
|
||
files = self.persistence.list_saved_files()
|
||
assert len(files) == 3
|
||
assert "file1" in files
|
||
assert "file2" in files
|
||
assert "batch1" in files
|
||
|
||
def test_large_data_persistence(self):
|
||
"""测试大数据量的持久化"""
|
||
# 创建大量测试数据
|
||
large_examples = []
|
||
for i in range(1000):
|
||
board = np.random.randint(0, 17, size=(4, 4))
|
||
example = TrainingExample(
|
||
board_state=board,
|
||
action=i % 4,
|
||
value=float(i),
|
||
canonical_hash=f"hash_{i}"
|
||
)
|
||
large_examples.append(example)
|
||
|
||
# 保存大批次
|
||
batch_name = "large_batch"
|
||
self.persistence.save_examples_batch(large_examples, batch_name)
|
||
|
||
# 加载并验证
|
||
loaded_examples = self.persistence.load_cache(batch_name)
|
||
assert len(loaded_examples) == 1000
|
||
|
||
# 验证一些随机样本
|
||
for i in [0, 100, 500, 999]:
|
||
assert loaded_examples[i].value == float(i)
|
||
assert loaded_examples[i].action == i % 4
|
||
|
||
def test_data_integrity(self):
|
||
"""测试数据完整性"""
|
||
# 创建包含特殊值的测试数据
|
||
special_board = np.array([
|
||
[0, 1, 2, 17], # 包含边界值
|
||
[3, 4, 5, 6],
|
||
[7, 8, 9, 10],
|
||
[11, 12, 13, 14]
|
||
])
|
||
|
||
special_example = TrainingExample(
|
||
board_state=special_board,
|
||
action=3,
|
||
value=12345.67,
|
||
canonical_hash="special_hash_123"
|
||
)
|
||
|
||
# 保存
|
||
self.persistence.save_examples_batch([special_example], "special_test")
|
||
|
||
# 加载
|
||
loaded = self.persistence.load_cache("special_test")
|
||
assert len(loaded) == 1
|
||
|
||
loaded_example = loaded[0]
|
||
|
||
# 验证所有字段
|
||
np.testing.assert_array_equal(loaded_example.board_state, special_board)
|
||
assert loaded_example.action == 3
|
||
assert abs(loaded_example.value - 12345.67) < 1e-6
|
||
assert loaded_example.canonical_hash == "special_hash_123"
|
||
|
||
|
||
class TestTrainingDataManager:
|
||
"""训练数据管理器测试类"""
|
||
|
||
def setup_method(self):
|
||
"""测试前的设置"""
|
||
self.temp_dir = tempfile.mkdtemp()
|
||
self.manager = TrainingDataManager(
|
||
data_dir=self.temp_dir,
|
||
cache_size=100,
|
||
board_size=(4, 4)
|
||
)
|
||
|
||
def teardown_method(self):
|
||
"""测试后的清理"""
|
||
shutil.rmtree(self.temp_dir)
|
||
|
||
def test_add_and_retrieve_examples(self):
|
||
"""测试添加和检索训练样本"""
|
||
# 创建测试棋盘
|
||
board = np.array([
|
||
[2, 4, 8, 16],
|
||
[0, 2, 4, 8],
|
||
[0, 0, 2, 4],
|
||
[0, 0, 0, 2]
|
||
])
|
||
|
||
# 添加训练样本
|
||
cache_key = self.manager.add_training_example(board, action=1, value=500.0)
|
||
assert cache_key is not None
|
||
|
||
# 验证缓存统计
|
||
stats = self.manager.get_cache_stats()
|
||
assert stats["cache_size"] == 1
|
||
|
||
# 获取PyTorch数据集
|
||
dataset = self.manager.get_pytorch_dataset()
|
||
assert len(dataset) == 1
|
||
|
||
# 验证数据集内容
|
||
board_tensor, action_tensor, value_tensor = dataset[0]
|
||
assert action_tensor.item() == 1
|
||
assert abs(value_tensor.item() - 500.0) < 1e-6
|
||
|
||
def test_save_and_load_workflow(self):
|
||
"""测试完整的保存和加载工作流"""
|
||
# 添加一些训练样本
|
||
boards = [
|
||
np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]]),
|
||
np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]]),
|
||
np.array([[8, 16, 32, 64], [4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8]])
|
||
]
|
||
|
||
for i, board in enumerate(boards):
|
||
for action in range(4):
|
||
value = (i + 1) * 100 + action * 10
|
||
self.manager.add_training_example(board, action, value)
|
||
|
||
# 保存当前缓存
|
||
self.manager.save_current_cache("workflow_test")
|
||
|
||
# 创建新的管理器
|
||
new_manager = TrainingDataManager(
|
||
data_dir=self.temp_dir,
|
||
cache_size=100,
|
||
board_size=(4, 4)
|
||
)
|
||
|
||
# 加载数据
|
||
loaded_count = new_manager.load_from_file("workflow_test")
|
||
assert loaded_count == 12 # 3个棋盘 × 4个动作
|
||
|
||
# 验证数据
|
||
dataset = new_manager.get_pytorch_dataset()
|
||
assert len(dataset) == 12
|
||
|
||
def test_merge_caches(self):
|
||
"""测试缓存合并功能"""
|
||
# 在第一个管理器中添加数据
|
||
board1 = np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]])
|
||
self.manager.add_training_example(board1, 0, 100.0)
|
||
self.manager.add_training_example(board1, 1, 200.0)
|
||
|
||
# 创建第二个管理器
|
||
manager2 = TrainingDataManager(
|
||
data_dir=self.temp_dir,
|
||
cache_size=100,
|
||
board_size=(4, 4)
|
||
)
|
||
|
||
# 在第二个管理器中添加不同的数据
|
||
board2 = np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]])
|
||
manager2.add_training_example(board2, 0, 300.0)
|
||
manager2.add_training_example(board2, 1, 400.0)
|
||
|
||
# 合并缓存
|
||
merged_count = self.manager.merge_caches(manager2)
|
||
assert merged_count == 2
|
||
|
||
# 验证合并后的数据
|
||
stats = self.manager.get_cache_stats()
|
||
assert stats["cache_size"] == 4
|
||
|
||
dataset = self.manager.get_pytorch_dataset()
|
||
assert len(dataset) == 4
|
||
|
||
def test_pytorch_integration(self):
|
||
"""测试PyTorch集成"""
|
||
# 添加测试数据
|
||
for i in range(10):
|
||
board = np.random.randint(0, 16, size=(4, 4))
|
||
# 确保至少有一些非零值
|
||
board[0, 0] = 2 ** (i % 4 + 1)
|
||
|
||
action = i % 4
|
||
value = float(i * 50)
|
||
self.manager.add_training_example(board, action, value)
|
||
|
||
# 获取DataLoader
|
||
dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False)
|
||
|
||
# 验证批次
|
||
batch_count = 0
|
||
total_samples = 0
|
||
|
||
for boards, actions, values in dataloader:
|
||
batch_count += 1
|
||
batch_size = boards.shape[0]
|
||
total_samples += batch_size
|
||
|
||
# 验证张量形状
|
||
assert boards.shape == (batch_size, 18, 4, 4) # max_tile_value + 1 = 18
|
||
assert actions.shape == (batch_size,)
|
||
assert values.shape == (batch_size,)
|
||
|
||
# 验证数据类型
|
||
assert boards.dtype == torch.float32
|
||
assert actions.dtype == torch.long
|
||
assert values.dtype == torch.float32
|
||
|
||
assert total_samples == 10
|
||
assert batch_count == 4 # ceil(10/3) = 4
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 运行测试
|
||
print("运行持久化系统测试...")
|
||
pytest.main([__file__, "-v"])
|