""" 硬盘持久化系统测试 验证TrainingDataPersistence的功能和可靠性 """ import numpy as np import torch import pytest import tempfile import shutil import os from pathlib import Path from training_data import ( TrainingDataPersistence, TrainingDataCache, TrainingExample, TrainingDataManager ) class TestTrainingDataPersistence: """训练数据持久化测试类""" def setup_method(self): """测试前的设置""" # 创建临时目录用于测试 self.temp_dir = tempfile.mkdtemp() self.persistence = TrainingDataPersistence(self.temp_dir) # 创建测试样本 self.test_examples = [] for i in range(10): board = np.random.randint(0, 17, size=(4, 4)) example = TrainingExample( board_state=board, action=i % 4, value=float(i * 100), canonical_hash=f"hash_{i}" ) self.test_examples.append(example) def teardown_method(self): """测试后的清理""" # 删除临时目录 shutil.rmtree(self.temp_dir) def test_save_and_load_cache(self): """测试缓存的保存和加载""" # 创建缓存并添加数据 cache = TrainingDataCache(max_size=100) for i, example in enumerate(self.test_examples[:5]): cache.put(f"key_{i}", example) # 保存缓存 filename = "test_cache" self.persistence.save_cache(cache, filename) # 验证文件存在 expected_path = Path(self.temp_dir) / f"{filename}.pkl" assert expected_path.exists() # 加载缓存 loaded_examples = self.persistence.load_cache(filename) # 验证加载的数据 assert len(loaded_examples) == 5 # 验证数据内容 loaded_values = {ex.value for ex in loaded_examples} original_values = {ex.value for ex in self.test_examples[:5]} assert loaded_values == original_values def test_save_examples_batch(self): """测试批量保存样本""" batch_name = "test_batch" examples = self.test_examples[:7] # 保存批次 self.persistence.save_examples_batch(examples, batch_name) # 验证文件存在 expected_path = Path(self.temp_dir) / f"{batch_name}.pkl" assert expected_path.exists() # 加载并验证 loaded_examples = self.persistence.load_cache(batch_name) assert len(loaded_examples) == 7 # 验证数据完整性 for original, loaded in zip(examples, loaded_examples): assert original.action == loaded.action assert original.value == loaded.value assert original.canonical_hash == loaded.canonical_hash np.testing.assert_array_equal(original.board_state, loaded.board_state) def test_load_nonexistent_file(self): """测试加载不存在的文件""" loaded_examples = self.persistence.load_cache("nonexistent_file") assert loaded_examples == [] def test_list_saved_files(self): """测试列出保存的文件""" # 初始应该没有文件 files = self.persistence.list_saved_files() assert len(files) == 0 # 保存一些文件 cache = TrainingDataCache(max_size=100) cache.put("key1", self.test_examples[0]) self.persistence.save_cache(cache, "file1") self.persistence.save_cache(cache, "file2") self.persistence.save_examples_batch(self.test_examples[:3], "batch1") # 检查文件列表 files = self.persistence.list_saved_files() assert len(files) == 3 assert "file1" in files assert "file2" in files assert "batch1" in files def test_large_data_persistence(self): """测试大数据量的持久化""" # 创建大量测试数据 large_examples = [] for i in range(1000): board = np.random.randint(0, 17, size=(4, 4)) example = TrainingExample( board_state=board, action=i % 4, value=float(i), canonical_hash=f"hash_{i}" ) large_examples.append(example) # 保存大批次 batch_name = "large_batch" self.persistence.save_examples_batch(large_examples, batch_name) # 加载并验证 loaded_examples = self.persistence.load_cache(batch_name) assert len(loaded_examples) == 1000 # 验证一些随机样本 for i in [0, 100, 500, 999]: assert loaded_examples[i].value == float(i) assert loaded_examples[i].action == i % 4 def test_data_integrity(self): """测试数据完整性""" # 创建包含特殊值的测试数据 special_board = np.array([ [0, 1, 2, 17], # 包含边界值 [3, 4, 5, 6], [7, 8, 9, 10], [11, 12, 13, 14] ]) special_example = TrainingExample( board_state=special_board, action=3, value=12345.67, canonical_hash="special_hash_123" ) # 保存 self.persistence.save_examples_batch([special_example], "special_test") # 加载 loaded = self.persistence.load_cache("special_test") assert len(loaded) == 1 loaded_example = loaded[0] # 验证所有字段 np.testing.assert_array_equal(loaded_example.board_state, special_board) assert loaded_example.action == 3 assert abs(loaded_example.value - 12345.67) < 1e-6 assert loaded_example.canonical_hash == "special_hash_123" class TestTrainingDataManager: """训练数据管理器测试类""" def setup_method(self): """测试前的设置""" self.temp_dir = tempfile.mkdtemp() self.manager = TrainingDataManager( data_dir=self.temp_dir, cache_size=100, board_size=(4, 4) ) def teardown_method(self): """测试后的清理""" shutil.rmtree(self.temp_dir) def test_add_and_retrieve_examples(self): """测试添加和检索训练样本""" # 创建测试棋盘 board = np.array([ [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2] ]) # 添加训练样本 cache_key = self.manager.add_training_example(board, action=1, value=500.0) assert cache_key is not None # 验证缓存统计 stats = self.manager.get_cache_stats() assert stats["cache_size"] == 1 # 获取PyTorch数据集 dataset = self.manager.get_pytorch_dataset() assert len(dataset) == 1 # 验证数据集内容 board_tensor, action_tensor, value_tensor = dataset[0] assert action_tensor.item() == 1 assert abs(value_tensor.item() - 500.0) < 1e-6 def test_save_and_load_workflow(self): """测试完整的保存和加载工作流""" # 添加一些训练样本 boards = [ np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]]), np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]]), np.array([[8, 16, 32, 64], [4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8]]) ] for i, board in enumerate(boards): for action in range(4): value = (i + 1) * 100 + action * 10 self.manager.add_training_example(board, action, value) # 保存当前缓存 self.manager.save_current_cache("workflow_test") # 创建新的管理器 new_manager = TrainingDataManager( data_dir=self.temp_dir, cache_size=100, board_size=(4, 4) ) # 加载数据 loaded_count = new_manager.load_from_file("workflow_test") assert loaded_count == 12 # 3个棋盘 × 4个动作 # 验证数据 dataset = new_manager.get_pytorch_dataset() assert len(dataset) == 12 def test_merge_caches(self): """测试缓存合并功能""" # 在第一个管理器中添加数据 board1 = np.array([[2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2]]) self.manager.add_training_example(board1, 0, 100.0) self.manager.add_training_example(board1, 1, 200.0) # 创建第二个管理器 manager2 = TrainingDataManager( data_dir=self.temp_dir, cache_size=100, board_size=(4, 4) ) # 在第二个管理器中添加不同的数据 board2 = np.array([[4, 8, 16, 32], [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4]]) manager2.add_training_example(board2, 0, 300.0) manager2.add_training_example(board2, 1, 400.0) # 合并缓存 merged_count = self.manager.merge_caches(manager2) assert merged_count == 2 # 验证合并后的数据 stats = self.manager.get_cache_stats() assert stats["cache_size"] == 4 dataset = self.manager.get_pytorch_dataset() assert len(dataset) == 4 def test_pytorch_integration(self): """测试PyTorch集成""" # 添加测试数据 for i in range(10): board = np.random.randint(0, 16, size=(4, 4)) # 确保至少有一些非零值 board[0, 0] = 2 ** (i % 4 + 1) action = i % 4 value = float(i * 50) self.manager.add_training_example(board, action, value) # 获取DataLoader dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False) # 验证批次 batch_count = 0 total_samples = 0 for boards, actions, values in dataloader: batch_count += 1 batch_size = boards.shape[0] total_samples += batch_size # 验证张量形状 assert boards.shape == (batch_size, 18, 4, 4) # max_tile_value + 1 = 18 assert actions.shape == (batch_size,) assert values.shape == (batch_size,) # 验证数据类型 assert boards.dtype == torch.float32 assert actions.dtype == torch.long assert values.dtype == torch.float32 assert total_samples == 10 assert batch_count == 4 # ceil(10/3) = 4 if __name__ == "__main__": # 运行测试 print("运行持久化系统测试...") pytest.main([__file__, "-v"])