增加L0训练阶段的MCTS部分
This commit is contained in:
362
torch_mcts.py
Normal file
362
torch_mcts.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
PyTorch MCTS实现
|
||||
|
||||
统一的MCTS模块,支持CPU和GPU加速,基于PyTorch实现
|
||||
"""
|
||||
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Tuple, Dict, List, Optional
|
||||
from game import Game2048
|
||||
from training_data import TrainingDataManager
|
||||
|
||||
|
||||
class TorchMCTS:
|
||||
"""
|
||||
基于PyTorch的统一MCTS实现
|
||||
|
||||
特点:
|
||||
1. 支持CPU和GPU加速
|
||||
2. 自动选择最优设备和批次大小
|
||||
3. 真正的MCTS算法实现
|
||||
4. 高效的内存管理
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
c_param: float = 1.414,
|
||||
max_simulation_depth: int = 50,
|
||||
batch_size: int = None,
|
||||
device: str = "auto",
|
||||
training_manager: Optional[TrainingDataManager] = None):
|
||||
"""
|
||||
初始化TorchMCTS
|
||||
|
||||
Args:
|
||||
c_param: UCT探索常数
|
||||
max_simulation_depth: 最大模拟深度
|
||||
batch_size: 批处理大小(None为自动选择)
|
||||
device: 计算设备("auto", "cpu", "cuda")
|
||||
training_manager: 训练数据管理器
|
||||
"""
|
||||
self.c_param = c_param
|
||||
self.max_simulation_depth = max_simulation_depth
|
||||
self.training_manager = training_manager
|
||||
|
||||
# 设备选择
|
||||
self.device = self._select_device(device)
|
||||
|
||||
# 批次大小选择
|
||||
self.batch_size = self._select_batch_size(batch_size)
|
||||
|
||||
# 统计信息
|
||||
self.total_simulations = 0
|
||||
self.total_time = 0.0
|
||||
|
||||
print(f"TorchMCTS初始化:")
|
||||
print(f" 设备: {self.device}")
|
||||
print(f" 批次大小: {self.batch_size:,}")
|
||||
if self.device.type == "cuda":
|
||||
print(f" GPU: {torch.cuda.get_device_name()}")
|
||||
|
||||
def _select_device(self, device: str) -> torch.device:
|
||||
"""选择计算设备"""
|
||||
if device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
elif device == "cuda":
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
else:
|
||||
print("⚠️ CUDA不可用,回退到CPU")
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
return torch.device(device)
|
||||
|
||||
def _select_batch_size(self, batch_size: Optional[int]) -> int:
|
||||
"""选择批次大小"""
|
||||
if batch_size is not None:
|
||||
return batch_size
|
||||
|
||||
# 根据设备自动选择批次大小
|
||||
if self.device.type == "cuda":
|
||||
# GPU: 使用大批次以充分利用并行能力
|
||||
return 32768
|
||||
else:
|
||||
# CPU: 使用适中批次避免内存压力
|
||||
return 4096
|
||||
|
||||
def search(self, game: Game2048, num_simulations: int) -> Tuple[int, Dict]:
|
||||
"""
|
||||
执行MCTS搜索
|
||||
|
||||
Args:
|
||||
game: 游戏状态
|
||||
num_simulations: 模拟次数
|
||||
|
||||
Returns:
|
||||
(最佳动作, 搜索统计)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 获取有效动作
|
||||
valid_actions = game.get_valid_moves()
|
||||
if not valid_actions:
|
||||
return -1, {}
|
||||
|
||||
# 在指定设备上初始化统计张量
|
||||
action_visits = torch.zeros(4, device=self.device, dtype=torch.long)
|
||||
action_values = torch.zeros(4, device=self.device, dtype=torch.float32)
|
||||
|
||||
# 批量处理
|
||||
num_batches = (num_simulations + self.batch_size - 1) // self.batch_size
|
||||
|
||||
# 同步开始(GPU)
|
||||
if self.device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
current_batch_size = min(self.batch_size, num_simulations - batch_idx * self.batch_size)
|
||||
|
||||
# 执行批量模拟
|
||||
batch_actions, batch_values = self._batch_simulate(
|
||||
game, valid_actions, current_batch_size
|
||||
)
|
||||
|
||||
# 累积统计
|
||||
self._accumulate_stats(batch_actions, batch_values, action_visits, action_values)
|
||||
|
||||
# 同步结束(GPU)
|
||||
if self.device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# 选择最佳动作(访问次数最多的)
|
||||
valid_action_tensor = torch.tensor(valid_actions, device=self.device)
|
||||
valid_visits = action_visits[valid_action_tensor]
|
||||
best_idx = torch.argmax(valid_visits)
|
||||
best_action = valid_actions[best_idx.item()]
|
||||
|
||||
# 批量写入训练数据
|
||||
if self.training_manager:
|
||||
self._write_training_data(game, action_visits, action_values, valid_actions)
|
||||
|
||||
# 计算统计信息
|
||||
elapsed_time = time.time() - start_time
|
||||
self.total_simulations += num_simulations
|
||||
self.total_time += elapsed_time
|
||||
|
||||
stats = {
|
||||
'action_visits': {i: action_visits[i].item() for i in valid_actions},
|
||||
'action_avg_values': {
|
||||
i: (action_values[i] / max(1, action_visits[i])).item()
|
||||
for i in valid_actions
|
||||
},
|
||||
'search_time': elapsed_time,
|
||||
'sims_per_second': num_simulations / elapsed_time if elapsed_time > 0 else 0,
|
||||
'batch_size': self.batch_size,
|
||||
'num_batches': num_batches,
|
||||
'device': str(self.device)
|
||||
}
|
||||
|
||||
return best_action, stats
|
||||
|
||||
def _batch_simulate(self, game: Game2048, valid_actions: List[int],
|
||||
batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
批量模拟
|
||||
|
||||
Args:
|
||||
game: 初始游戏状态
|
||||
valid_actions: 有效动作列表
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
(动作张量, 价值张量)
|
||||
"""
|
||||
# 在指定设备上生成随机动作
|
||||
valid_actions_tensor = torch.tensor(valid_actions, device=self.device)
|
||||
action_indices = torch.randint(0, len(valid_actions), (batch_size,), device=self.device)
|
||||
batch_actions = valid_actions_tensor[action_indices]
|
||||
|
||||
# 执行批量rollout
|
||||
batch_values = self._batch_rollout(game, batch_actions)
|
||||
|
||||
return batch_actions, batch_values
|
||||
|
||||
def _batch_rollout(self, initial_game: Game2048, first_actions: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
批量rollout
|
||||
|
||||
Args:
|
||||
initial_game: 初始游戏状态
|
||||
first_actions: 第一步动作
|
||||
|
||||
Returns:
|
||||
最终分数张量
|
||||
"""
|
||||
batch_size = first_actions.shape[0]
|
||||
final_scores = torch.zeros(batch_size, device=self.device)
|
||||
|
||||
# 转移到CPU进行游戏模拟(游戏逻辑在CPU上更高效)
|
||||
first_actions_cpu = first_actions.cpu().numpy()
|
||||
|
||||
# 批量执行rollout
|
||||
for i in range(batch_size):
|
||||
# 复制游戏状态
|
||||
game_copy = initial_game.copy()
|
||||
|
||||
# 执行第一步动作
|
||||
first_action = int(first_actions_cpu[i])
|
||||
if not game_copy.move(first_action):
|
||||
# 如果第一步动作无效,使用当前分数
|
||||
final_scores[i] = game_copy.score
|
||||
continue
|
||||
|
||||
# 执行随机rollout
|
||||
depth = 0
|
||||
while not game_copy.is_over and depth < self.max_simulation_depth:
|
||||
valid_moves = game_copy.get_valid_moves()
|
||||
if not valid_moves:
|
||||
break
|
||||
|
||||
# 随机选择动作
|
||||
action = np.random.choice(valid_moves)
|
||||
if not game_copy.move(action):
|
||||
break
|
||||
|
||||
depth += 1
|
||||
|
||||
# 记录最终分数
|
||||
final_scores[i] = game_copy.score
|
||||
|
||||
return final_scores
|
||||
|
||||
def _accumulate_stats(self, batch_actions: torch.Tensor, batch_values: torch.Tensor,
|
||||
action_visits: torch.Tensor, action_values: torch.Tensor):
|
||||
"""
|
||||
累积统计信息
|
||||
|
||||
Args:
|
||||
batch_actions: 批次动作
|
||||
batch_values: 批次价值
|
||||
action_visits: 动作访问计数
|
||||
action_values: 动作价值累积
|
||||
"""
|
||||
# 使用PyTorch的高效操作进行统计累积
|
||||
for action in range(4):
|
||||
mask = (batch_actions == action)
|
||||
action_visits[action] += mask.sum()
|
||||
action_values[action] += batch_values[mask].sum()
|
||||
|
||||
def _write_training_data(self, game: Game2048, action_visits: torch.Tensor,
|
||||
action_values: torch.Tensor, valid_actions: List[int]):
|
||||
"""写入训练数据"""
|
||||
if not self.training_manager:
|
||||
return
|
||||
|
||||
# 转换到CPU进行训练数据写入
|
||||
visits_cpu = action_visits.cpu().numpy()
|
||||
values_cpu = action_values.cpu().numpy()
|
||||
|
||||
for action in valid_actions:
|
||||
visits = int(visits_cpu[action])
|
||||
if visits > 0:
|
||||
avg_value = float(values_cpu[action] / visits)
|
||||
|
||||
# 根据访问次数按比例添加样本
|
||||
sample_ratio = min(1.0, 1000.0 / visits)
|
||||
sample_count = max(1, int(visits * sample_ratio))
|
||||
|
||||
for _ in range(sample_count):
|
||||
self.training_manager.add_training_example(
|
||||
board_state=game.board,
|
||||
action=action,
|
||||
value=avg_value
|
||||
)
|
||||
|
||||
def get_statistics(self) -> Dict[str, float]:
|
||||
"""获取搜索统计信息"""
|
||||
if self.total_simulations == 0:
|
||||
return {"simulations": 0, "avg_time_per_sim": 0.0, "sims_per_second": 0.0}
|
||||
|
||||
avg_time = self.total_time / self.total_simulations
|
||||
sims_per_sec = self.total_simulations / self.total_time if self.total_time > 0 else 0
|
||||
|
||||
stats = {
|
||||
"total_simulations": self.total_simulations,
|
||||
"total_time": self.total_time,
|
||||
"avg_time_per_sim": avg_time,
|
||||
"sims_per_second": sims_per_sec,
|
||||
"device": str(self.device),
|
||||
"batch_size": self.batch_size
|
||||
}
|
||||
|
||||
if self.device.type == "cuda":
|
||||
stats["gpu_memory_allocated"] = torch.cuda.memory_allocated() / 1e6 # MB
|
||||
stats["gpu_memory_reserved"] = torch.cuda.memory_reserved() / 1e6 # MB
|
||||
|
||||
return stats
|
||||
|
||||
def set_device(self, device: str):
|
||||
"""动态切换设备"""
|
||||
new_device = self._select_device(device)
|
||||
if new_device != self.device:
|
||||
self.device = new_device
|
||||
self.batch_size = self._select_batch_size(None) # 重新选择批次大小
|
||||
print(f"设备切换到: {self.device}, 批次大小: {self.batch_size:,}")
|
||||
|
||||
def optimize_batch_size(self, game: Game2048, test_simulations: int = 2000) -> int:
|
||||
"""
|
||||
自动优化批次大小
|
||||
|
||||
Args:
|
||||
game: 测试游戏状态
|
||||
test_simulations: 测试模拟次数
|
||||
|
||||
Returns:
|
||||
最优批次大小
|
||||
"""
|
||||
print("🔧 自动优化批次大小...")
|
||||
|
||||
if self.device.type == "cuda":
|
||||
test_sizes = [4096, 8192, 16384, 32768, 65536]
|
||||
else:
|
||||
test_sizes = [1024, 2048, 4096, 8192]
|
||||
|
||||
best_size = self.batch_size
|
||||
best_speed = 0
|
||||
|
||||
for size in test_sizes:
|
||||
try:
|
||||
# 临时设置批次大小
|
||||
old_size = self.batch_size
|
||||
self.batch_size = size
|
||||
|
||||
if self.device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_time = time.time()
|
||||
action, stats = self.search(game.copy(), test_simulations)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
speed = test_simulations / elapsed_time
|
||||
print(f" 批次 {size:,}: {speed:.0f} 模拟/秒")
|
||||
|
||||
if speed > best_speed:
|
||||
best_speed = speed
|
||||
best_size = size
|
||||
|
||||
# 恢复原批次大小
|
||||
self.batch_size = old_size
|
||||
|
||||
except Exception as e:
|
||||
print(f" 批次 {size:,}: 失败 ({e})")
|
||||
|
||||
# 设置最优批次大小
|
||||
self.batch_size = best_size
|
||||
print(f"🎯 最优批次大小: {best_size:,} ({best_speed:.0f} 模拟/秒)")
|
||||
|
||||
return best_size
|
||||
Reference in New Issue
Block a user