Files
deep2048/torch_mcts.py
2025-07-23 07:04:10 +08:00

363 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
PyTorch MCTS实现
统一的MCTS模块支持CPU和GPU加速基于PyTorch实现
"""
import torch
import time
import numpy as np
from typing import Tuple, Dict, List, Optional
from game import Game2048
from training_data import TrainingDataManager
class TorchMCTS:
"""
基于PyTorch的统一MCTS实现
特点:
1. 支持CPU和GPU加速
2. 自动选择最优设备和批次大小
3. 真正的MCTS算法实现
4. 高效的内存管理
"""
def __init__(self,
c_param: float = 1.414,
max_simulation_depth: int = 50,
batch_size: int = None,
device: str = "auto",
training_manager: Optional[TrainingDataManager] = None):
"""
初始化TorchMCTS
Args:
c_param: UCT探索常数
max_simulation_depth: 最大模拟深度
batch_size: 批处理大小None为自动选择
device: 计算设备("auto", "cpu", "cuda"
training_manager: 训练数据管理器
"""
self.c_param = c_param
self.max_simulation_depth = max_simulation_depth
self.training_manager = training_manager
# 设备选择
self.device = self._select_device(device)
# 批次大小选择
self.batch_size = self._select_batch_size(batch_size)
# 统计信息
self.total_simulations = 0
self.total_time = 0.0
print(f"TorchMCTS初始化:")
print(f" 设备: {self.device}")
print(f" 批次大小: {self.batch_size:,}")
if self.device.type == "cuda":
print(f" GPU: {torch.cuda.get_device_name()}")
def _select_device(self, device: str) -> torch.device:
"""选择计算设备"""
if device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
elif device == "cuda":
if torch.cuda.is_available():
return torch.device("cuda")
else:
print("⚠️ CUDA不可用回退到CPU")
return torch.device("cpu")
else:
return torch.device(device)
def _select_batch_size(self, batch_size: Optional[int]) -> int:
"""选择批次大小"""
if batch_size is not None:
return batch_size
# 根据设备自动选择批次大小
if self.device.type == "cuda":
# GPU: 使用大批次以充分利用并行能力
return 32768
else:
# CPU: 使用适中批次避免内存压力
return 4096
def search(self, game: Game2048, num_simulations: int) -> Tuple[int, Dict]:
"""
执行MCTS搜索
Args:
game: 游戏状态
num_simulations: 模拟次数
Returns:
(最佳动作, 搜索统计)
"""
start_time = time.time()
# 获取有效动作
valid_actions = game.get_valid_moves()
if not valid_actions:
return -1, {}
# 在指定设备上初始化统计张量
action_visits = torch.zeros(4, device=self.device, dtype=torch.long)
action_values = torch.zeros(4, device=self.device, dtype=torch.float32)
# 批量处理
num_batches = (num_simulations + self.batch_size - 1) // self.batch_size
# 同步开始GPU
if self.device.type == "cuda":
torch.cuda.synchronize()
for batch_idx in range(num_batches):
current_batch_size = min(self.batch_size, num_simulations - batch_idx * self.batch_size)
# 执行批量模拟
batch_actions, batch_values = self._batch_simulate(
game, valid_actions, current_batch_size
)
# 累积统计
self._accumulate_stats(batch_actions, batch_values, action_visits, action_values)
# 同步结束GPU
if self.device.type == "cuda":
torch.cuda.synchronize()
# 选择最佳动作(访问次数最多的)
valid_action_tensor = torch.tensor(valid_actions, device=self.device)
valid_visits = action_visits[valid_action_tensor]
best_idx = torch.argmax(valid_visits)
best_action = valid_actions[best_idx.item()]
# 批量写入训练数据
if self.training_manager:
self._write_training_data(game, action_visits, action_values, valid_actions)
# 计算统计信息
elapsed_time = time.time() - start_time
self.total_simulations += num_simulations
self.total_time += elapsed_time
stats = {
'action_visits': {i: action_visits[i].item() for i in valid_actions},
'action_avg_values': {
i: (action_values[i] / max(1, action_visits[i])).item()
for i in valid_actions
},
'search_time': elapsed_time,
'sims_per_second': num_simulations / elapsed_time if elapsed_time > 0 else 0,
'batch_size': self.batch_size,
'num_batches': num_batches,
'device': str(self.device)
}
return best_action, stats
def _batch_simulate(self, game: Game2048, valid_actions: List[int],
batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
批量模拟
Args:
game: 初始游戏状态
valid_actions: 有效动作列表
batch_size: 批次大小
Returns:
(动作张量, 价值张量)
"""
# 在指定设备上生成随机动作
valid_actions_tensor = torch.tensor(valid_actions, device=self.device)
action_indices = torch.randint(0, len(valid_actions), (batch_size,), device=self.device)
batch_actions = valid_actions_tensor[action_indices]
# 执行批量rollout
batch_values = self._batch_rollout(game, batch_actions)
return batch_actions, batch_values
def _batch_rollout(self, initial_game: Game2048, first_actions: torch.Tensor) -> torch.Tensor:
"""
批量rollout
Args:
initial_game: 初始游戏状态
first_actions: 第一步动作
Returns:
最终分数张量
"""
batch_size = first_actions.shape[0]
final_scores = torch.zeros(batch_size, device=self.device)
# 转移到CPU进行游戏模拟游戏逻辑在CPU上更高效
first_actions_cpu = first_actions.cpu().numpy()
# 批量执行rollout
for i in range(batch_size):
# 复制游戏状态
game_copy = initial_game.copy()
# 执行第一步动作
first_action = int(first_actions_cpu[i])
if not game_copy.move(first_action):
# 如果第一步动作无效,使用当前分数
final_scores[i] = game_copy.score
continue
# 执行随机rollout
depth = 0
while not game_copy.is_over and depth < self.max_simulation_depth:
valid_moves = game_copy.get_valid_moves()
if not valid_moves:
break
# 随机选择动作
action = np.random.choice(valid_moves)
if not game_copy.move(action):
break
depth += 1
# 记录最终分数
final_scores[i] = game_copy.score
return final_scores
def _accumulate_stats(self, batch_actions: torch.Tensor, batch_values: torch.Tensor,
action_visits: torch.Tensor, action_values: torch.Tensor):
"""
累积统计信息
Args:
batch_actions: 批次动作
batch_values: 批次价值
action_visits: 动作访问计数
action_values: 动作价值累积
"""
# 使用PyTorch的高效操作进行统计累积
for action in range(4):
mask = (batch_actions == action)
action_visits[action] += mask.sum()
action_values[action] += batch_values[mask].sum()
def _write_training_data(self, game: Game2048, action_visits: torch.Tensor,
action_values: torch.Tensor, valid_actions: List[int]):
"""写入训练数据"""
if not self.training_manager:
return
# 转换到CPU进行训练数据写入
visits_cpu = action_visits.cpu().numpy()
values_cpu = action_values.cpu().numpy()
for action in valid_actions:
visits = int(visits_cpu[action])
if visits > 0:
avg_value = float(values_cpu[action] / visits)
# 根据访问次数按比例添加样本
sample_ratio = min(1.0, 1000.0 / visits)
sample_count = max(1, int(visits * sample_ratio))
for _ in range(sample_count):
self.training_manager.add_training_example(
board_state=game.board,
action=action,
value=avg_value
)
def get_statistics(self) -> Dict[str, float]:
"""获取搜索统计信息"""
if self.total_simulations == 0:
return {"simulations": 0, "avg_time_per_sim": 0.0, "sims_per_second": 0.0}
avg_time = self.total_time / self.total_simulations
sims_per_sec = self.total_simulations / self.total_time if self.total_time > 0 else 0
stats = {
"total_simulations": self.total_simulations,
"total_time": self.total_time,
"avg_time_per_sim": avg_time,
"sims_per_second": sims_per_sec,
"device": str(self.device),
"batch_size": self.batch_size
}
if self.device.type == "cuda":
stats["gpu_memory_allocated"] = torch.cuda.memory_allocated() / 1e6 # MB
stats["gpu_memory_reserved"] = torch.cuda.memory_reserved() / 1e6 # MB
return stats
def set_device(self, device: str):
"""动态切换设备"""
new_device = self._select_device(device)
if new_device != self.device:
self.device = new_device
self.batch_size = self._select_batch_size(None) # 重新选择批次大小
print(f"设备切换到: {self.device}, 批次大小: {self.batch_size:,}")
def optimize_batch_size(self, game: Game2048, test_simulations: int = 2000) -> int:
"""
自动优化批次大小
Args:
game: 测试游戏状态
test_simulations: 测试模拟次数
Returns:
最优批次大小
"""
print("🔧 自动优化批次大小...")
if self.device.type == "cuda":
test_sizes = [4096, 8192, 16384, 32768, 65536]
else:
test_sizes = [1024, 2048, 4096, 8192]
best_size = self.batch_size
best_speed = 0
for size in test_sizes:
try:
# 临时设置批次大小
old_size = self.batch_size
self.batch_size = size
if self.device.type == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
start_time = time.time()
action, stats = self.search(game.copy(), test_simulations)
elapsed_time = time.time() - start_time
speed = test_simulations / elapsed_time
print(f" 批次 {size:,}: {speed:.0f} 模拟/秒")
if speed > best_speed:
best_speed = speed
best_size = size
# 恢复原批次大小
self.batch_size = old_size
except Exception as e:
print(f" 批次 {size:,}: 失败 ({e})")
# 设置最优批次大小
self.batch_size = best_size
print(f"🎯 最优批次大小: {best_size:,} ({best_speed:.0f} 模拟/秒)")
return best_size