增加L0训练阶段的MCTS部分

This commit is contained in:
hisatri
2025-07-23 07:04:10 +08:00
parent 88bed2a1ef
commit 4410defbe5
23 changed files with 5205 additions and 0 deletions

100
tests/run_all_tests.py Normal file
View File

@@ -0,0 +1,100 @@
"""
统一测试运行器
运行所有测试并生成报告
"""
import pytest
import sys
import time
from pathlib import Path
def run_all_tests():
"""运行所有测试"""
print("Deep2048 项目测试套件")
print("=" * 50)
test_dir = Path(__file__).parent
# 测试文件列表
test_files = [
"test_training_data.py",
"test_game_engine.py",
"test_torch_mcts.py",
"test_board_compression.py",
"test_cache_system.py",
"test_persistence.py",
"test_performance_benchmark.py"
]
# 检查测试文件是否存在
existing_tests = []
for test_file in test_files:
test_path = test_dir / test_file
if test_path.exists():
existing_tests.append(str(test_path))
else:
print(f"警告: 测试文件不存在 {test_file}")
if not existing_tests:
print("错误: 没有找到测试文件")
return False
print(f"找到 {len(existing_tests)} 个测试文件")
# 运行测试
start_time = time.time()
# pytest参数
args = [
"-v", # 详细输出
"--tb=short", # 简短的错误回溯
"--durations=10", # 显示最慢的10个测试
] + existing_tests
result = pytest.main(args)
elapsed_time = time.time() - start_time
print(f"\n测试完成,用时: {elapsed_time:.2f}")
if result == 0:
print("✅ 所有测试通过!")
return True
else:
print("❌ 部分测试失败")
return False
def run_quick_tests():
"""运行快速测试(跳过性能测试)"""
print("快速测试模式")
print("=" * 30)
test_dir = Path(__file__).parent
args = [
"-v",
"-k", "not performance and not slow", # 跳过性能测试
str(test_dir)
]
result = pytest.main(args)
return result == 0
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="运行Deep2048测试套件")
parser.add_argument("--quick", action="store_true", help="快速测试模式")
args = parser.parse_args()
if args.quick:
success = run_quick_tests()
else:
success = run_all_tests()
sys.exit(0 if success else 1)