增加L0训练阶段的MCTS部分

This commit is contained in:
hisatri
2025-07-23 07:04:10 +08:00
parent 88bed2a1ef
commit 4410defbe5
23 changed files with 5205 additions and 0 deletions

311
tests/test_cache_system.py Normal file
View File

@@ -0,0 +1,311 @@
"""
内存缓存系统测试
验证TrainingDataCache的功能和性能
"""
import numpy as np
import pytest
import time
from training_data import TrainingDataCache, TrainingExample
class TestTrainingDataCache:
"""训练数据缓存测试类"""
def setup_method(self):
"""测试前的设置"""
self.cache = TrainingDataCache(max_size=5) # 小缓存便于测试
# 创建测试样本
self.sample_examples = []
for i in range(10):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i * 100),
canonical_hash=f"hash_{i}"
)
self.sample_examples.append(example)
def test_basic_operations(self):
"""测试基本的存取操作"""
# 测试空缓存
assert self.cache.size() == 0
assert self.cache.get("nonexistent") is None
# 添加一个项目
example = self.sample_examples[0]
self.cache.put("key1", example)
assert self.cache.size() == 1
retrieved = self.cache.get("key1")
assert retrieved is not None
assert retrieved.value == example.value
assert retrieved.action == example.action
def test_lru_eviction(self):
"""测试LRU淘汰机制"""
# 填满缓存
for i in range(5):
self.cache.put(f"key_{i}", self.sample_examples[i])
assert self.cache.size() == 5
# 访问key_1使其成为最近使用的
self.cache.get("key_1")
# 添加新项目应该淘汰key_0最久未使用
self.cache.put("key_5", self.sample_examples[5])
assert self.cache.size() == 5
assert self.cache.get("key_0") is None # 应该被淘汰
assert self.cache.get("key_1") is not None # 应该还在
assert self.cache.get("key_5") is not None # 新添加的
def test_update_existing(self):
"""测试更新现有项目"""
example1 = self.sample_examples[0]
example2 = self.sample_examples[1]
# 添加项目
self.cache.put("key1", example1)
assert self.cache.get("key1").value == example1.value
# 更新项目
self.cache.put("key1", example2)
assert self.cache.size() == 1 # 大小不变
assert self.cache.get("key1").value == example2.value # 值已更新
def test_update_if_better(self):
"""测试条件更新功能"""
low_value_example = TrainingExample(
board_state=np.zeros((4, 4)),
action=0,
value=100.0,
canonical_hash="test_hash"
)
high_value_example = TrainingExample(
board_state=np.zeros((4, 4)),
action=0,
value=200.0,
canonical_hash="test_hash"
)
# 首次添加
result = self.cache.update_if_better("key1", low_value_example)
assert result is True
assert self.cache.get("key1").value == 100.0
# 用更高价值更新
result = self.cache.update_if_better("key1", high_value_example)
assert result is True
assert self.cache.get("key1").value == 200.0
# 用更低价值尝试更新(应该失败)
result = self.cache.update_if_better("key1", low_value_example)
assert result is False
assert self.cache.get("key1").value == 200.0 # 值不变
def test_clear(self):
"""测试清空缓存"""
# 添加一些项目
for i in range(3):
self.cache.put(f"key_{i}", self.sample_examples[i])
assert self.cache.size() == 3
# 清空缓存
self.cache.clear()
assert self.cache.size() == 0
for i in range(3):
assert self.cache.get(f"key_{i}") is None
def test_get_all_examples(self):
"""测试获取所有样本"""
# 添加一些项目
added_examples = []
for i in range(3):
example = self.sample_examples[i]
self.cache.put(f"key_{i}", example)
added_examples.append(example)
# 获取所有样本
all_examples = self.cache.get_all_examples()
assert len(all_examples) == 3
# 验证所有样本都在其中(顺序可能不同)
all_values = {ex.value for ex in all_examples}
expected_values = {ex.value for ex in added_examples}
assert all_values == expected_values
def test_access_order_tracking(self):
"""测试访问顺序跟踪"""
# 先填满缓存
for i in range(5):
self.cache.put(f"key_{i}", self.sample_examples[i])
# 访问key_1使其成为最近使用的
self.cache.get("key_1")
# 访问key_3
self.cache.get("key_3")
# 现在访问顺序应该是key_0最久, key_2, key_4, key_1, key_3最新
# 添加两个新项目应该淘汰key_0和key_2
self.cache.put("key_5", self.sample_examples[5])
self.cache.put("key_6", self.sample_examples[6])
assert self.cache.get("key_0") is None # 最久的,应该被淘汰
assert self.cache.get("key_2") is None # 第二久的,应该被淘汰
assert self.cache.get("key_1") is not None # 应该还在
assert self.cache.get("key_3") is not None # 应该还在
assert self.cache.get("key_4") is not None # 应该还在
assert self.cache.get("key_5") is not None # 新添加的
assert self.cache.get("key_6") is not None # 新添加的
class TestCachePerformance:
"""缓存性能测试"""
def test_large_cache_performance(self):
"""测试大缓存的性能"""
large_cache = TrainingDataCache(max_size=10000)
# 创建大量测试数据
examples = []
for i in range(5000):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i),
canonical_hash=f"hash_{i}"
)
examples.append(example)
# 测试插入性能
start_time = time.time()
for i, example in enumerate(examples):
large_cache.put(f"key_{i}", example)
insert_time = time.time() - start_time
print(f"插入5000个项目耗时: {insert_time:.3f}")
assert insert_time < 1.0, "插入操作应该很快"
# 测试查询性能
start_time = time.time()
for i in range(1000): # 随机查询1000次
key = f"key_{np.random.randint(0, 5000)}"
large_cache.get(key)
query_time = time.time() - start_time
print(f"1000次随机查询耗时: {query_time:.3f}")
assert query_time < 0.1, "查询操作应该很快"
# 验证缓存大小
assert large_cache.size() == 5000
def test_memory_usage(self):
"""测试内存使用情况"""
import sys
cache = TrainingDataCache(max_size=1000)
# 测量空缓存的内存使用
initial_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order)
# 添加数据
for i in range(500):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(i),
canonical_hash=f"hash_{i}"
)
cache.put(f"key_{i}", example)
# 测量填充后的内存使用
filled_size = sys.getsizeof(cache.cache) + sys.getsizeof(cache.access_order)
print(f"空缓存内存使用: {initial_size} bytes")
print(f"500项目缓存内存使用: {filled_size} bytes")
print(f"平均每项目内存使用: {(filled_size - initial_size) / 500:.2f} bytes")
def test_cache_thread_safety():
"""测试缓存的线程安全性(基础测试)"""
import threading
import time
cache = TrainingDataCache(max_size=1000)
errors = []
def worker(worker_id):
"""工作线程函数"""
try:
for i in range(100):
board = np.random.randint(0, 17, size=(4, 4))
example = TrainingExample(
board_state=board,
action=i % 4,
value=float(worker_id * 100 + i),
canonical_hash=f"hash_{worker_id}_{i}"
)
key = f"worker_{worker_id}_key_{i}"
cache.put(key, example)
# 随机读取
if i % 10 == 0:
cache.get(key)
# 短暂休眠
time.sleep(0.001)
except Exception as e:
errors.append(f"Worker {worker_id}: {e}")
# 创建多个线程
threads = []
for i in range(5):
thread = threading.Thread(target=worker, args=(i,))
threads.append(thread)
thread.start()
# 等待所有线程完成
for thread in threads:
thread.join()
# 检查是否有错误
if errors:
print("线程安全测试中的错误:")
for error in errors:
print(f" {error}")
print(f"最终缓存大小: {cache.size()}")
print(f"线程安全测试完成,错误数: {len(errors)}")
if __name__ == "__main__":
# 运行基本测试
print("运行缓存系统测试...")
# 运行性能测试
print("\n运行性能测试...")
perf_test = TestCachePerformance()
perf_test.test_large_cache_performance()
perf_test.test_memory_usage()
# 运行线程安全测试
print("\n运行线程安全测试...")
test_cache_thread_safety()
# 运行pytest测试
print("\n运行pytest测试...")
pytest.main([__file__, "-v"])