增加L0训练阶段的MCTS部分

This commit is contained in:
hisatri
2025-07-23 07:04:10 +08:00
parent 88bed2a1ef
commit 4410defbe5
23 changed files with 5205 additions and 0 deletions

362
torch_mcts.py Normal file
View File

@@ -0,0 +1,362 @@
"""
PyTorch MCTS实现
统一的MCTS模块支持CPU和GPU加速基于PyTorch实现
"""
import torch
import time
import numpy as np
from typing import Tuple, Dict, List, Optional
from game import Game2048
from training_data import TrainingDataManager
class TorchMCTS:
"""
基于PyTorch的统一MCTS实现
特点:
1. 支持CPU和GPU加速
2. 自动选择最优设备和批次大小
3. 真正的MCTS算法实现
4. 高效的内存管理
"""
def __init__(self,
c_param: float = 1.414,
max_simulation_depth: int = 50,
batch_size: int = None,
device: str = "auto",
training_manager: Optional[TrainingDataManager] = None):
"""
初始化TorchMCTS
Args:
c_param: UCT探索常数
max_simulation_depth: 最大模拟深度
batch_size: 批处理大小None为自动选择
device: 计算设备("auto", "cpu", "cuda"
training_manager: 训练数据管理器
"""
self.c_param = c_param
self.max_simulation_depth = max_simulation_depth
self.training_manager = training_manager
# 设备选择
self.device = self._select_device(device)
# 批次大小选择
self.batch_size = self._select_batch_size(batch_size)
# 统计信息
self.total_simulations = 0
self.total_time = 0.0
print(f"TorchMCTS初始化:")
print(f" 设备: {self.device}")
print(f" 批次大小: {self.batch_size:,}")
if self.device.type == "cuda":
print(f" GPU: {torch.cuda.get_device_name()}")
def _select_device(self, device: str) -> torch.device:
"""选择计算设备"""
if device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
elif device == "cuda":
if torch.cuda.is_available():
return torch.device("cuda")
else:
print("⚠️ CUDA不可用回退到CPU")
return torch.device("cpu")
else:
return torch.device(device)
def _select_batch_size(self, batch_size: Optional[int]) -> int:
"""选择批次大小"""
if batch_size is not None:
return batch_size
# 根据设备自动选择批次大小
if self.device.type == "cuda":
# GPU: 使用大批次以充分利用并行能力
return 32768
else:
# CPU: 使用适中批次避免内存压力
return 4096
def search(self, game: Game2048, num_simulations: int) -> Tuple[int, Dict]:
"""
执行MCTS搜索
Args:
game: 游戏状态
num_simulations: 模拟次数
Returns:
(最佳动作, 搜索统计)
"""
start_time = time.time()
# 获取有效动作
valid_actions = game.get_valid_moves()
if not valid_actions:
return -1, {}
# 在指定设备上初始化统计张量
action_visits = torch.zeros(4, device=self.device, dtype=torch.long)
action_values = torch.zeros(4, device=self.device, dtype=torch.float32)
# 批量处理
num_batches = (num_simulations + self.batch_size - 1) // self.batch_size
# 同步开始GPU
if self.device.type == "cuda":
torch.cuda.synchronize()
for batch_idx in range(num_batches):
current_batch_size = min(self.batch_size, num_simulations - batch_idx * self.batch_size)
# 执行批量模拟
batch_actions, batch_values = self._batch_simulate(
game, valid_actions, current_batch_size
)
# 累积统计
self._accumulate_stats(batch_actions, batch_values, action_visits, action_values)
# 同步结束GPU
if self.device.type == "cuda":
torch.cuda.synchronize()
# 选择最佳动作(访问次数最多的)
valid_action_tensor = torch.tensor(valid_actions, device=self.device)
valid_visits = action_visits[valid_action_tensor]
best_idx = torch.argmax(valid_visits)
best_action = valid_actions[best_idx.item()]
# 批量写入训练数据
if self.training_manager:
self._write_training_data(game, action_visits, action_values, valid_actions)
# 计算统计信息
elapsed_time = time.time() - start_time
self.total_simulations += num_simulations
self.total_time += elapsed_time
stats = {
'action_visits': {i: action_visits[i].item() for i in valid_actions},
'action_avg_values': {
i: (action_values[i] / max(1, action_visits[i])).item()
for i in valid_actions
},
'search_time': elapsed_time,
'sims_per_second': num_simulations / elapsed_time if elapsed_time > 0 else 0,
'batch_size': self.batch_size,
'num_batches': num_batches,
'device': str(self.device)
}
return best_action, stats
def _batch_simulate(self, game: Game2048, valid_actions: List[int],
batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
批量模拟
Args:
game: 初始游戏状态
valid_actions: 有效动作列表
batch_size: 批次大小
Returns:
(动作张量, 价值张量)
"""
# 在指定设备上生成随机动作
valid_actions_tensor = torch.tensor(valid_actions, device=self.device)
action_indices = torch.randint(0, len(valid_actions), (batch_size,), device=self.device)
batch_actions = valid_actions_tensor[action_indices]
# 执行批量rollout
batch_values = self._batch_rollout(game, batch_actions)
return batch_actions, batch_values
def _batch_rollout(self, initial_game: Game2048, first_actions: torch.Tensor) -> torch.Tensor:
"""
批量rollout
Args:
initial_game: 初始游戏状态
first_actions: 第一步动作
Returns:
最终分数张量
"""
batch_size = first_actions.shape[0]
final_scores = torch.zeros(batch_size, device=self.device)
# 转移到CPU进行游戏模拟游戏逻辑在CPU上更高效
first_actions_cpu = first_actions.cpu().numpy()
# 批量执行rollout
for i in range(batch_size):
# 复制游戏状态
game_copy = initial_game.copy()
# 执行第一步动作
first_action = int(first_actions_cpu[i])
if not game_copy.move(first_action):
# 如果第一步动作无效,使用当前分数
final_scores[i] = game_copy.score
continue
# 执行随机rollout
depth = 0
while not game_copy.is_over and depth < self.max_simulation_depth:
valid_moves = game_copy.get_valid_moves()
if not valid_moves:
break
# 随机选择动作
action = np.random.choice(valid_moves)
if not game_copy.move(action):
break
depth += 1
# 记录最终分数
final_scores[i] = game_copy.score
return final_scores
def _accumulate_stats(self, batch_actions: torch.Tensor, batch_values: torch.Tensor,
action_visits: torch.Tensor, action_values: torch.Tensor):
"""
累积统计信息
Args:
batch_actions: 批次动作
batch_values: 批次价值
action_visits: 动作访问计数
action_values: 动作价值累积
"""
# 使用PyTorch的高效操作进行统计累积
for action in range(4):
mask = (batch_actions == action)
action_visits[action] += mask.sum()
action_values[action] += batch_values[mask].sum()
def _write_training_data(self, game: Game2048, action_visits: torch.Tensor,
action_values: torch.Tensor, valid_actions: List[int]):
"""写入训练数据"""
if not self.training_manager:
return
# 转换到CPU进行训练数据写入
visits_cpu = action_visits.cpu().numpy()
values_cpu = action_values.cpu().numpy()
for action in valid_actions:
visits = int(visits_cpu[action])
if visits > 0:
avg_value = float(values_cpu[action] / visits)
# 根据访问次数按比例添加样本
sample_ratio = min(1.0, 1000.0 / visits)
sample_count = max(1, int(visits * sample_ratio))
for _ in range(sample_count):
self.training_manager.add_training_example(
board_state=game.board,
action=action,
value=avg_value
)
def get_statistics(self) -> Dict[str, float]:
"""获取搜索统计信息"""
if self.total_simulations == 0:
return {"simulations": 0, "avg_time_per_sim": 0.0, "sims_per_second": 0.0}
avg_time = self.total_time / self.total_simulations
sims_per_sec = self.total_simulations / self.total_time if self.total_time > 0 else 0
stats = {
"total_simulations": self.total_simulations,
"total_time": self.total_time,
"avg_time_per_sim": avg_time,
"sims_per_second": sims_per_sec,
"device": str(self.device),
"batch_size": self.batch_size
}
if self.device.type == "cuda":
stats["gpu_memory_allocated"] = torch.cuda.memory_allocated() / 1e6 # MB
stats["gpu_memory_reserved"] = torch.cuda.memory_reserved() / 1e6 # MB
return stats
def set_device(self, device: str):
"""动态切换设备"""
new_device = self._select_device(device)
if new_device != self.device:
self.device = new_device
self.batch_size = self._select_batch_size(None) # 重新选择批次大小
print(f"设备切换到: {self.device}, 批次大小: {self.batch_size:,}")
def optimize_batch_size(self, game: Game2048, test_simulations: int = 2000) -> int:
"""
自动优化批次大小
Args:
game: 测试游戏状态
test_simulations: 测试模拟次数
Returns:
最优批次大小
"""
print("🔧 自动优化批次大小...")
if self.device.type == "cuda":
test_sizes = [4096, 8192, 16384, 32768, 65536]
else:
test_sizes = [1024, 2048, 4096, 8192]
best_size = self.batch_size
best_speed = 0
for size in test_sizes:
try:
# 临时设置批次大小
old_size = self.batch_size
self.batch_size = size
if self.device.type == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
start_time = time.time()
action, stats = self.search(game.copy(), test_simulations)
elapsed_time = time.time() - start_time
speed = test_simulations / elapsed_time
print(f" 批次 {size:,}: {speed:.0f} 模拟/秒")
if speed > best_speed:
best_speed = speed
best_size = size
# 恢复原批次大小
self.batch_size = old_size
except Exception as e:
print(f" 批次 {size:,}: 失败 ({e})")
# 设置最优批次大小
self.batch_size = best_size
print(f"🎯 最优批次大小: {best_size:,} ({best_speed:.0f} 模拟/秒)")
return best_size