Files
deep2048/tests/test_persistence.py
2025-07-23 07:04:10 +08:00

330 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
硬盘持久化系统测试
验证TrainingDataPersistence的功能和可靠性
"""
import numpy as np
import torch
import pytest
import tempfile
import shutil
import os
from pathlib import Path
from training_data import (
TrainingDataPersistence,
TrainingDataCache,
TrainingExample,
TrainingDataManager
)
class TestTrainingDataPersistence:
"""训练数据持久化测试类"""
def setup_method(self):
"""测试前的设置"""
# 创建临时目录用于测试
self.temp_dir = tempfile.mkdtemp()
self.persistence = TrainingDataPersistence(self.temp_dir)
# 创建测试样本
self.test_examples = []
for i in range(10):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i * 100),
canonical_hash=f"hash_{i}"
)
self.test_examples.append(example)
def teardown_method(self):
"""测试后的清理"""
# 删除临时目录
shutil.rmtree(self.temp_dir)
def test_save_and_load_cache(self):
"""测试缓存的保存和加载"""
# 创建缓存并添加数据
cache = TrainingDataCache(max_size=100)
for i, example in enumerate(self.test_examples[:5]):
cache.put(f"key_{i}", example)
# 保存缓存
filename = "test_cache"
self.persistence.save_cache(cache, filename)
# 验证文件存在
expected_path = Path(self.temp_dir) / f"{filename}.pkl"
assert expected_path.exists()
# 加载缓存
loaded_examples = self.persistence.load_cache(filename)
# 验证加载的数据
assert len(loaded_examples) == 5
# 验证数据内容
loaded_values = {ex.value for ex in loaded_examples}
original_values = {ex.value for ex in self.test_examples[:5]}
assert loaded_values == original_values
def test_save_examples_batch(self):
"""测试批量保存样本"""
batch_name = "test_batch"
examples = self.test_examples[:7]
# 保存批次
self.persistence.save_examples_batch(examples, batch_name)
# 验证文件存在
expected_path = Path(self.temp_dir) / f"{batch_name}.pkl"
assert expected_path.exists()
# 加载并验证
loaded_examples = self.persistence.load_cache(batch_name)
assert len(loaded_examples) == 7
# 验证数据完整性
for original, loaded in zip(examples, loaded_examples):
assert original.action == loaded.action
assert original.value == loaded.value
assert original.canonical_hash == loaded.canonical_hash
np.testing.assert_array_equal(original.board_state, loaded.board_state)
def test_load_nonexistent_file(self):
"""测试加载不存在的文件"""
loaded_examples = self.persistence.load_cache("nonexistent_file")
assert loaded_examples == []
def test_list_saved_files(self):
"""测试列出保存的文件"""
# 初始应该没有文件
files = self.persistence.list_saved_files()
assert len(files) == 0
# 保存一些文件
cache = TrainingDataCache(max_size=100)
cache.put("key1", self.test_examples[0])
self.persistence.save_cache(cache, "file1")
self.persistence.save_cache(cache, "file2")
self.persistence.save_examples_batch(self.test_examples[:3], "batch1")
# 检查文件列表
files = self.persistence.list_saved_files()
assert len(files) == 3
assert "file1" in files
assert "file2" in files
assert "batch1" in files
def test_large_data_persistence(self):
"""测试大数据量的持久化"""
# 创建大量测试数据
large_examples = []
for i in range(1000):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i),
canonical_hash=f"hash_{i}"
)
large_examples.append(example)
# 保存大批次
batch_name = "large_batch"
self.persistence.save_examples_batch(large_examples, batch_name)
# 加载并验证
loaded_examples = self.persistence.load_cache(batch_name)
assert len(loaded_examples) == 1000
# 验证一些随机样本
for i in [0, 100, 500, 999]:
assert loaded_examples[i].value == float(i)
assert loaded_examples[i].action == i % 4
def test_data_integrity(self):
"""测试数据完整性"""
# 创建包含特殊值的测试数据
special_board = np.array([
[0, 1, 2, 17], # 包含边界值
[3, 4, 5, 6],
[7, 8, 9, 10],
[11, 12, 13, 14]
])
special_example = TrainingExample(
board_state=special_board,
action=3,
value=12345.67,
canonical_hash="special_hash_123"
)
# 保存
self.persistence.save_examples_batch([special_example], "special_test")
# 加载
loaded = self.persistence.load_cache("special_test")
assert len(loaded) == 1
loaded_example = loaded[0]
# 验证所有字段
np.testing.assert_array_equal(loaded_example.board_state, special_board)
assert loaded_example.action == 3
assert abs(loaded_example.value - 12345.67) < 1e-6
assert loaded_example.canonical_hash == "special_hash_123"
class TestTrainingDataManager:
"""训练数据管理器测试类"""
def setup_method(self):
"""测试前的设置"""
self.temp_dir = tempfile.mkdtemp()
self.manager = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
def teardown_method(self):
"""测试后的清理"""
shutil.rmtree(self.temp_dir)
def test_add_and_retrieve_examples(self):
"""测试添加和检索训练样本"""
# 创建测试棋盘
board = np.array([
[2, 4, 8, 16],
[0, 2, 4, 8],
[0, 0, 2, 4],
[0, 0, 0, 2]
])
# 添加训练样本
cache_key = self.manager.add_training_example(board, action=1, value=500.0)
assert cache_key is not None
# 验证缓存统计
stats = self.manager.get_cache_stats()
assert stats["cache_size"] == 1
# 获取PyTorch数据集
dataset = self.manager.get_pytorch_dataset()
assert len(dataset) == 1
# 验证数据集内容
board_tensor, action_tensor, value_tensor = dataset[0]
assert action_tensor.item() == 1
assert abs(value_tensor.item() - 500.0) < 1e-6
def test_save_and_load_workflow(self):
"""测试完整的保存和加载工作流"""
# 添加一些训练样本
boards = [
np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]]),
np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]]),
np.array([[8, 16, 32, 64], [4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8]])
]
for i, board in enumerate(boards):
for action in range(4):
value = (i + 1) * 100 + action * 10
self.manager.add_training_example(board, action, value)
# 保存当前缓存
self.manager.save_current_cache("workflow_test")
# 创建新的管理器
new_manager = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
# 加载数据
loaded_count = new_manager.load_from_file("workflow_test")
assert loaded_count == 12 # 3个棋盘 × 4个动作
# 验证数据
dataset = new_manager.get_pytorch_dataset()
assert len(dataset) == 12
def test_merge_caches(self):
"""测试缓存合并功能"""
# 在第一个管理器中添加数据
board1 = np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]])
self.manager.add_training_example(board1, 0, 100.0)
self.manager.add_training_example(board1, 1, 200.0)
# 创建第二个管理器
manager2 = TrainingDataManager(
data_dir=self.temp_dir,
cache_size=100,
board_size=(4, 4)
)
# 在第二个管理器中添加不同的数据
board2 = np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]])
manager2.add_training_example(board2, 0, 300.0)
manager2.add_training_example(board2, 1, 400.0)
# 合并缓存
merged_count = self.manager.merge_caches(manager2)
assert merged_count == 2
# 验证合并后的数据
stats = self.manager.get_cache_stats()
assert stats["cache_size"] == 4
dataset = self.manager.get_pytorch_dataset()
assert len(dataset) == 4
def test_pytorch_integration(self):
"""测试PyTorch集成"""
# 添加测试数据
for i in range(10):
board = np.random.randint(0, 16, size=(4, 4))
# 确保至少有一些非零值
board[0, 0] = 2 ** (i % 4 + 1)
action = i % 4
value = float(i * 50)
self.manager.add_training_example(board, action, value)
# 获取DataLoader
dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False)
# 验证批次
batch_count = 0
total_samples = 0
for boards, actions, values in dataloader:
batch_count += 1
batch_size = boards.shape[0]
total_samples += batch_size
# 验证张量形状
assert boards.shape == (batch_size, 18, 4, 4) # max_tile_value + 1 = 18
assert actions.shape == (batch_size,)
assert values.shape == (batch_size,)
# 验证数据类型
assert boards.dtype == torch.float32
assert actions.dtype == torch.long
assert values.dtype == torch.float32
assert total_samples == 10
assert batch_count == 4 # ceil(10/3) = 4
if __name__ == "__main__":
# 运行测试
print("运行持久化系统测试...")
pytest.main([__file__, "-v"])