""" 训练数据模块测试 测试棋盘变换、缓存系统、持久化等核心功能 """ import numpy as np import pytest import tempfile import shutil from pathlib import Path import torch from training_data import ( BoardTransform, ScoreCalculator, TrainingDataCache, TrainingExample, TrainingDataManager ) class TestBoardTransform: """棋盘变换测试""" def test_log_transform(self): """测试对数变换""" board = np.array([ [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2] ]) expected = np.array([ [1, 2, 3, 4], [0, 1, 2, 3], [0, 0, 1, 2], [0, 0, 0, 1] ]) result = BoardTransform.log_transform(board) np.testing.assert_array_equal(result, expected) # 测试逆变换 restored = BoardTransform.inverse_log_transform(result) np.testing.assert_array_equal(restored, board) def test_canonical_form(self): """测试规范形式""" board = np.array([ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16] ]) transforms = BoardTransform.get_all_transforms(board) assert len(transforms) == 8 # 所有变换的规范形式应该相同 canonical_forms = [] for transform in transforms: canonical, _ = BoardTransform.get_canonical_form(transform) canonical_forms.append(canonical) first_canonical = canonical_forms[0] for canonical in canonical_forms[1:]: np.testing.assert_array_equal(canonical, first_canonical) def test_hash_consistency(self): """测试哈希一致性""" board = np.array([[1, 2], [3, 4]]) transforms = BoardTransform.get_all_transforms(board) hashes = [BoardTransform.compute_hash(t) for t in transforms] first_hash = hashes[0] for hash_val in hashes[1:]: assert hash_val == first_hash class TestScoreCalculator: """分数计算测试""" def test_tile_value_calculation(self): """测试瓦片价值计算""" # V(2) = 0, V(4) = 4, V(8) = 16, V(16) = 48 assert ScoreCalculator.calculate_tile_value(1) == 0 # 2^1 = 2 assert ScoreCalculator.calculate_tile_value(2) == 4 # 2^2 = 4 assert ScoreCalculator.calculate_tile_value(3) == 16 # 2^3 = 8 assert ScoreCalculator.calculate_tile_value(4) == 48 # 2^4 = 16 def test_board_score_calculation(self): """测试棋盘分数计算""" log_board = np.array([ [1, 2], # 2, 4 [3, 4] # 8, 16 ]) total_score = ScoreCalculator.calculate_board_score(log_board) expected = 0 + 4 + 16 + 48 # V(2) + V(4) + V(8) + V(16) assert total_score == expected class TestTrainingDataCache: """训练数据缓存测试""" def setup_method(self): """测试前的设置""" self.cache = TrainingDataCache(max_size=5) # 创建测试样本 self.sample_examples = [] for i in range(10): board = np.random.randint(0, 17, size=(4, 4)) example = TrainingExample( board_state=board, action=i % 4, value=float(i * 100), canonical_hash=f"hash_{i}" ) self.sample_examples.append(example) def test_basic_operations(self): """测试基本操作""" assert self.cache.size() == 0 assert self.cache.get("nonexistent") is None example = self.sample_examples[0] self.cache.put("key1", example) assert self.cache.size() == 1 retrieved = self.cache.get("key1") assert retrieved is not None assert retrieved.value == example.value def test_lru_eviction(self): """测试LRU淘汰""" # 填满缓存 for i in range(5): self.cache.put(f"key_{i}", self.sample_examples[i]) assert self.cache.size() == 5 # 访问key_1 self.cache.get("key_1") # 添加新项目,应该淘汰key_0 self.cache.put("key_5", self.sample_examples[5]) assert self.cache.size() == 5 assert self.cache.get("key_0") is None assert self.cache.get("key_1") is not None assert self.cache.get("key_5") is not None class TestTrainingDataManager: """训练数据管理器测试""" def setup_method(self): """测试前的设置""" self.temp_dir = tempfile.mkdtemp() self.manager = TrainingDataManager( data_dir=self.temp_dir, cache_size=100, board_size=(4, 4) ) def teardown_method(self): """测试后的清理""" shutil.rmtree(self.temp_dir) def test_add_and_retrieve_examples(self): """测试添加和检索样本""" board = np.array([ [2, 4, 8, 16], [0, 2, 4, 8], [0, 0, 2, 4], [0, 0, 0, 2] ]) cache_key = self.manager.add_training_example(board, action=1, value=500.0) assert cache_key is not None stats = self.manager.get_cache_stats() assert stats["cache_size"] == 1 dataset = self.manager.get_pytorch_dataset() assert len(dataset) == 1 def test_pytorch_integration(self): """测试PyTorch集成""" for i in range(5): board = np.random.randint(0, 16, size=(4, 4)) board[0, 0] = 2 ** (i % 4 + 1) self.manager.add_training_example(board, i % 4, float(i * 50)) dataloader = self.manager.get_dataloader(batch_size=3, shuffle=False) for boards, actions, values in dataloader: assert boards.shape[0] <= 3 # 批次大小 assert boards.shape[1] == 18 # 通道数 (max_tile_value + 1) assert boards.shape[2:] == (4, 4) # 棋盘大小 break # 只测试第一个批次 if __name__ == "__main__": pytest.main([__file__, "-v"])