""" 内存缓存系统测试 验证TrainingDataCache的功能和性能 """ import numpy as np import pytest import time from training_data import TrainingDataCache, TrainingExample class TestTrainingDataCache: """训练数据缓存测试类""" def setup_method(self): """测试前的设置""" self.cache = TrainingDataCache(max_size=5) # 小缓存便于测试 # 创建测试样本 self.sample_examples = [] for i in range(10): board = np.random.randint(0, 17, size=(4, 4)) example = TrainingExample( board_state=board, action=i % 4, value=float(i * 100), canonical_hash=f"hash_{i}" ) self.sample_examples.append(example) def test_basic_operations(self): """测试基本的存取操作""" # 测试空缓存 assert self.cache.size() == 0 assert self.cache.get("nonexistent") is None # 添加一个项目 example = self.sample_examples[0] self.cache.put("key1", example) assert self.cache.size() == 1 retrieved = self.cache.get("key1") assert retrieved is not None assert retrieved.value == example.value assert retrieved.action == example.action def test_lru_eviction(self): """测试LRU淘汰机制""" # 填满缓存 for i in range(5): self.cache.put(f"key_{i}", self.sample_examples[i]) assert self.cache.size() == 5 # 访问key_1,使其成为最近使用的 self.cache.get("key_1") # 添加新项目,应该淘汰key_0(最久未使用) self.cache.put("key_5", self.sample_examples[5]) assert self.cache.size() == 5 assert self.cache.get("key_0") is None # 应该被淘汰 assert self.cache.get("key_1") is not None # 应该还在 assert self.cache.get("key_5") is not None # 新添加的 def test_update_existing(self): """测试更新现有项目""" example1 = self.sample_examples[0] example2 = self.sample_examples[1] # 添加项目 self.cache.put("key1", example1) assert self.cache.get("key1").value == example1.value # 更新项目 self.cache.put("key1", example2) assert self.cache.size() == 1 # 大小不变 assert self.cache.get("key1").value == example2.value # 值已更新 def test_update_if_better(self): """测试条件更新功能""" low_value_example = TrainingExample( board_state=np.zeros((4, 4)), action=0, value=100.0, canonical_hash="test_hash" ) high_value_example = TrainingExample( board_state=np.zeros((4, 4)), action=0, value=200.0, canonical_hash="test_hash" ) # 首次添加 result = self.cache.update_if_better("key1", low_value_example) assert result is True assert self.cache.get("key1").value == 100.0 # 用更高价值更新 result = self.cache.update_if_better("key1", high_value_example) assert result is True assert self.cache.get("key1").value == 200.0 # 用更低价值尝试更新(应该失败) result = self.cache.update_if_better("key1", low_value_example) assert result is False assert self.cache.get("key1").value == 200.0 # 值不变 def test_clear(self): """测试清空缓存""" # 添加一些项目 for i in range(3): self.cache.put(f"key_{i}", self.sample_examples[i]) assert self.cache.size() == 3 # 清空缓存 self.cache.clear() assert self.cache.size() == 0 for i in range(3): assert self.cache.get(f"key_{i}") is None def test_get_all_examples(self): """测试获取所有样本""" # 添加一些项目 added_examples = [] for i in range(3): example = self.sample_examples[i] self.cache.put(f"key_{i}", example) added_examples.append(example) # 获取所有样本 all_examples = self.cache.get_all_examples() assert len(all_examples) == 3 # 验证所有样本都在其中(顺序可能不同) all_values = {ex.value for ex in all_examples} expected_values = {ex.value for ex in added_examples} assert all_values == expected_values def test_access_order_tracking(self): """测试访问顺序跟踪""" # 先填满缓存 for i in range(5): self.cache.put(f"key_{i}", self.sample_examples[i]) # 访问key_1,使其成为最近使用的 self.cache.get("key_1") # 访问key_3 self.cache.get("key_3") # 现在访问顺序应该是:key_0(最久), key_2, key_4, key_1, key_3(最新) # 添加两个新项目,应该淘汰key_0和key_2 self.cache.put("key_5", self.sample_examples[5]) self.cache.put("key_6", self.sample_examples[6]) assert self.cache.get("key_0") is None # 最久的,应该被淘汰 assert self.cache.get("key_2") is None # 第二久的,应该被淘汰 assert self.cache.get("key_1") is not None # 应该还在 assert self.cache.get("key_3") is not None # 应该还在 assert self.cache.get("key_4") is not None # 应该还在 assert self.cache.get("key_5") is not None # 新添加的 assert self.cache.get("key_6") is not None # 新添加的 class TestCachePerformance: """缓存性能测试""" def test_large_cache_performance(self): """测试大缓存的性能""" large_cache = TrainingDataCache(max_size=10000) # 创建大量测试数据 examples = [] for i in range(5000): board = np.random.randint(0, 17, size=(4, 4)) example = TrainingExample( board_state=board, action=i % 4, value=float(i), canonical_hash=f"hash_{i}" ) examples.append(example) # 测试插入性能 start_time = time.time() for i, example in enumerate(examples): large_cache.put(f"key_{i}", example) insert_time = time.time() - start_time print(f"插入5000个项目耗时: {insert_time:.3f}秒") assert insert_time < 1.0, "插入操作应该很快" # 测试查询性能 start_time = time.time() for i in range(1000): # 随机查询1000次 key = f"key_{np.random.randint(0, 5000)}" large_cache.get(key) query_time = time.time() - start_time print(f"1000次随机查询耗时: {query_time:.3f}秒") assert query_time < 0.1, "查询操作应该很快" # 验证缓存大小 assert large_cache.size() == 5000 def test_memory_usage(self): """测试内存使用情况""" import sys cache = TrainingDataCache(max_size=1000) # 测量空缓存的内存使用 initial_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order) # 添加数据 for i in range(500): board = np.random.randint(0, 17, size=(4, 4)) example = TrainingExample( board_state=board, action=i % 4, value=float(i), canonical_hash=f"hash_{i}" ) cache.put(f"key_{i}", example) # 测量填充后的内存使用 filled_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order) print(f"空缓存内存使用: {initial_size} bytes") print(f"500项目缓存内存使用: {filled_size} bytes") print(f"平均每项目内存使用: {(filled_size - initial_size) / 500:.2f} bytes") def test_cache_thread_safety(): """测试缓存的线程安全性(基础测试)""" import threading import time cache = TrainingDataCache(max_size=1000) errors = [] def worker(worker_id): """工作线程函数""" try: for i in range(100): board = np.random.randint(0, 17, size=(4, 4)) example = TrainingExample( board_state=board, action=i % 4, value=float(worker_id * 100 + i), canonical_hash=f"hash_{worker_id}_{i}" ) key = f"worker_{worker_id}_key_{i}" cache.put(key, example) # 随机读取 if i % 10 == 0: cache.get(key) # 短暂休眠 time.sleep(0.001) except Exception as e: errors.append(f"Worker {worker_id}: {e}") # 创建多个线程 threads = [] for i in range(5): thread = threading.Thread(target=worker, args=(i,)) threads.append(thread) thread.start() # 等待所有线程完成 for thread in threads: thread.join() # 检查是否有错误 if errors: print("线程安全测试中的错误:") for error in errors: print(f" {error}") print(f"最终缓存大小: {cache.size()}") print(f"线程安全测试完成,错误数: {len(errors)}") if __name__ == "__main__": # 运行基本测试 print("运行缓存系统测试...") # 运行性能测试 print("\n运行性能测试...") perf_test = TestCachePerformance() perf_test.test_large_cache_performance() perf_test.test_memory_usage() # 运行线程安全测试 print("\n运行线程安全测试...") test_cache_thread_safety() # 运行pytest测试 print("\n运行pytest测试...") pytest.main([__file__, "-v"])