363 lines
13 KiB
Python
363 lines
13 KiB
Python
"""
|
||
PyTorch MCTS实现
|
||
|
||
统一的MCTS模块,支持CPU和GPU加速,基于PyTorch实现
|
||
"""
|
||
|
||
import torch
|
||
import time
|
||
import numpy as np
|
||
from typing import Tuple, Dict, List, Optional
|
||
from game import Game2048
|
||
from training_data import TrainingDataManager
|
||
|
||
|
||
class TorchMCTS:
|
||
"""
|
||
基于PyTorch的统一MCTS实现
|
||
|
||
特点:
|
||
1. 支持CPU和GPU加速
|
||
2. 自动选择最优设备和批次大小
|
||
3. 真正的MCTS算法实现
|
||
4. 高效的内存管理
|
||
"""
|
||
|
||
def __init__(self,
|
||
c_param: float = 1.414,
|
||
max_simulation_depth: int = 50,
|
||
batch_size: int = None,
|
||
device: str = "auto",
|
||
training_manager: Optional[TrainingDataManager] = None):
|
||
"""
|
||
初始化TorchMCTS
|
||
|
||
Args:
|
||
c_param: UCT探索常数
|
||
max_simulation_depth: 最大模拟深度
|
||
batch_size: 批处理大小(None为自动选择)
|
||
device: 计算设备("auto", "cpu", "cuda")
|
||
training_manager: 训练数据管理器
|
||
"""
|
||
self.c_param = c_param
|
||
self.max_simulation_depth = max_simulation_depth
|
||
self.training_manager = training_manager
|
||
|
||
# 设备选择
|
||
self.device = self._select_device(device)
|
||
|
||
# 批次大小选择
|
||
self.batch_size = self._select_batch_size(batch_size)
|
||
|
||
# 统计信息
|
||
self.total_simulations = 0
|
||
self.total_time = 0.0
|
||
|
||
print(f"TorchMCTS初始化:")
|
||
print(f" 设备: {self.device}")
|
||
print(f" 批次大小: {self.batch_size:,}")
|
||
if self.device.type == "cuda":
|
||
print(f" GPU: {torch.cuda.get_device_name()}")
|
||
|
||
def _select_device(self, device: str) -> torch.device:
|
||
"""选择计算设备"""
|
||
if device == "auto":
|
||
if torch.cuda.is_available():
|
||
return torch.device("cuda")
|
||
else:
|
||
return torch.device("cpu")
|
||
elif device == "cuda":
|
||
if torch.cuda.is_available():
|
||
return torch.device("cuda")
|
||
else:
|
||
print("⚠️ CUDA不可用,回退到CPU")
|
||
return torch.device("cpu")
|
||
else:
|
||
return torch.device(device)
|
||
|
||
def _select_batch_size(self, batch_size: Optional[int]) -> int:
|
||
"""选择批次大小"""
|
||
if batch_size is not None:
|
||
return batch_size
|
||
|
||
# 根据设备自动选择批次大小
|
||
if self.device.type == "cuda":
|
||
# GPU: 使用大批次以充分利用并行能力
|
||
return 32768
|
||
else:
|
||
# CPU: 使用适中批次避免内存压力
|
||
return 4096
|
||
|
||
def search(self, game: Game2048, num_simulations: int) -> Tuple[int, Dict]:
|
||
"""
|
||
执行MCTS搜索
|
||
|
||
Args:
|
||
game: 游戏状态
|
||
num_simulations: 模拟次数
|
||
|
||
Returns:
|
||
(最佳动作, 搜索统计)
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 获取有效动作
|
||
valid_actions = game.get_valid_moves()
|
||
if not valid_actions:
|
||
return -1, {}
|
||
|
||
# 在指定设备上初始化统计张量
|
||
action_visits = torch.zeros(4, device=self.device, dtype=torch.long)
|
||
action_values = torch.zeros(4, device=self.device, dtype=torch.float32)
|
||
|
||
# 批量处理
|
||
num_batches = (num_simulations + self.batch_size - 1) // self.batch_size
|
||
|
||
# 同步开始(GPU)
|
||
if self.device.type == "cuda":
|
||
torch.cuda.synchronize()
|
||
|
||
for batch_idx in range(num_batches):
|
||
current_batch_size = min(self.batch_size, num_simulations - batch_idx * self.batch_size)
|
||
|
||
# 执行批量模拟
|
||
batch_actions, batch_values = self._batch_simulate(
|
||
game, valid_actions, current_batch_size
|
||
)
|
||
|
||
# 累积统计
|
||
self._accumulate_stats(batch_actions, batch_values, action_visits, action_values)
|
||
|
||
# 同步结束(GPU)
|
||
if self.device.type == "cuda":
|
||
torch.cuda.synchronize()
|
||
|
||
# 选择最佳动作(访问次数最多的)
|
||
valid_action_tensor = torch.tensor(valid_actions, device=self.device)
|
||
valid_visits = action_visits[valid_action_tensor]
|
||
best_idx = torch.argmax(valid_visits)
|
||
best_action = valid_actions[best_idx.item()]
|
||
|
||
# 批量写入训练数据
|
||
if self.training_manager:
|
||
self._write_training_data(game, action_visits, action_values, valid_actions)
|
||
|
||
# 计算统计信息
|
||
elapsed_time = time.time() - start_time
|
||
self.total_simulations += num_simulations
|
||
self.total_time += elapsed_time
|
||
|
||
stats = {
|
||
'action_visits': {i: action_visits[i].item() for i in valid_actions},
|
||
'action_avg_values': {
|
||
i: (action_values[i] / max(1, action_visits[i])).item()
|
||
for i in valid_actions
|
||
},
|
||
'search_time': elapsed_time,
|
||
'sims_per_second': num_simulations / elapsed_time if elapsed_time > 0 else 0,
|
||
'batch_size': self.batch_size,
|
||
'num_batches': num_batches,
|
||
'device': str(self.device)
|
||
}
|
||
|
||
return best_action, stats
|
||
|
||
def _batch_simulate(self, game: Game2048, valid_actions: List[int],
|
||
batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
批量模拟
|
||
|
||
Args:
|
||
game: 初始游戏状态
|
||
valid_actions: 有效动作列表
|
||
batch_size: 批次大小
|
||
|
||
Returns:
|
||
(动作张量, 价值张量)
|
||
"""
|
||
# 在指定设备上生成随机动作
|
||
valid_actions_tensor = torch.tensor(valid_actions, device=self.device)
|
||
action_indices = torch.randint(0, len(valid_actions), (batch_size,), device=self.device)
|
||
batch_actions = valid_actions_tensor[action_indices]
|
||
|
||
# 执行批量rollout
|
||
batch_values = self._batch_rollout(game, batch_actions)
|
||
|
||
return batch_actions, batch_values
|
||
|
||
def _batch_rollout(self, initial_game: Game2048, first_actions: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
批量rollout
|
||
|
||
Args:
|
||
initial_game: 初始游戏状态
|
||
first_actions: 第一步动作
|
||
|
||
Returns:
|
||
最终分数张量
|
||
"""
|
||
batch_size = first_actions.shape[0]
|
||
final_scores = torch.zeros(batch_size, device=self.device)
|
||
|
||
# 转移到CPU进行游戏模拟(游戏逻辑在CPU上更高效)
|
||
first_actions_cpu = first_actions.cpu().numpy()
|
||
|
||
# 批量执行rollout
|
||
for i in range(batch_size):
|
||
# 复制游戏状态
|
||
game_copy = initial_game.copy()
|
||
|
||
# 执行第一步动作
|
||
first_action = int(first_actions_cpu[i])
|
||
if not game_copy.move(first_action):
|
||
# 如果第一步动作无效,使用当前分数
|
||
final_scores[i] = game_copy.score
|
||
continue
|
||
|
||
# 执行随机rollout
|
||
depth = 0
|
||
while not game_copy.is_over and depth < self.max_simulation_depth:
|
||
valid_moves = game_copy.get_valid_moves()
|
||
if not valid_moves:
|
||
break
|
||
|
||
# 随机选择动作
|
||
action = np.random.choice(valid_moves)
|
||
if not game_copy.move(action):
|
||
break
|
||
|
||
depth += 1
|
||
|
||
# 记录最终分数
|
||
final_scores[i] = game_copy.score
|
||
|
||
return final_scores
|
||
|
||
def _accumulate_stats(self, batch_actions: torch.Tensor, batch_values: torch.Tensor,
|
||
action_visits: torch.Tensor, action_values: torch.Tensor):
|
||
"""
|
||
累积统计信息
|
||
|
||
Args:
|
||
batch_actions: 批次动作
|
||
batch_values: 批次价值
|
||
action_visits: 动作访问计数
|
||
action_values: 动作价值累积
|
||
"""
|
||
# 使用PyTorch的高效操作进行统计累积
|
||
for action in range(4):
|
||
mask = (batch_actions == action)
|
||
action_visits[action] += mask.sum()
|
||
action_values[action] += batch_values[mask].sum()
|
||
|
||
def _write_training_data(self, game: Game2048, action_visits: torch.Tensor,
|
||
action_values: torch.Tensor, valid_actions: List[int]):
|
||
"""写入训练数据"""
|
||
if not self.training_manager:
|
||
return
|
||
|
||
# 转换到CPU进行训练数据写入
|
||
visits_cpu = action_visits.cpu().numpy()
|
||
values_cpu = action_values.cpu().numpy()
|
||
|
||
for action in valid_actions:
|
||
visits = int(visits_cpu[action])
|
||
if visits > 0:
|
||
avg_value = float(values_cpu[action] / visits)
|
||
|
||
# 根据访问次数按比例添加样本
|
||
sample_ratio = min(1.0, 1000.0 / visits)
|
||
sample_count = max(1, int(visits * sample_ratio))
|
||
|
||
for _ in range(sample_count):
|
||
self.training_manager.add_training_example(
|
||
board_state=game.board,
|
||
action=action,
|
||
value=avg_value
|
||
)
|
||
|
||
def get_statistics(self) -> Dict[str, float]:
|
||
"""获取搜索统计信息"""
|
||
if self.total_simulations == 0:
|
||
return {"simulations": 0, "avg_time_per_sim": 0.0, "sims_per_second": 0.0}
|
||
|
||
avg_time = self.total_time / self.total_simulations
|
||
sims_per_sec = self.total_simulations / self.total_time if self.total_time > 0 else 0
|
||
|
||
stats = {
|
||
"total_simulations": self.total_simulations,
|
||
"total_time": self.total_time,
|
||
"avg_time_per_sim": avg_time,
|
||
"sims_per_second": sims_per_sec,
|
||
"device": str(self.device),
|
||
"batch_size": self.batch_size
|
||
}
|
||
|
||
if self.device.type == "cuda":
|
||
stats["gpu_memory_allocated"] = torch.cuda.memory_allocated() / 1e6 # MB
|
||
stats["gpu_memory_reserved"] = torch.cuda.memory_reserved() / 1e6 # MB
|
||
|
||
return stats
|
||
|
||
def set_device(self, device: str):
|
||
"""动态切换设备"""
|
||
new_device = self._select_device(device)
|
||
if new_device != self.device:
|
||
self.device = new_device
|
||
self.batch_size = self._select_batch_size(None) # 重新选择批次大小
|
||
print(f"设备切换到: {self.device}, 批次大小: {self.batch_size:,}")
|
||
|
||
def optimize_batch_size(self, game: Game2048, test_simulations: int = 2000) -> int:
|
||
"""
|
||
自动优化批次大小
|
||
|
||
Args:
|
||
game: 测试游戏状态
|
||
test_simulations: 测试模拟次数
|
||
|
||
Returns:
|
||
最优批次大小
|
||
"""
|
||
print("🔧 自动优化批次大小...")
|
||
|
||
if self.device.type == "cuda":
|
||
test_sizes = [4096, 8192, 16384, 32768, 65536]
|
||
else:
|
||
test_sizes = [1024, 2048, 4096, 8192]
|
||
|
||
best_size = self.batch_size
|
||
best_speed = 0
|
||
|
||
for size in test_sizes:
|
||
try:
|
||
# 临时设置批次大小
|
||
old_size = self.batch_size
|
||
self.batch_size = size
|
||
|
||
if self.device.type == "cuda":
|
||
torch.cuda.empty_cache()
|
||
torch.cuda.synchronize()
|
||
|
||
start_time = time.time()
|
||
action, stats = self.search(game.copy(), test_simulations)
|
||
elapsed_time = time.time() - start_time
|
||
|
||
speed = test_simulations / elapsed_time
|
||
print(f" 批次 {size:,}: {speed:.0f} 模拟/秒")
|
||
|
||
if speed > best_speed:
|
||
best_speed = speed
|
||
best_size = size
|
||
|
||
# 恢复原批次大小
|
||
self.batch_size = old_size
|
||
|
||
except Exception as e:
|
||
print(f" 批次 {size:,}: 失败 ({e})")
|
||
|
||
# 设置最优批次大小
|
||
self.batch_size = best_size
|
||
print(f"🎯 最优批次大小: {best_size:,} ({best_speed:.0f} 模拟/秒)")
|
||
|
||
return best_size
|