Files
deep2048/tests/test_cache_system.py
2025-07-23 07:04:10 +08:00

312 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
内存缓存系统测试
验证TrainingDataCache的功能和性能
"""
import numpy as np
import pytest
import time
from training_data import TrainingDataCache, TrainingExample
class TestTrainingDataCache:
"""训练数据缓存测试类"""
def setup_method(self):
"""测试前的设置"""
self.cache = TrainingDataCache(max_size=5) # 小缓存便于测试
# 创建测试样本
self.sample_examples = []
for i in range(10):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i * 100),
canonical_hash=f"hash_{i}"
)
self.sample_examples.append(example)
def test_basic_operations(self):
"""测试基本的存取操作"""
# 测试空缓存
assert self.cache.size() == 0
assert self.cache.get("nonexistent") is None
# 添加一个项目
example = self.sample_examples[0]
self.cache.put("key1", example)
assert self.cache.size() == 1
retrieved = self.cache.get("key1")
assert retrieved is not None
assert retrieved.value == example.value
assert retrieved.action == example.action
def test_lru_eviction(self):
"""测试LRU淘汰机制"""
# 填满缓存
for i in range(5):
self.cache.put(f"key_{i}", self.sample_examples[i])
assert self.cache.size() == 5
# 访问key_1使其成为最近使用的
self.cache.get("key_1")
# 添加新项目应该淘汰key_0最久未使用
self.cache.put("key_5", self.sample_examples[5])
assert self.cache.size() == 5
assert self.cache.get("key_0") is None # 应该被淘汰
assert self.cache.get("key_1") is not None # 应该还在
assert self.cache.get("key_5") is not None # 新添加的
def test_update_existing(self):
"""测试更新现有项目"""
example1 = self.sample_examples[0]
example2 = self.sample_examples[1]
# 添加项目
self.cache.put("key1", example1)
assert self.cache.get("key1").value == example1.value
# 更新项目
self.cache.put("key1", example2)
assert self.cache.size() == 1 # 大小不变
assert self.cache.get("key1").value == example2.value # 值已更新
def test_update_if_better(self):
"""测试条件更新功能"""
low_value_example = TrainingExample(
board_state=np.zeros((4, 4)),
action=0,
value=100.0,
canonical_hash="test_hash"
)
high_value_example = TrainingExample(
board_state=np.zeros((4, 4)),
action=0,
value=200.0,
canonical_hash="test_hash"
)
# 首次添加
result = self.cache.update_if_better("key1", low_value_example)
assert result is True
assert self.cache.get("key1").value == 100.0
# 用更高价值更新
result = self.cache.update_if_better("key1", high_value_example)
assert result is True
assert self.cache.get("key1").value == 200.0
# 用更低价值尝试更新(应该失败)
result = self.cache.update_if_better("key1", low_value_example)
assert result is False
assert self.cache.get("key1").value == 200.0 # 值不变
def test_clear(self):
"""测试清空缓存"""
# 添加一些项目
for i in range(3):
self.cache.put(f"key_{i}", self.sample_examples[i])
assert self.cache.size() == 3
# 清空缓存
self.cache.clear()
assert self.cache.size() == 0
for i in range(3):
assert self.cache.get(f"key_{i}") is None
def test_get_all_examples(self):
"""测试获取所有样本"""
# 添加一些项目
added_examples = []
for i in range(3):
example = self.sample_examples[i]
self.cache.put(f"key_{i}", example)
added_examples.append(example)
# 获取所有样本
all_examples = self.cache.get_all_examples()
assert len(all_examples) == 3
# 验证所有样本都在其中(顺序可能不同)
all_values = {ex.value for ex in all_examples}
expected_values = {ex.value for ex in added_examples}
assert all_values == expected_values
def test_access_order_tracking(self):
"""测试访问顺序跟踪"""
# 先填满缓存
for i in range(5):
self.cache.put(f"key_{i}", self.sample_examples[i])
# 访问key_1使其成为最近使用的
self.cache.get("key_1")
# 访问key_3
self.cache.get("key_3")
# 现在访问顺序应该是key_0最久, key_2, key_4, key_1, key_3最新
# 添加两个新项目应该淘汰key_0和key_2
self.cache.put("key_5", self.sample_examples[5])
self.cache.put("key_6", self.sample_examples[6])
assert self.cache.get("key_0") is None # 最久的,应该被淘汰
assert self.cache.get("key_2") is None # 第二久的,应该被淘汰
assert self.cache.get("key_1") is not None # 应该还在
assert self.cache.get("key_3") is not None # 应该还在
assert self.cache.get("key_4") is not None # 应该还在
assert self.cache.get("key_5") is not None # 新添加的
assert self.cache.get("key_6") is not None # 新添加的
class TestCachePerformance:
"""缓存性能测试"""
def test_large_cache_performance(self):
"""测试大缓存的性能"""
large_cache = TrainingDataCache(max_size=10000)
# 创建大量测试数据
examples = []
for i in range(5000):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i),
canonical_hash=f"hash_{i}"
)
examples.append(example)
# 测试插入性能
start_time = time.time()
for i, example in enumerate(examples):
large_cache.put(f"key_{i}", example)
insert_time = time.time() - start_time
print(f"插入5000个项目耗时: {insert_time:.3f}")
assert insert_time < 1.0, "插入操作应该很快"
# 测试查询性能
start_time = time.time()
for i in range(1000): # 随机查询1000次
key = f"key_{np.random.randint(0, 5000)}"
large_cache.get(key)
query_time = time.time() - start_time
print(f"1000次随机查询耗时: {query_time:.3f}")
assert query_time < 0.1, "查询操作应该很快"
# 验证缓存大小
assert large_cache.size() == 5000
def test_memory_usage(self):
"""测试内存使用情况"""
import sys
cache = TrainingDataCache(max_size=1000)
# 测量空缓存的内存使用
initial_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order)
# 添加数据
for i in range(500):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i),
canonical_hash=f"hash_{i}"
)
cache.put(f"key_{i}", example)
# 测量填充后的内存使用
filled_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order)
print(f"空缓存内存使用: {initial_size} bytes")
print(f"500项目缓存内存使用: {filled_size} bytes")
print(f"平均每项目内存使用: {(filled_size - initial_size) / 500:.2f} bytes")
def test_cache_thread_safety():
"""测试缓存的线程安全性(基础测试)"""
import threading
import time
cache = TrainingDataCache(max_size=1000)
errors = []
def worker(worker_id):
"""工作线程函数"""
try:
for i in range(100):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(worker_id * 100 + i),
canonical_hash=f"hash_{worker_id}_{i}"
)
key = f"worker_{worker_id}_key_{i}"
cache.put(key, example)
# 随机读取
if i % 10 == 0:
cache.get(key)
# 短暂休眠
time.sleep(0.001)
except Exception as e:
errors.append(f"Worker {worker_id}: {e}")
# 创建多个线程
threads = []
for i in range(5):
thread = threading.Thread(target=worker, args=(i,))
threads.append(thread)
thread.start()
# 等待所有线程完成
for thread in threads:
thread.join()
# 检查是否有错误
if errors:
print("线程安全测试中的错误:")
for error in errors:
print(f" {error}")
print(f"最终缓存大小: {cache.size()}")
print(f"线程安全测试完成,错误数: {len(errors)}")
if __name__ == "__main__":
# 运行基本测试
print("运行缓存系统测试...")
# 运行性能测试
print("\n运行性能测试...")
perf_test = TestCachePerformance()
perf_test.test_large_cache_performance()
perf_test.test_memory_usage()
# 运行线程安全测试
print("\n运行线程安全测试...")
test_cache_thread_safety()
# 运行pytest测试
print("\n运行pytest测试...")
pytest.main([__file__, "-v"])