增加L0训练阶段的MCTS部分
This commit is contained in:
311
tests/test_cache_system.py
Normal file
311
tests/test_cache_system.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
内存缓存系统测试
|
||||
|
||||
验证TrainingDataCache的功能和性能
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import time
|
||||
from training_data import TrainingDataCache, TrainingExample
|
||||
|
||||
|
||||
class TestTrainingDataCache:
|
||||
"""训练数据缓存测试类"""
|
||||
|
||||
def setup_method(self):
|
||||
"""测试前的设置"""
|
||||
self.cache = TrainingDataCache(max_size=5) # 小缓存便于测试
|
||||
|
||||
# 创建测试样本
|
||||
self.sample_examples = []
|
||||
for i in range(10):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(i * 100),
|
||||
canonical_hash=f"hash_{i}"
|
||||
)
|
||||
self.sample_examples.append(example)
|
||||
|
||||
def test_basic_operations(self):
|
||||
"""测试基本的存取操作"""
|
||||
# 测试空缓存
|
||||
assert self.cache.size() == 0
|
||||
assert self.cache.get("nonexistent") is None
|
||||
|
||||
# 添加一个项目
|
||||
example = self.sample_examples[0]
|
||||
self.cache.put("key1", example)
|
||||
|
||||
assert self.cache.size() == 1
|
||||
retrieved = self.cache.get("key1")
|
||||
assert retrieved is not None
|
||||
assert retrieved.value == example.value
|
||||
assert retrieved.action == example.action
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""测试LRU淘汰机制"""
|
||||
# 填满缓存
|
||||
for i in range(5):
|
||||
self.cache.put(f"key_{i}", self.sample_examples[i])
|
||||
|
||||
assert self.cache.size() == 5
|
||||
|
||||
# 访问key_1,使其成为最近使用的
|
||||
self.cache.get("key_1")
|
||||
|
||||
# 添加新项目,应该淘汰key_0(最久未使用)
|
||||
self.cache.put("key_5", self.sample_examples[5])
|
||||
|
||||
assert self.cache.size() == 5
|
||||
assert self.cache.get("key_0") is None # 应该被淘汰
|
||||
assert self.cache.get("key_1") is not None # 应该还在
|
||||
assert self.cache.get("key_5") is not None # 新添加的
|
||||
|
||||
def test_update_existing(self):
|
||||
"""测试更新现有项目"""
|
||||
example1 = self.sample_examples[0]
|
||||
example2 = self.sample_examples[1]
|
||||
|
||||
# 添加项目
|
||||
self.cache.put("key1", example1)
|
||||
assert self.cache.get("key1").value == example1.value
|
||||
|
||||
# 更新项目
|
||||
self.cache.put("key1", example2)
|
||||
assert self.cache.size() == 1 # 大小不变
|
||||
assert self.cache.get("key1").value == example2.value # 值已更新
|
||||
|
||||
def test_update_if_better(self):
|
||||
"""测试条件更新功能"""
|
||||
low_value_example = TrainingExample(
|
||||
board_state=np.zeros((4, 4)),
|
||||
action=0,
|
||||
value=100.0,
|
||||
canonical_hash="test_hash"
|
||||
)
|
||||
|
||||
high_value_example = TrainingExample(
|
||||
board_state=np.zeros((4, 4)),
|
||||
action=0,
|
||||
value=200.0,
|
||||
canonical_hash="test_hash"
|
||||
)
|
||||
|
||||
# 首次添加
|
||||
result = self.cache.update_if_better("key1", low_value_example)
|
||||
assert result is True
|
||||
assert self.cache.get("key1").value == 100.0
|
||||
|
||||
# 用更高价值更新
|
||||
result = self.cache.update_if_better("key1", high_value_example)
|
||||
assert result is True
|
||||
assert self.cache.get("key1").value == 200.0
|
||||
|
||||
# 用更低价值尝试更新(应该失败)
|
||||
result = self.cache.update_if_better("key1", low_value_example)
|
||||
assert result is False
|
||||
assert self.cache.get("key1").value == 200.0 # 值不变
|
||||
|
||||
def test_clear(self):
|
||||
"""测试清空缓存"""
|
||||
# 添加一些项目
|
||||
for i in range(3):
|
||||
self.cache.put(f"key_{i}", self.sample_examples[i])
|
||||
|
||||
assert self.cache.size() == 3
|
||||
|
||||
# 清空缓存
|
||||
self.cache.clear()
|
||||
|
||||
assert self.cache.size() == 0
|
||||
for i in range(3):
|
||||
assert self.cache.get(f"key_{i}") is None
|
||||
|
||||
def test_get_all_examples(self):
|
||||
"""测试获取所有样本"""
|
||||
# 添加一些项目
|
||||
added_examples = []
|
||||
for i in range(3):
|
||||
example = self.sample_examples[i]
|
||||
self.cache.put(f"key_{i}", example)
|
||||
added_examples.append(example)
|
||||
|
||||
# 获取所有样本
|
||||
all_examples = self.cache.get_all_examples()
|
||||
|
||||
assert len(all_examples) == 3
|
||||
|
||||
# 验证所有样本都在其中(顺序可能不同)
|
||||
all_values = {ex.value for ex in all_examples}
|
||||
expected_values = {ex.value for ex in added_examples}
|
||||
assert all_values == expected_values
|
||||
|
||||
def test_access_order_tracking(self):
|
||||
"""测试访问顺序跟踪"""
|
||||
# 先填满缓存
|
||||
for i in range(5):
|
||||
self.cache.put(f"key_{i}", self.sample_examples[i])
|
||||
|
||||
# 访问key_1,使其成为最近使用的
|
||||
self.cache.get("key_1")
|
||||
|
||||
# 访问key_3
|
||||
self.cache.get("key_3")
|
||||
|
||||
# 现在访问顺序应该是:key_0(最久), key_2, key_4, key_1, key_3(最新)
|
||||
|
||||
# 添加两个新项目,应该淘汰key_0和key_2
|
||||
self.cache.put("key_5", self.sample_examples[5])
|
||||
self.cache.put("key_6", self.sample_examples[6])
|
||||
|
||||
assert self.cache.get("key_0") is None # 最久的,应该被淘汰
|
||||
assert self.cache.get("key_2") is None # 第二久的,应该被淘汰
|
||||
assert self.cache.get("key_1") is not None # 应该还在
|
||||
assert self.cache.get("key_3") is not None # 应该还在
|
||||
assert self.cache.get("key_4") is not None # 应该还在
|
||||
assert self.cache.get("key_5") is not None # 新添加的
|
||||
assert self.cache.get("key_6") is not None # 新添加的
|
||||
|
||||
|
||||
class TestCachePerformance:
|
||||
"""缓存性能测试"""
|
||||
|
||||
def test_large_cache_performance(self):
|
||||
"""测试大缓存的性能"""
|
||||
large_cache = TrainingDataCache(max_size=10000)
|
||||
|
||||
# 创建大量测试数据
|
||||
examples = []
|
||||
for i in range(5000):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(i),
|
||||
canonical_hash=f"hash_{i}"
|
||||
)
|
||||
examples.append(example)
|
||||
|
||||
# 测试插入性能
|
||||
start_time = time.time()
|
||||
for i, example in enumerate(examples):
|
||||
large_cache.put(f"key_{i}", example)
|
||||
insert_time = time.time() - start_time
|
||||
|
||||
print(f"插入5000个项目耗时: {insert_time:.3f}秒")
|
||||
assert insert_time < 1.0, "插入操作应该很快"
|
||||
|
||||
# 测试查询性能
|
||||
start_time = time.time()
|
||||
for i in range(1000): # 随机查询1000次
|
||||
key = f"key_{np.random.randint(0, 5000)}"
|
||||
large_cache.get(key)
|
||||
query_time = time.time() - start_time
|
||||
|
||||
print(f"1000次随机查询耗时: {query_time:.3f}秒")
|
||||
assert query_time < 0.1, "查询操作应该很快"
|
||||
|
||||
# 验证缓存大小
|
||||
assert large_cache.size() == 5000
|
||||
|
||||
def test_memory_usage(self):
|
||||
"""测试内存使用情况"""
|
||||
import sys
|
||||
|
||||
cache = TrainingDataCache(max_size=1000)
|
||||
|
||||
# 测量空缓存的内存使用
|
||||
initial_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order)
|
||||
|
||||
# 添加数据
|
||||
for i in range(500):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(i),
|
||||
canonical_hash=f"hash_{i}"
|
||||
)
|
||||
cache.put(f"key_{i}", example)
|
||||
|
||||
# 测量填充后的内存使用
|
||||
filled_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order)
|
||||
|
||||
print(f"空缓存内存使用: {initial_size} bytes")
|
||||
print(f"500项目缓存内存使用: {filled_size} bytes")
|
||||
print(f"平均每项目内存使用: {(filled_size - initial_size) / 500:.2f} bytes")
|
||||
|
||||
|
||||
def test_cache_thread_safety():
|
||||
"""测试缓存的线程安全性(基础测试)"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
cache = TrainingDataCache(max_size=1000)
|
||||
errors = []
|
||||
|
||||
def worker(worker_id):
|
||||
"""工作线程函数"""
|
||||
try:
|
||||
for i in range(100):
|
||||
board = np.random.randint(0, 17, size=(4, 4))
|
||||
example = TrainingExample(
|
||||
board_state=board,
|
||||
action=i % 4,
|
||||
value=float(worker_id * 100 + i),
|
||||
canonical_hash=f"hash_{worker_id}_{i}"
|
||||
)
|
||||
|
||||
key = f"worker_{worker_id}_key_{i}"
|
||||
cache.put(key, example)
|
||||
|
||||
# 随机读取
|
||||
if i % 10 == 0:
|
||||
cache.get(key)
|
||||
|
||||
# 短暂休眠
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(f"Worker {worker_id}: {e}")
|
||||
|
||||
# 创建多个线程
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=worker, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# 等待所有线程完成
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# 检查是否有错误
|
||||
if errors:
|
||||
print("线程安全测试中的错误:")
|
||||
for error in errors:
|
||||
print(f" {error}")
|
||||
|
||||
print(f"最终缓存大小: {cache.size()}")
|
||||
print(f"线程安全测试完成,错误数: {len(errors)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行基本测试
|
||||
print("运行缓存系统测试...")
|
||||
|
||||
# 运行性能测试
|
||||
print("\n运行性能测试...")
|
||||
perf_test = TestCachePerformance()
|
||||
perf_test.test_large_cache_performance()
|
||||
perf_test.test_memory_usage()
|
||||
|
||||
# 运行线程安全测试
|
||||
print("\n运行线程安全测试...")
|
||||
test_cache_thread_safety()
|
||||
|
||||
# 运行pytest测试
|
||||
print("\n运行pytest测试...")
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user