""" PyTorch MCTS实现 统一的MCTS模块,支持CPU和GPU加速,基于PyTorch实现 """ import torch import time import numpy as np from typing import Tuple, Dict, List, Optional from game import Game2048 from training_data import TrainingDataManager class TorchMCTS: """ 基于PyTorch的统一MCTS实现 特点: 1. 支持CPU和GPU加速 2. 自动选择最优设备和批次大小 3. 真正的MCTS算法实现 4. 高效的内存管理 """ def __init__(self, c_param: float = 1.414, max_simulation_depth: int = 50, batch_size: int = None, device: str = "auto", training_manager: Optional[TrainingDataManager] = None): """ 初始化TorchMCTS Args: c_param: UCT探索常数 max_simulation_depth: 最大模拟深度 batch_size: 批处理大小(None为自动选择) device: 计算设备("auto", "cpu", "cuda") training_manager: 训练数据管理器 """ self.c_param = c_param self.max_simulation_depth = max_simulation_depth self.training_manager = training_manager # 设备选择 self.device = self._select_device(device) # 批次大小选择 self.batch_size = self._select_batch_size(batch_size) # 统计信息 self.total_simulations = 0 self.total_time = 0.0 print(f"TorchMCTS初始化:") print(f" 设备: {self.device}") print(f" 批次大小: {self.batch_size:,}") if self.device.type == "cuda": print(f" GPU: {torch.cuda.get_device_name()}") def _select_device(self, device: str) -> torch.device: """选择计算设备""" if device == "auto": if torch.cuda.is_available(): return torch.device("cuda") else: return torch.device("cpu") elif device == "cuda": if torch.cuda.is_available(): return torch.device("cuda") else: print("⚠️ CUDA不可用,回退到CPU") return torch.device("cpu") else: return torch.device(device) def _select_batch_size(self, batch_size: Optional[int]) -> int: """选择批次大小""" if batch_size is not None: return batch_size # 根据设备自动选择批次大小 if self.device.type == "cuda": # GPU: 使用大批次以充分利用并行能力 return 32768 else: # CPU: 使用适中批次避免内存压力 return 4096 def search(self, game: Game2048, num_simulations: int) -> Tuple[int, Dict]: """ 执行MCTS搜索 Args: game: 游戏状态 num_simulations: 模拟次数 Returns: (最佳动作, 搜索统计) """ start_time = time.time() # 获取有效动作 valid_actions = game.get_valid_moves() if not valid_actions: return -1, {} # 在指定设备上初始化统计张量 action_visits = torch.zeros(4, device=self.device, dtype=torch.long) action_values = torch.zeros(4, device=self.device, dtype=torch.float32) # 批量处理 num_batches = (num_simulations + self.batch_size - 1) // self.batch_size # 同步开始(GPU) if self.device.type == "cuda": torch.cuda.synchronize() for batch_idx in range(num_batches): current_batch_size = min(self.batch_size, num_simulations - batch_idx * self.batch_size) # 执行批量模拟 batch_actions, batch_values = self._batch_simulate( game, valid_actions, current_batch_size ) # 累积统计 self._accumulate_stats(batch_actions, batch_values, action_visits, action_values) # 同步结束(GPU) if self.device.type == "cuda": torch.cuda.synchronize() # 选择最佳动作(访问次数最多的) valid_action_tensor = torch.tensor(valid_actions, device=self.device) valid_visits = action_visits[valid_action_tensor] best_idx = torch.argmax(valid_visits) best_action = valid_actions[best_idx.item()] # 批量写入训练数据 if self.training_manager: self._write_training_data(game, action_visits, action_values, valid_actions) # 计算统计信息 elapsed_time = time.time() - start_time self.total_simulations += num_simulations self.total_time += elapsed_time stats = { 'action_visits': {i: action_visits[i].item() for i in valid_actions}, 'action_avg_values': { i: (action_values[i] / max(1, action_visits[i])).item() for i in valid_actions }, 'search_time': elapsed_time, 'sims_per_second': num_simulations / elapsed_time if elapsed_time > 0 else 0, 'batch_size': self.batch_size, 'num_batches': num_batches, 'device': str(self.device) } return best_action, stats def _batch_simulate(self, game: Game2048, valid_actions: List[int], batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: """ 批量模拟 Args: game: 初始游戏状态 valid_actions: 有效动作列表 batch_size: 批次大小 Returns: (动作张量, 价值张量) """ # 在指定设备上生成随机动作 valid_actions_tensor = torch.tensor(valid_actions, device=self.device) action_indices = torch.randint(0, len(valid_actions), (batch_size,), device=self.device) batch_actions = valid_actions_tensor[action_indices] # 执行批量rollout batch_values = self._batch_rollout(game, batch_actions) return batch_actions, batch_values def _batch_rollout(self, initial_game: Game2048, first_actions: torch.Tensor) -> torch.Tensor: """ 批量rollout Args: initial_game: 初始游戏状态 first_actions: 第一步动作 Returns: 最终分数张量 """ batch_size = first_actions.shape[0] final_scores = torch.zeros(batch_size, device=self.device) # 转移到CPU进行游戏模拟(游戏逻辑在CPU上更高效) first_actions_cpu = first_actions.cpu().numpy() # 批量执行rollout for i in range(batch_size): # 复制游戏状态 game_copy = initial_game.copy() # 执行第一步动作 first_action = int(first_actions_cpu[i]) if not game_copy.move(first_action): # 如果第一步动作无效,使用当前分数 final_scores[i] = game_copy.score continue # 执行随机rollout depth = 0 while not game_copy.is_over and depth < self.max_simulation_depth: valid_moves = game_copy.get_valid_moves() if not valid_moves: break # 随机选择动作 action = np.random.choice(valid_moves) if not game_copy.move(action): break depth += 1 # 记录最终分数 final_scores[i] = game_copy.score return final_scores def _accumulate_stats(self, batch_actions: torch.Tensor, batch_values: torch.Tensor, action_visits: torch.Tensor, action_values: torch.Tensor): """ 累积统计信息 Args: batch_actions: 批次动作 batch_values: 批次价值 action_visits: 动作访问计数 action_values: 动作价值累积 """ # 使用PyTorch的高效操作进行统计累积 for action in range(4): mask = (batch_actions == action) action_visits[action] += mask.sum() action_values[action] += batch_values[mask].sum() def _write_training_data(self, game: Game2048, action_visits: torch.Tensor, action_values: torch.Tensor, valid_actions: List[int]): """写入训练数据""" if not self.training_manager: return # 转换到CPU进行训练数据写入 visits_cpu = action_visits.cpu().numpy() values_cpu = action_values.cpu().numpy() for action in valid_actions: visits = int(visits_cpu[action]) if visits > 0: avg_value = float(values_cpu[action] / visits) # 根据访问次数按比例添加样本 sample_ratio = min(1.0, 1000.0 / visits) sample_count = max(1, int(visits * sample_ratio)) for _ in range(sample_count): self.training_manager.add_training_example( board_state=game.board, action=action, value=avg_value ) def get_statistics(self) -> Dict[str, float]: """获取搜索统计信息""" if self.total_simulations == 0: return {"simulations": 0, "avg_time_per_sim": 0.0, "sims_per_second": 0.0} avg_time = self.total_time / self.total_simulations sims_per_sec = self.total_simulations / self.total_time if self.total_time > 0 else 0 stats = { "total_simulations": self.total_simulations, "total_time": self.total_time, "avg_time_per_sim": avg_time, "sims_per_second": sims_per_sec, "device": str(self.device), "batch_size": self.batch_size } if self.device.type == "cuda": stats["gpu_memory_allocated"] = torch.cuda.memory_allocated() / 1e6 # MB stats["gpu_memory_reserved"] = torch.cuda.memory_reserved() / 1e6 # MB return stats def set_device(self, device: str): """动态切换设备""" new_device = self._select_device(device) if new_device != self.device: self.device = new_device self.batch_size = self._select_batch_size(None) # 重新选择批次大小 print(f"设备切换到: {self.device}, 批次大小: {self.batch_size:,}") def optimize_batch_size(self, game: Game2048, test_simulations: int = 2000) -> int: """ 自动优化批次大小 Args: game: 测试游戏状态 test_simulations: 测试模拟次数 Returns: 最优批次大小 """ print("🔧 自动优化批次大小...") if self.device.type == "cuda": test_sizes = [4096, 8192, 16384, 32768, 65536] else: test_sizes = [1024, 2048, 4096, 8192] best_size = self.batch_size best_speed = 0 for size in test_sizes: try: # 临时设置批次大小 old_size = self.batch_size self.batch_size = size if self.device.type == "cuda": torch.cuda.empty_cache() torch.cuda.synchronize() start_time = time.time() action, stats = self.search(game.copy(), test_simulations) elapsed_time = time.time() - start_time speed = test_simulations / elapsed_time print(f" 批次 {size:,}: {speed:.0f} 模拟/秒") if speed > best_speed: best_speed = speed best_size = size # 恢复原批次大小 self.batch_size = old_size except Exception as e: print(f" 批次 {size:,}: 失败 ({e})") # 设置最优批次大小 self.batch_size = best_size print(f"🎯 最优批次大小: {best_size:,} ({best_speed:.0f} 模拟/秒)") return best_size