增加L0训练阶段的MCTS部分
This commit is contained in:
371
game.py
Normal file
371
game.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""
|
||||
2048游戏引擎
|
||||
|
||||
根据论文要求重新设计的2048游戏引擎,包括:
|
||||
1. 正确的累积分数计算
|
||||
2. 棋盘压缩和规范化
|
||||
3. 支持任意大小的矩形棋盘
|
||||
4. 与训练数据模块集成
|
||||
5. 高效的游戏状态管理
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
from typing import Tuple, List, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
from training_data import BoardTransform, ScoreCalculator
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameState:
|
||||
"""游戏状态数据结构"""
|
||||
board: np.ndarray # 棋盘状态(对数形式)
|
||||
score: int # 当前累积分数
|
||||
moves: int # 移动次数
|
||||
is_over: bool # 游戏是否结束
|
||||
canonical_hash: str # 规范化哈希值
|
||||
|
||||
|
||||
class Game2048:
|
||||
"""
|
||||
2048游戏引擎
|
||||
|
||||
特点:
|
||||
- 使用对数表示(空位=0, 2=1, 4=2, 8=3, ...)
|
||||
- 正确的累积分数计算
|
||||
- 支持任意大小的矩形棋盘
|
||||
- 棋盘压缩和规范化
|
||||
- 与训练数据模块集成
|
||||
"""
|
||||
|
||||
def __init__(self, height: int = 4, width: int = 4,
|
||||
spawn_prob_4: float = 0.1, seed: Optional[int] = None):
|
||||
"""
|
||||
初始化游戏
|
||||
|
||||
Args:
|
||||
height: 棋盘高度
|
||||
width: 棋盘宽度
|
||||
spawn_prob_4: 生成4的概率(否则生成2)
|
||||
seed: 随机种子
|
||||
"""
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.spawn_prob_4 = spawn_prob_4
|
||||
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
# 初始化棋盘(对数形式)
|
||||
self.board = np.zeros((height, width), dtype=np.int32)
|
||||
self.score = 0
|
||||
self.moves = 0
|
||||
self.is_over = False
|
||||
|
||||
# 工具类
|
||||
self.transform = BoardTransform()
|
||||
self.score_calc = ScoreCalculator()
|
||||
|
||||
# 生成初始数字
|
||||
self._spawn_tile()
|
||||
self._spawn_tile()
|
||||
|
||||
def reset(self) -> GameState:
|
||||
"""重置游戏到初始状态"""
|
||||
self.board = np.zeros((self.height, self.width), dtype=np.int32)
|
||||
self.score = 0
|
||||
self.moves = 0
|
||||
self.is_over = False
|
||||
|
||||
self._spawn_tile()
|
||||
self._spawn_tile()
|
||||
|
||||
return self.get_state()
|
||||
|
||||
def get_state(self) -> GameState:
|
||||
"""获取当前游戏状态"""
|
||||
canonical_hash = self.transform.compute_hash(self.board)
|
||||
|
||||
return GameState(
|
||||
board=self.board.copy(),
|
||||
score=self.score,
|
||||
moves=self.moves,
|
||||
is_over=self.is_over,
|
||||
canonical_hash=canonical_hash
|
||||
)
|
||||
|
||||
def set_state(self, state: GameState) -> None:
|
||||
"""设置游戏状态"""
|
||||
self.board = state.board.copy()
|
||||
self.score = state.score
|
||||
self.moves = state.moves
|
||||
self.is_over = state.is_over
|
||||
|
||||
def _spawn_tile(self) -> bool:
|
||||
"""
|
||||
在随机空位生成新数字
|
||||
|
||||
Returns:
|
||||
是否成功生成(False表示棋盘已满)
|
||||
"""
|
||||
empty_positions = list(zip(*np.where(self.board == 0)))
|
||||
|
||||
if not empty_positions:
|
||||
return False
|
||||
|
||||
# 随机选择空位
|
||||
pos = random.choice(empty_positions)
|
||||
|
||||
# 根据概率生成2或4(对数形式为1或2)
|
||||
if random.random() < self.spawn_prob_4:
|
||||
self.board[pos] = 2 # 4 = 2^2
|
||||
else:
|
||||
self.board[pos] = 1 # 2 = 2^1
|
||||
|
||||
return True
|
||||
|
||||
def get_empty_positions(self) -> List[Tuple[int, int]]:
|
||||
"""获取所有空位置"""
|
||||
return list(zip(*np.where(self.board == 0)))
|
||||
|
||||
def is_full(self) -> bool:
|
||||
"""检查棋盘是否已满"""
|
||||
return len(self.get_empty_positions()) == 0
|
||||
|
||||
def copy(self) -> 'Game2048':
|
||||
"""创建游戏副本"""
|
||||
new_game = Game2048(self.height, self.width, self.spawn_prob_4)
|
||||
new_game.board = self.board.copy()
|
||||
new_game.score = self.score
|
||||
new_game.moves = self.moves
|
||||
new_game.is_over = self.is_over
|
||||
return new_game
|
||||
|
||||
def _move_row_left(self, row: np.ndarray) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
将一行向左移动和合并
|
||||
|
||||
Args:
|
||||
row: 输入行
|
||||
|
||||
Returns:
|
||||
(新行, 本次移动获得的分数)
|
||||
"""
|
||||
# 移除零元素
|
||||
non_zero = row[row != 0]
|
||||
|
||||
if len(non_zero) == 0:
|
||||
return row, 0
|
||||
|
||||
# 合并相邻的相同元素
|
||||
merged = []
|
||||
score_gained = 0
|
||||
i = 0
|
||||
|
||||
while i < len(non_zero):
|
||||
if i < len(non_zero) - 1 and non_zero[i] == non_zero[i + 1]:
|
||||
# 合并
|
||||
new_value = non_zero[i] + 1
|
||||
merged.append(new_value)
|
||||
|
||||
# 计算分数增量(根据论文公式)
|
||||
tile_value = 2 ** new_value
|
||||
score_gained += tile_value
|
||||
|
||||
i += 2 # 跳过下一个元素
|
||||
else:
|
||||
merged.append(non_zero[i])
|
||||
i += 1
|
||||
|
||||
# 补充零元素
|
||||
result = np.zeros(len(row), dtype=np.int32)
|
||||
result[:len(merged)] = merged
|
||||
|
||||
return result, score_gained
|
||||
|
||||
def move(self, direction: int) -> bool:
|
||||
"""
|
||||
执行移动操作
|
||||
|
||||
Args:
|
||||
direction: 移动方向 (0:上, 1:下, 2:左, 3:右)
|
||||
|
||||
Returns:
|
||||
是否成功移动
|
||||
"""
|
||||
if self.is_over:
|
||||
return False
|
||||
|
||||
before = self.board.copy()
|
||||
total_score_gained = 0
|
||||
|
||||
# 根据方向旋转棋盘,统一处理为向左移动
|
||||
if direction == 0: # 上
|
||||
rotated = np.rot90(self.board, k=1)
|
||||
elif direction == 1: # 下
|
||||
rotated = np.rot90(self.board, k=-1)
|
||||
elif direction == 2: # 左
|
||||
rotated = self.board
|
||||
else: # 右
|
||||
rotated = np.rot90(self.board, k=2)
|
||||
|
||||
# 对每一行执行向左移动
|
||||
new_board = np.zeros_like(rotated)
|
||||
for i in range(rotated.shape[0]):
|
||||
new_row, score_gained = self._move_row_left(rotated[i])
|
||||
new_board[i] = new_row
|
||||
total_score_gained += score_gained
|
||||
|
||||
# 旋转回原方向
|
||||
if direction == 0: # 上
|
||||
self.board = np.rot90(new_board, k=-1)
|
||||
elif direction == 1: # 下
|
||||
self.board = np.rot90(new_board, k=1)
|
||||
elif direction == 2: # 左
|
||||
self.board = new_board
|
||||
else: # 右
|
||||
self.board = np.rot90(new_board, k=-2)
|
||||
|
||||
# 检查是否有变化
|
||||
if np.array_equal(before, self.board):
|
||||
return False
|
||||
|
||||
# 更新分数和移动次数
|
||||
self.score += total_score_gained
|
||||
self.moves += 1
|
||||
|
||||
# 生成新数字
|
||||
if not self._spawn_tile():
|
||||
# 如果无法生成新数字,检查游戏是否结束
|
||||
self._check_game_over()
|
||||
|
||||
return True
|
||||
|
||||
def _check_game_over(self) -> None:
|
||||
"""检查游戏是否结束"""
|
||||
# 如果有空位,游戏未结束
|
||||
if not self.is_full():
|
||||
return
|
||||
|
||||
# 检查是否还能移动
|
||||
for direction in range(4):
|
||||
test_game = self.copy()
|
||||
if test_game._can_move(direction):
|
||||
return
|
||||
|
||||
# 无法移动,游戏结束
|
||||
self.is_over = True
|
||||
|
||||
def _can_move(self, direction: int) -> bool:
|
||||
"""检查指定方向是否可以移动(不实际执行移动)"""
|
||||
# 优化:直接检查而不创建副本
|
||||
if direction == 2: # 左
|
||||
board = self.board
|
||||
elif direction == 3: # 右
|
||||
board = np.fliplr(self.board)
|
||||
elif direction == 0: # 上
|
||||
board = self.board.T
|
||||
else: # 下
|
||||
board = np.flipud(self.board.T)
|
||||
|
||||
# 快速检查:对每一行,看是否有空位可以移动或相邻相同数字可以合并
|
||||
for row in board:
|
||||
# 检查是否有空位可以移动
|
||||
non_zero = row[row != 0]
|
||||
if len(non_zero) < len(row) and len(non_zero) > 0:
|
||||
return True
|
||||
|
||||
# 检查是否有相邻相同数字可以合并
|
||||
for j in range(len(non_zero) - 1):
|
||||
if non_zero[j] == non_zero[j + 1] and non_zero[j] != 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_valid_moves(self) -> List[int]:
|
||||
"""获取所有有效的移动方向"""
|
||||
if self.is_over:
|
||||
return []
|
||||
|
||||
# 缓存有效移动以避免重复计算
|
||||
if not hasattr(self, '_cached_valid_moves') or self._cache_board_hash != hash(self.board.tobytes()):
|
||||
valid_moves = []
|
||||
for direction in range(4):
|
||||
if self._can_move(direction):
|
||||
valid_moves.append(direction)
|
||||
|
||||
self._cached_valid_moves = valid_moves
|
||||
self._cache_board_hash = hash(self.board.tobytes())
|
||||
|
||||
return self._cached_valid_moves
|
||||
|
||||
def get_board_display(self) -> np.ndarray:
|
||||
"""获取用于显示的棋盘(原始数值形式)"""
|
||||
return self.transform.inverse_log_transform(self.board)
|
||||
|
||||
def calculate_total_score(self) -> int:
|
||||
"""计算棋盘的总累积分数"""
|
||||
return self.score_calc.calculate_board_score(self.board)
|
||||
|
||||
def get_max_tile(self) -> int:
|
||||
"""获取棋盘上的最大数字"""
|
||||
max_log = np.max(self.board)
|
||||
return 2 ** max_log if max_log > 0 else 0
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
display_board = self.get_board_display()
|
||||
result = f"Score: {self.score}, Moves: {self.moves}, Max: {self.get_max_tile()}\n"
|
||||
result += "+" + "-" * (self.width * 6 - 1) + "+\n"
|
||||
|
||||
for row in display_board:
|
||||
result += "|"
|
||||
for cell in row:
|
||||
if cell == 0:
|
||||
result += f"{'':^5}|"
|
||||
else:
|
||||
result += f"{cell:^5}|"
|
||||
result += "\n"
|
||||
|
||||
result += "+" + "-" * (self.width * 6 - 1) + "+"
|
||||
return result
|
||||
|
||||
|
||||
def demo_game():
|
||||
"""演示游戏功能"""
|
||||
print("2048游戏引擎演示")
|
||||
print("=" * 50)
|
||||
|
||||
# 创建3x3的小棋盘用于演示
|
||||
game = Game2048(height=3, width=3, seed=42)
|
||||
|
||||
print("初始状态:")
|
||||
print(game)
|
||||
print(f"规范哈希: {game.get_state().canonical_hash}")
|
||||
|
||||
# 执行一些移动
|
||||
moves = [2, 0, 1, 3] # 左、上、下、右
|
||||
move_names = ["左", "上", "下", "右"]
|
||||
|
||||
for i, (move, name) in enumerate(zip(moves, move_names)):
|
||||
print(f"\n第{i+1}步: 向{name}移动")
|
||||
|
||||
if game.move(move):
|
||||
print("移动成功!")
|
||||
print(game)
|
||||
print(f"有效移动: {[move_names[m] for m in game.get_valid_moves()]}")
|
||||
else:
|
||||
print("无法移动!")
|
||||
|
||||
if game.is_over:
|
||||
print("游戏结束!")
|
||||
break
|
||||
|
||||
print(f"\n最终分数: {game.score}")
|
||||
print(f"累积分数: {game.calculate_total_score()}")
|
||||
print(f"最大数字: {game.get_max_tile()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_game()
|
||||
Reference in New Issue
Block a user