Compare commits

...

2 Commits

Author SHA1 Message Date
5f542fb577 style: 调整代码风格问题 2026-05-13 18:25:05 +08:00
5d95027894 refactor: 重写沙比 AI 写的训练代码 2026-05-13 18:20:31 +08:00
8 changed files with 998 additions and 1022 deletions

View File

@ -1,30 +0,0 @@
# Ginka 地图生成器 - Copilot 指引
## 项目概述
本项目是一个基于深度学习的二维网格状地图生成模型用于生成魔塔Magic Tower类网页游戏地图。
- **模型结构**VQ-VAE 风格编码器 + MaskGIT 解码器
- VQ-VAE 编码器将完整地图压缩为离散隐变量 z从 codebook 查得)
- MaskGIT 以 z 为条件,通过迭代掩码预测生成地图
- 推理时直接随机采样 z无需用户输入
- **地图规格**13×13 格子7 类图块
- **目录结构**
- `ginka/` — 模型定义与训练脚本Python
- `data/` — 数据预处理TypeScript因游戏是网页游戏
- `docs/` — 设计文档
- `shared/` — 可视化等共享工具
## 重要约束
### 训练
- **不要在当前设备上运行训练**,训练在其他设备上进行
- 可以运行小规模验证、推理或单步测试,但不要触发完整训练流程
### 代码风格
- **Python**:不使用三引号注释(`"""..."""`),一律改用 `#` 注释;不出现连续空格;遵循 Prettier 风格(缩进 4 空格,行宽 88
- **TypeScript**:遵循 Prettier 默认风格
### 验证与可视化
- 编写验证代码时,优先输出可视化结果(图片文件),使用 `shared/image.py` 中的工具
- 验证阶段应对不同条件(不同 z 采样)分别生成图片,便于直观对比模型效果

View File

@ -4,715 +4,187 @@ import torch
import numpy as np
from torch.utils.data import Dataset
def _compute_map_labels(map_2d) -> dict:
"""
2D 地图列表 numpy 数组推算结构标签
JSON 数据缺少 roomCount / highDegBranchCount / outerWall 字段时调用
"""
arr = np.array(map_2d, dtype=np.int64) # [H, W]
H, W = arr.shape
WALL, ENTRY = 1, 5
# outerWall最外圈中 wall+entry 占比 > 90%
border = np.concatenate([arr[0, :], arr[-1, :], arr[1:-1, 0], arr[1:-1, -1]])
total_b = border.size
outer_wall = int(total_b > 0 and np.sum((border == WALL) | (border == ENTRY)) / total_b > 0.9)
# roomCountBFS 统计 floor(0)+resource(3) 连通区域,
# 需满足:总格子 >= 4外接矩形宽 >= 2 且高 >= 2
FLOOR_SET = (0, 3)
visited = np.zeros((H, W), dtype=bool)
room_count = 0
for sy in range(H):
for sx in range(W):
if arr[sy, sx] not in FLOOR_SET or visited[sy, sx]:
continue
queue = [(sy, sx)]
visited[sy, sx] = True
tiles_y, tiles_x = [sy], [sx]
head = 0
while head < len(queue):
y, x = queue[head]; head += 1
for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
ny, nx = y + dy, x + dx
if 0 <= ny < H and 0 <= nx < W and not visited[ny, nx] and arr[ny, nx] in FLOOR_SET:
visited[ny, nx] = True
queue.append((ny, nx))
tiles_y.append(ny); tiles_x.append(nx)
if (len(tiles_y) >= 4
and max(tiles_y) - min(tiles_y) >= 1
and max(tiles_x) - min(tiles_x) >= 1):
room_count += 1
# highDegBranchCount非 wall 格子中4 邻域非 wall 邻居 >= 3 的数量
non_wall = (arr != WALL).astype(np.int32)
padded = np.pad(non_wall, 1, mode='constant', constant_values=0)
nbr_sum = (padded[:-2, 1:-1] + padded[2:, 1:-1] +
padded[1:-1, :-2] + padded[1:-1, 2:])
high_deg = int(np.sum((non_wall == 1) & (nbr_sum >= 3)))
return {'outerWall': outer_wall, 'roomCount': room_count, 'highDegBranchCount': high_deg}
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
data = json.load(f)
data_list = []
for value in data["data"].values():
# 兼容旧版数据集(缺少结构标签字段)
if 'roomCount' not in value:
labels = _compute_map_labels(value['map'])
value.update(labels)
# symmetry 字段由 __getitem__ 在增强后重新计算,此处不需要从 JSON 读取
data_list.append(value)
return data_list
def _compute_symmetry(target_np: np.ndarray) -> tuple:
def compute_symmetry(target_np: np.ndarray) -> tuple:
"""从 numpy 地图矩阵中直接计算三种对称性O(H*W)"""
sym_h = bool(np.all(target_np == target_np[:, ::-1]))
sym_v = bool(np.all(target_np == target_np[::-1, :]))
sym_c = bool(np.all(target_np == target_np[::-1, ::-1]))
return int(sym_h), int(sym_v), int(sym_c)
class GinkaVQDataset(Dataset):
"""
用于 VQ-VAE + MaskGIT 联合训练的多子集数据集
每次 __getitem__ 按权重随机选取以下四种子集之一
A (standard): 标准 MaskGIT 随机掩码随机遮盖部分 tile
B (wall-only): 仅保留 wall(1) + floor(0)其余全部替换为 MASK(6)
C (wall-random): B 基础上再随机 mask 部分 wall tile
D (wall+entry): 仅保留 wall(1) + floor(0) + entrance(5)其余全部替换为 MASK(6)
返回 dict:
raw_map: LongTensor [H*W] 完整原始地图 VQ-VAE 编码
masked_map: LongTensor [H*W] MaskGIT 输入 mask 的位置 = 6
target_map: LongTensor [H*W] CE loss ground truth等同 raw_map
subset: str 子集标识供调试/统计用
"""
FLOOR = 0
WALL = 1
class GinkaSeperatedDataset(Dataset):
FLOOR = 0
WALL = 1
DOOR = 2
RESOURCE = 3
MONSTER = 4
ENTRANCE = 5
MASK_ID = 6
MASK_ID = 6
def __init__(
self,
data_path: str,
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
wall_mask_ratio: float = 0.3,
room_thresholds: tuple = None,
branch_thresholds: tuple = None,
subset_weights: tuple = (0.5, 0.3, 0.2),
subset2_wall_prob: float = 0.7
):
"""
Args:
data_path: JSON 数据文件路径
subset_weights: 子集 (A, B, C, D) 的采样权重自动归一化
wall_mask_ratio: Subset C 中额外随机 mask wall tile 比例上限
每次从 [0, wall_mask_ratio] 均匀采样实际比例
room_thresholds: (th1, th2) 房间数量等频分箱阈值 None 时自动从当前数据计算训练集
branch_thresholds: (th1, th2) 分支数量等频分箱阈值 None 时自动从当前数据计算训练集
"""
self.data = load_data(data_path)
self.wall_mask_ratio = wall_mask_ratio
self.subset2_wall_prob = subset2_wall_prob
total = sum(subset_weights)
self.subset_cumw = [sum(subset_weights[:i+1]) / total for i in range(len(subset_weights))]
# 累积权重,用于快速随机子集选择
total_w = sum(subset_weights)
normalized = [x / total_w for x in subset_weights]
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
n = len(self.data)
rs = sorted(item['roomCount'] for item in self.data)
bs = sorted(item['highDegBranchCount'] for item in self.data)
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
if th1_r == th2_r: th2_r = th1_r + 1
if th1_b == th2_b: th2_b = th1_b + 1
self.room_th = (th1_r, th2_r)
self.branch_th = (th1_b, th2_b)
# ── 两趟扫描:计算等频分箱阈值 ──────────────────────────────
room_counts = [item['roomCount'] for item in self.data]
branch_counts = [item['highDegBranchCount'] for item in self.data]
if room_thresholds is None:
n = len(room_counts)
rs = sorted(room_counts)
bs = sorted(branch_counts)
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
# 防止 Medium 等级为空
if th1_r == th2_r:
th2_r = th1_r + 1
if th1_b == th2_b:
th2_b = th1_b + 1
self.room_th = (th1_r, th2_r)
self.branch_th = (th1_b, th2_b)
else:
self.room_th = room_thresholds
self.branch_th = branch_thresholds
def to_level(v: int, th: tuple) -> int:
return 0 if v < th[0] else (1 if v < th[1] else 2)
# 回填等级字段
for item in self.data:
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
item['roomCountLevel'] = self.to_level(item['roomCount'], self.room_th)
item['branchLevel'] = self.to_level(item['highDegBranchCount'], self.branch_th)
def to_level(self, v, th):
return 0 if v < th[0] else (1 if v < th[1] else 2)
def __len__(self):
return len(self.data)
# ------------------------------------------------------------------
# 内联随机掩码生成(避免 scipy 的 NumPy 版本兼容问题)
# ------------------------------------------------------------------
@staticmethod
def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float:
"""用 Beta(2,2) 分布采样掩码比例,集中在中间值。"""
r = np.random.beta(2, 2)
return min_r + (max_r - min_r) * r
def degrade_tile(self, m: np.ndarray, tiles: list) -> np.ndarray:
# 将指定 tile ID 替换为 floor(0),原地修改
for t in tiles:
m[m == t] = self.FLOOR
return m
@staticmethod
def _random_mask(h: int, w: int) -> np.ndarray:
"""纯随机掩码,返回 [H*W] bool。"""
ratio = GinkaVQDataset._sample_mask_ratio()
total = h * w
idx = np.random.choice(total, int(total * ratio), replace=False)
mask = np.zeros(total, dtype=bool)
mask[idx] = True
return mask
@staticmethod
def _block_mask(h: int, w: int) -> np.ndarray:
"""矩形分块随机掩码,返回 [H*W] bool。"""
ratio = GinkaVQDataset._sample_mask_ratio()
max_block = max(2, min(h, w) // 2)
target = int(h * w * ratio)
mask = np.zeros((h, w), dtype=bool)
while mask.sum() < target:
bh = np.random.randint(2, max_block + 1)
bw = np.random.randint(2, max_block + 1)
x = np.random.randint(0, max(1, h - bh + 1))
y = np.random.randint(0, max(1, w - bw + 1))
mask[x:x + bh, y:y + bw] = True
return mask.reshape(-1)
def _std_mask(self, h: int, w: int) -> np.ndarray:
"""标准 MaskGIT 掩码:随机选择纯随机或分块策略。"""
def std_mask(self) -> np.ndarray:
# Beta(2,2) 采样掩码比例50% 随机掩码 / 50% 分块掩码,返回 bool[13, 13]
ratio = float(np.random.beta(2, 2)) * 0.95 + 0.05
if random.random() < 0.5:
return self._random_mask(h, w)
else:
return self._block_mask(h, w)
idx = np.random.choice(169, int(169 * ratio), replace=False)
mask = np.zeros(169, dtype=bool)
mask[idx] = True
return mask.reshape(13, 13)
target = int(169 * ratio)
mask = np.zeros((13, 13), dtype=bool)
while mask.sum() < target:
bh = np.random.randint(2, 7)
bw = np.random.randint(2, 7)
x = np.random.randint(0, 14 - bh)
y = np.random.randint(0, 14 - bw)
mask[x:x + bh, y:y + bw] = True
return mask
def create_degreaded(self, raw: np.ndarray):
# 阶段一:生成墙壁和入口
target1 = raw.copy()
self.degrade_tile(target1, [self.DOOR, self.RESOURCE, self.MONSTER])
inp1 = target1.copy()
# 阶段二:生成怪物、门,同时也允许生成入口以适配结构
target2 = raw.copy()
self.degrade_tile(target2, [self.RESOURCE, self.WALL])
inp2 = raw.copy()
self.degrade_tile(inp2, [self.RESOURCE])
# 阶段三:生成资源
target3 = raw.copy()
self.degrade_tile(target3, [self.WALL, self.DOOR, self.MONSTER, self.ENTRANCE])
inp3 = raw.copy()
return target1, inp1, target2, inp2, target3, inp3
# ------------------------------------------------------------------
def apply_subset1(self, raw: np.ndarray):
# 子集 1std_mask 随机掩码
target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw)
enc1 = target1.copy()
enc2 = inp2.copy()
enc3 = raw.copy()
# stage1对整图 std_mask
inp1[self.std_mask()] = self.MASK_ID
# stage2对 floor+功能元素区域 std_mask
need_mask = np.isin(inp2, [self.FLOOR, self.DOOR, self.MONSTER, self.ENTRANCE])
inp2[need_mask & self.std_mask()] = self.MASK_ID
# stage3对 floor+resource 区域 std_mask
need_mask = np.isin(inp3, [self.FLOOR, self.RESOURCE])
inp3[need_mask & self.std_mask()] = self.MASK_ID
return inp1, target1, enc1, inp2, target2, enc2, inp3, target3, enc3
def apply_subset2(self, raw: np.ndarray):
# 子集 2掩码所有内容墙壁随机掩码不掩码入口
target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw)
enc1 = target1.copy()
enc2 = inp2.copy()
enc3 = raw.copy()
if np.random.random() < self.subset2_wall_prob:
inp1[self.std_mask()] = self.MASK_ID
need_mask = np.isin(inp2, [self.FLOOR, self.DOOR, self.MONSTER, self.ENTRANCE])
inp2[need_mask] = self.MASK_ID
need_mask = np.isin(inp3, [self.FLOOR, self.RESOURCE])
inp3[need_mask] = self.MASK_ID
return inp1, target1, enc1, inp2, target2, enc2, inp3, target3, enc3
def apply_subset3(self, raw: np.ndarray):
# 子集 3在 2 的基础上掩码入口
out = self.apply_subset2(raw)
out[0][out[0] == self.ENTRANCE] = self.MASK_ID
return out
def __getitem__(self, idx):
item = self.data[idx]
map_np = np.array(item['map'], dtype=np.int64)
def _augment(self, arr: np.ndarray) -> np.ndarray:
"""随机旋转 / 翻转数据增强,返回新 array。"""
if np.random.rand() > 0.5:
k = np.random.randint(1, 4)
arr = np.rot90(arr, k).copy()
map_np = np.rot90(map_np, k).copy()
if np.random.rand() > 0.5:
arr = np.fliplr(arr).copy()
map_np = np.fliplr(map_np).copy()
if np.random.rand() > 0.5:
arr = np.flipud(arr).copy()
return arr
map_np = np.flipud(map_np).copy()
def _choose_subset(self) -> str:
r = random.random()
if r < self.subset_cumw[0]:
return 'A'
out = self.apply_subset1(map_np)
elif r < self.subset_cumw[1]:
return 'B'
elif r < self.subset_cumw[2]:
return 'C'
out = self.apply_subset2(map_np)
else:
return 'D'
out = self.apply_subset3(map_np)
def _apply_subset(self, raw: np.ndarray, subset: str) -> np.ndarray:
"""
根据子集类型生成 masked_map
Args:
raw: [H, W] int64 完整原始地图
subset: 'A' | 'B' | 'C' | 'D'
Returns:
[H*W] int64被遮盖位置値为 MASK_ID(6)
"""
H, W = raw.shape
if subset == 'A':
# 标准随机 mask纯随机或分块策略
mask = self._std_mask(H, W) # [H*W] bool
flat = raw.reshape(-1).copy()
flat[mask] = self.MASK_ID
return flat
elif subset == 'B':
# 仅保留 wall(1)floor(0) 和其他非墙内容全部 mask
flat = raw.reshape(-1).copy()
keep = (flat == self.WALL)
flat[~keep] = self.MASK_ID
return flat
elif subset == 'C':
# Subset B + 随机 mask 部分 wall
flat = raw.reshape(-1).copy()
keep = (flat == self.WALL)
flat[~keep] = self.MASK_ID
wall_idx = np.where(flat == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
flat[chosen] = self.MASK_ID
return flat
else: # D
# 仅保留 wall(1) 和 entrance(5)floor(0) 和其他非墙内容全部 mask
flat = raw.reshape(-1).copy()
keep = (flat == self.WALL) | (flat == self.ENTRANCE)
flat[~keep] = self.MASK_ID
# 随机 mask 部分 wall模拟真实场景与子集 C 一致)
wall_idx = np.where(flat == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
flat[chosen] = self.MASK_ID
return flat
def __getitem__(self, idx):
item = self.data[idx]
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
subset = self._choose_subset()
masked_np = self._apply_subset(raw_np, subset) # [H*W]
raw_flat = raw_np.reshape(-1) # [H*W]
# 对称性:在增强后重新计算
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7]
# 其余结构标签:增强不改变拓扑结构,直接读取
cond_room = item['roomCountLevel'] # 0/1/2
cond_branch = item['branchLevel'] # 0/1/2
cond_outer = item['outerWall'] # 0/1
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
raw_t = torch.LongTensor(raw_flat)
return {
"raw_map": raw_t, # VQ-VAE 编码器输入
"slice1": make_slice(raw_t, {0, 1}), # 通道 1floor+wall
"slice2": make_slice(raw_t, {0, 1, 2, 4, 5}), # 通道 2floor+wall+门+怪+入口
"slice3": raw_t.clone(), # 通道 3完整地图
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
"subset": subset, # 调试/统计用
"struct_cond": struct_cond, # [4],供模型 Embedding 查表
}
# ---------------------------------------------------------------------------
# make_slice按保留集合切割地图其余位置替换为 floor(0)
# ---------------------------------------------------------------------------
def make_slice(map_flat: torch.Tensor, keep_set: set) -> torch.Tensor:
"""
从完整地图中只保留 keep_set 中的 tile 类型其余位置替换为 floor(0)
Args:
map_flat: LongTensor [H*W] 完整地图 tile ID 序列
keep_set: set of int 需要保留的 tile 类型集合
Returns:
LongTensor [H*W] 切片后的地图非保留 tile 位置值为 0
"""
out = map_flat.clone()
mask = torch.zeros_like(out, dtype=torch.bool)
for t in keep_set:
mask |= (out == t)
out[~mask] = 0
return out
# ---------------------------------------------------------------------------
# GinkaSplitDataset三通道分拆预训练专用数据集
# ---------------------------------------------------------------------------
class GinkaSplitDataset(Dataset):
"""
三通道分拆预训练方案 B专用数据集
每个样本只提供完整地图及其三路切片不做 MaskGIT 掩码处理
切片按累积式设计
slice1 = floor(0) + wall(1)
slice2 = floor(0) + wall(1) + door(2) + mob(4) + entrance(5)
slice3 = 完整地图所有 tile
返回 dict:
raw_map: LongTensor [H*W] 完整原始地图
slice1: LongTensor [H*W] 通道 1 切片floor+wall
slice2: LongTensor [H*W] 通道 2 切片floor+wall+++入口
slice3: LongTensor [H*W] 通道 3 切片完整地图
"""
def __init__(self, data_path: str):
self.data = load_data(data_path)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
arr = np.array(item['map'], dtype=np.int64) # [H, W]
# 随机旋转 / 翻转数据增强
if np.random.rand() > 0.5:
k = np.random.randint(1, 4)
arr = np.rot90(arr, k).copy()
if np.random.rand() > 0.5:
arr = np.fliplr(arr).copy()
if np.random.rand() > 0.5:
arr = np.flipud(arr).copy()
raw = torch.LongTensor(arr.reshape(-1)) # [H*W]
return {
"raw_map": raw,
"slice1": make_slice(raw, {0, 1}),
"slice2": make_slice(raw, {0, 1, 2, 4, 5}),
"slice3": raw.clone(),
}
# ---------------------------------------------------------------------------
# GinkaStageDataset三阶段级联训练专用数据集
# ---------------------------------------------------------------------------
class GinkaStageDataset(Dataset):
"""
三阶段级联生成训练专用 Dataset
每个阶段只预测特定类别的 tile后续阶段以前序阶段输出作为上下文
训练时统一使用 GT 作为前序上下文teacher forcing避免误差级联
阶段划分
stage=1 结构骨架预测 floor(0) + wall(1)
stage=2 功能元素预测 door(2) + monster(4) + entrance(5) floor/wall 为上下文
stage=3 资源放置预测 resource(3)以完整骨架为上下文
返回 dict:
raw_map: LongTensor [H*W] 完整原始地图 VQ-VAE 编码
vq_slice: LongTensor [H*W] 当前阶段 VQ 编码器的输入切片
stage_input: LongTensor [H*W] MaskGIT 输入含上下文 + MASK 位置
target_map: LongTensor [H*W] CE loss ground truth
loss_mask: BoolTensor [H*W] 只对 True 位置计算损失
subset: str 子集标识 A/B/C/D
struct_cond: LongTensor [4] [sym, room, branch, outer]
"""
FLOOR = 0
WALL = 1
DOOR = 2
RESOURCE = 3
MONSTER = 4
ENTRANCE = 5
MASK_ID = 6
STAGE1_TARGETS = frozenset({0, 1})
STAGE2_TARGETS = frozenset({2, 4, 5})
STAGE3_TARGETS = frozenset({3})
# VQ 切片集合:各阶段编码器只"看"与自身相关的 tile
_VQ_KEEP = {
1: frozenset({0, 1}),
2: frozenset({0, 1, 2, 4, 5}),
3: None, # 完整地图
}
def __init__(
self,
data_path: str,
stage: int,
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
wall_mask_ratio: float = 0.3,
room_thresholds: tuple = None,
branch_thresholds: tuple = None,
):
"""
Args:
data_path: JSON 数据文件路径
stage: 生成阶段 1/2/3
subset_weights: 子集 (A, B, C, D) 的采样权重自动归一化
wall_mask_ratio: Subset C 中额外随机 mask wall 比例上限
room_thresholds: 等频分箱阈值None 时自动计算
branch_thresholds: 等频分箱阈值None 时自动计算
"""
assert stage in (1, 2, 3), f"stage 必须是 1/2/3收到 {stage}"
self.stage = stage
self.data = load_data(data_path)
self.wall_mask_ratio = wall_mask_ratio
total_w = sum(subset_weights)
normalized = [x / total_w for x in subset_weights]
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
room_counts = [item['roomCount'] for item in self.data]
branch_counts = [item['highDegBranchCount'] for item in self.data]
if room_thresholds is None:
n = len(room_counts)
rs = sorted(room_counts)
bs = sorted(branch_counts)
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
if th1_r == th2_r: th2_r = th1_r + 1
if th1_b == th2_b: th2_b = th1_b + 1
self.room_th = (th1_r, th2_r)
self.branch_th = (th1_b, th2_b)
else:
self.room_th = room_thresholds
self.branch_th = branch_thresholds
def to_level(v, th):
return 0 if v < th[0] else (1 if v < th[1] else 2)
for item in self.data:
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
def __len__(self):
return len(self.data)
# ------------------------------------------------------------------
# 掩码辅助(与 GinkaVQDataset 相同逻辑)
# ------------------------------------------------------------------
@staticmethod
def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float:
r = np.random.beta(2, 2)
return min_r + (max_r - min_r) * r
@staticmethod
def _random_mask(h: int, w: int) -> np.ndarray:
ratio = GinkaStageDataset._sample_mask_ratio()
total = h * w
idx = np.random.choice(total, int(total * ratio), replace=False)
mask = np.zeros(total, dtype=bool)
mask[idx] = True
return mask
@staticmethod
def _block_mask(h: int, w: int) -> np.ndarray:
ratio = GinkaStageDataset._sample_mask_ratio()
max_block = max(2, min(h, w) // 2)
target = int(h * w * ratio)
mask = np.zeros((h, w), dtype=bool)
while mask.sum() < target:
bh = np.random.randint(2, max_block + 1)
bw = np.random.randint(2, max_block + 1)
x = np.random.randint(0, max(1, h - bh + 1))
y = np.random.randint(0, max(1, w - bw + 1))
mask[x:x + bh, y:y + bw] = True
return mask.reshape(-1)
def _std_mask(self, h: int, w: int) -> np.ndarray:
return self._random_mask(h, w) if random.random() < 0.5 else self._block_mask(h, w)
# ------------------------------------------------------------------
# 子集选择
# ------------------------------------------------------------------
def _choose_subset(self) -> str:
r = random.random()
if r < self.subset_cumw[0]: return 'A'
if r < self.subset_cumw[1]: return 'B'
if r < self.subset_cumw[2]: return 'C'
return 'D'
# ------------------------------------------------------------------
# 阶段一结构骨架floor + wall
# ------------------------------------------------------------------
def _make_stage1(self, raw_flat: np.ndarray, subset: str):
"""
阶段一预测 floor/wall所有非 floor/wall tile 在目标中重映射为 floor
子集决定向模型提供多少 wall 作为上下文条件
"""
H = W = 13
# 目标:非 floor/wall → floor
target = raw_flat.copy()
target[~np.isin(target, [self.FLOOR, self.WALL])] = self.FLOOR
inp = target.copy()
if subset == 'A':
# 标准随机 mask随机遮盖部分 floor/wall
mask = self._std_mask(H, W)
inp[mask] = self.MASK_ID
elif subset == 'B':
# 保留全部 wallMASK floor
inp[inp == self.FLOOR] = self.MASK_ID
elif subset == 'C':
# 随机保留部分 wallMASK 其余(含全部 floor
inp[inp == self.FLOOR] = self.MASK_ID
wall_idx = np.where(inp == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
inp[chosen] = self.MASK_ID
else: # D与 B 相同(阶段一无 entrance 维度)
inp[inp == self.FLOOR] = self.MASK_ID
loss_mask = (inp == self.MASK_ID)
return inp, target, loss_mask
# ------------------------------------------------------------------
# 阶段二功能元素door + monster + entrance
# ------------------------------------------------------------------
def _make_stage2(self, raw_flat: np.ndarray, subset: str):
"""
阶段二 floor/wall 为上下文预测 door/monster/entrance
resource 在输入与目标中均视为 floor阶段二不负责资源
子集决定 wall 上下文的完整程度与 door/monster/entrance 的掩码方式
"""
# 目标resource → floor
target = raw_flat.copy()
target[target == self.RESOURCE] = self.FLOOR
# 基础输入resource → floor功能元素先保留再按子集处理
inp = raw_flat.copy()
inp[inp == self.RESOURCE] = self.FLOOR
if subset == 'A':
# 随机遮盖部分 door/monster/entrance部分上下文补全
func_idx = np.where(np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE]))[0]
if len(func_idx) > 0:
ratio = random.random() * 0.8 + 0.2 # 20%~100%
n = max(1, int(len(func_idx) * ratio))
chosen = np.random.choice(func_idx, n, replace=False)
inp[chosen] = self.MASK_ID
else:
# B/C/D全部 door/monster/entrance → MASK
inp[np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE])] = self.MASK_ID
if subset == 'C':
# 额外随机 mask 部分 wall降低 wall 上下文质量)
wall_idx = np.where(inp == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
inp[chosen] = self.MASK_ID
# loss_mask阶段二只对 door/monster/entrance 原始位置计算损失,
# 不对被额外 mask 的 wall 位置计算(它们在 target 中已知为 wall
loss_mask = np.isin(raw_flat, [self.DOOR, self.MONSTER, self.ENTRANCE])
return inp, target, loss_mask
# ------------------------------------------------------------------
# 阶段三资源放置resource
# ------------------------------------------------------------------
def _make_stage3(self, raw_flat: np.ndarray, subset: str):
"""
阶段三以完整骨架为上下文预测 resource 位置
所有 resource 位置在输入中替换为 MASK
子集 A 随机保留部分 resource 作为上下文部分补全训练
其余子集始终 MASK 全部 resource
"""
target = raw_flat.copy()
inp = raw_flat.copy()
if subset == 'A':
# 随机遮盖部分 resource部分上下文补全
res_idx = np.where(inp == self.RESOURCE)[0]
if len(res_idx) > 0:
ratio = random.random() * 0.8 + 0.2 # 20%~100%
n = max(1, int(len(res_idx) * ratio))
chosen = np.random.choice(res_idx, n, replace=False)
inp[chosen] = self.MASK_ID
else:
pass # 无 resource 时无需处理
else:
# B/C/D全部 resource → MASK
inp[inp == self.RESOURCE] = self.MASK_ID
loss_mask = (inp == self.MASK_ID)
return inp, target, loss_mask
# ------------------------------------------------------------------
# __getitem__
# ------------------------------------------------------------------
def _augment(self, arr: np.ndarray) -> np.ndarray:
if np.random.rand() > 0.5:
k = np.random.randint(1, 4)
arr = np.rot90(arr, k).copy()
if np.random.rand() > 0.5:
arr = np.fliplr(arr).copy()
if np.random.rand() > 0.5:
arr = np.flipud(arr).copy()
return arr
def __getitem__(self, idx):
item = self.data[idx]
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
raw_flat = raw_np.reshape(-1) # [H*W]
subset = self._choose_subset()
if self.stage == 1:
stage_input_np, target_np, loss_mask_np = self._make_stage1(raw_flat, subset)
elif self.stage == 2:
stage_input_np, target_np, loss_mask_np = self._make_stage2(raw_flat, subset)
else:
stage_input_np, target_np, loss_mask_np = self._make_stage3(raw_flat, subset)
# 若 loss_mask 全为 False如地图中无 resource 时的 stage3
# 退回为全图损失,避免 NaN
if not loss_mask_np.any():
loss_mask_np = np.ones_like(loss_mask_np)
# VQ 切片:当前阶段编码器的输入(仅保留相关 tile
raw_t = torch.LongTensor(raw_flat)
vq_keep = self._VQ_KEEP[self.stage]
if vq_keep is None:
vq_slice = raw_t.clone()
else:
vq_slice = make_slice(raw_t, vq_keep)
# 结构标签
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
cond_room = item['roomCountLevel']
sym_h, sym_v, sym_c = compute_symmetry(map_np)
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
cond_room = item['roomCountLevel']
cond_branch = item['branchLevel']
cond_outer = item['outerWall']
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
cond_outer = item['outerWall']
struct_inject = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
return {
"raw_map": raw_t,
"vq_slice": vq_slice,
"stage_input": torch.LongTensor(stage_input_np),
"target_map": torch.LongTensor(target_np),
"loss_mask": torch.BoolTensor(loss_mask_np),
"subset": subset,
"struct_cond": struct_cond,
"input_stage1": torch.LongTensor(out[0]),
"target_stage1": torch.LongTensor(out[1]),
"encoder_stage1": torch.LongTensor(out[2]),
"input_stage2": torch.LongTensor(out[3]),
"target_stage2": torch.LongTensor(out[4]),
"encoder_stage2": torch.LongTensor(out[5]),
"input_stage3": torch.LongTensor(out[6]),
"target_stage3": torch.LongTensor(out[7]),
"encoder_stage3": torch.LongTensor(out[8]),
"struct_inject": struct_inject
}
if __name__ == "__main__":
import os
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json')
ds = GinkaVQDataset(data_path)
print(f"数据集大小: {len(ds)}")
subset_count = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
for i in range(200):
sample = ds[i % len(ds)]
subset_count[sample['subset']] += 1
raw = sample['raw_map']
masked = sample['masked_map']
target = sample['target_map']
print(f"raw_map shape={raw.shape}, dtype={raw.dtype}")
print(f"masked_map shape={masked.shape}, dtype={masked.dtype}")
print(f"target_map shape={target.shape}, dtype={target.dtype}")
print(f"被 mask 的位置数: {(masked == 6).sum().item()} / {masked.numel()}")
print(f"\n200 次采样子集分布: {subset_count}")

View File

@ -4,53 +4,28 @@ import torch.nn as nn
from ..utils import print_memory
from .maskGIT import Transformer
# 结构标签词表大小(最后一个索引为无条件占位符 null
SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-67 = null
ROOM_VOCAB = 4 # roomCountLevel 0-23 = null
BRANCH_VOCAB = 4 # branchLevel 0-23 = null
OUTER_VOCAB = 3 # outerWall 0-12 = null
# 结构标签词表大小
SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-7
ROOM_VOCAB = 3 # roomCountLevel 0-2
BRANCH_VOCAB = 3 # branchLevel 0-2
OUTER_VOCAB = 2 # outerWall 0-1
class GinkaMaskGIT(nn.Module):
"""
改造后的 MaskGIT 地图生成模型
以掩码地图序列和 VQ-VAE 输出的离散隐变量 z 为输入
通过 Transformer encoder-decoder 结构预测被遮盖位置的 tile 类别
z 通过 cross-attention 注入到 Transformer decoder
作为风格/多样性控制信号而非结构重建指导
"""
def __init__(
self,
num_classes: int = 16,
d_model: int = 192,
d_z: int = 64,
dim_ff: int = 512,
nhead: int = 8,
num_layers: int = 4,
map_size: int = 13 * 13,
z_dropout: float = 0.1,
struct_dropout: float = 0.15,
self, num_classes: int = 16, d_model: int = 192, dim_ff: int = 512,
nhead: int = 8, num_layers: int = 4, map_size: int = 13 * 13, d_z: int = 64
):
"""
Args:
num_classes: tile 类别数 MASK token=15
d_model: Transformer 内部维度
d_z: VQ-VAE 码字嵌入维度需与 GinkaVQVAE.d_z 一致
dim_ff: 前馈网络隐层维度
nhead: 注意力头数
num_layers: Transformer 层数
map_size: 地图 token 总数H * W
z_dropout: 训练时随机替换 z 为随机码字的概率提升鲁棒性
struct_dropout: 训练时以此概率将结构标签替换为 null无条件占位
实现 classifier-free guidance 兼容训练
"""
super().__init__()
self.z_dropout = z_dropout
self.struct_dropout_prob = struct_dropout
# Tile 嵌入 + 位置编码
self.tile_embedding = nn.Embedding(num_classes, d_model)
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
@ -67,10 +42,10 @@ class GinkaMaskGIT(nn.Module):
# 结构标签嵌入(编码到 d_z 维度)
# 注意:结构标签与 VQ 码字语义不同,使用独立投影层避免混用
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
self.struct_proj = nn.Sequential(
nn.Linear(d_z, d_model * 2),
@ -92,108 +67,75 @@ class GinkaMaskGIT(nn.Module):
self,
map: torch.Tensor,
z: torch.Tensor,
struct_cond: torch.Tensor | None = None,
dropout_struct: bool = False,
struct: torch.Tensor
) -> torch.Tensor:
"""
Args:
map: [B, H*W] 掩码后的地图 token 序列MASK token = 15
z: [B, L_total, d_z] VQ-VAE 量化后的离散隐变量
方案 B L_total = L1+L2+L3三路 z 拼接
struct_cond: [B, 4] 结构标签 LongTensor顺序为
[cond_sym, cond_room, cond_branch, cond_outer]
None 时等价于全 null无条件模式
dropout_struct: bool 强制将所有结构标签替换为 null推理时无条件生成
# map: [B, H * W]
# z: [B, L * 3, d_z]
# struch: [B, 4]
Returns:
logits: [B, H*W, num_classes]
"""
B = z.shape[0]
# z dropout训练时以一定概率将 z 替换为随机均匀噪声,
# 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义
if self.training and self.z_dropout > 0:
mask = torch.rand(B, 1, 1, device=z.device) < self.z_dropout
rand_z = torch.randn_like(z)
z = torch.where(mask, rand_z, z)
# 结构标签嵌入
# struct_cond 为 None 或 dropout_struct=True 时,全部使用 null 索引
if struct_cond is None or dropout_struct:
sym_idx = torch.full((B,), SYM_VOCAB - 1, dtype=torch.long, device=z.device)
room_idx = torch.full((B,), ROOM_VOCAB - 1, dtype=torch.long, device=z.device)
branch_idx = torch.full((B,), BRANCH_VOCAB - 1, dtype=torch.long, device=z.device)
outer_idx = torch.full((B,), OUTER_VOCAB - 1, dtype=torch.long, device=z.device)
else:
sc = struct_cond.to(z.device)
sym_idx, room_idx, branch_idx, outer_idx = sc[:, 0], sc[:, 1], sc[:, 2], sc[:, 3]
# 训练时对各标签独立做 struct dropout
if self.training and self.struct_dropout_prob > 0:
def _drop(idx, null_val):
drop_mask = torch.rand(B, device=z.device) < self.struct_dropout_prob
return torch.where(drop_mask, torch.full_like(idx, null_val), idx)
sym_idx = _drop(sym_idx, SYM_VOCAB - 1)
room_idx = _drop(room_idx, ROOM_VOCAB - 1)
branch_idx = _drop(branch_idx, BRANCH_VOCAB - 1)
outer_idx = _drop(outer_idx, OUTER_VOCAB - 1)
sym_idx = struct[:, 0]
room_idx = struct[:, 1]
branch_idx = struct[:, 2]
outer_idx = struct[:, 3]
# 嵌入结构标签到 d_z 维度,拼接到 z 序列末尾
e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z]
e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z]
e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z]
e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z]
e_branch = self.branch_embed(branch_idx).unsqueeze(1) # [B, 1, d_z]
e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z]
e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z]
struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z]
struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z]
# VQ 码字与结构标签语义不同,使用各自独立的投影层后再拼接
z_mem_vq = self.z_proj(z) # [B, L, d_model]
z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model]
z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L+4, d_model]
z_mem_vq = self.z_proj(z) # [B, L, d_model]
z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model]
z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L * 3 + 4, d_model]
# tile embedding + 位置编码
x = self.tile_embedding(map) # [B, H*W, d_model]
x = x + self.pos_embedding # [B, H*W, d_model]
x = self.tile_embedding(map) # [B, H * W, d_model]
x = x + self.pos_embedding # [B, H * W, d_model]
# Transformerencoder 做 map 自注意力decoder cross-attend z+struct
x = self.transformer(x, memory=z_mem) # [B, H*W, d_model]
x = self.transformer(x, memory=z_mem) # [B, H * W, d_model]
logits = self.output_fc(x) # [B, H*W, num_classes]
logits = self.output_fc(x) # [B, H * W, num_classes]
return logits
if __name__ == "__main__":
device = torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [4, 169]
z_input = torch.randn(4, 2, 64).to(device) # [4, 2, 64]
struct_input = torch.tensor([
[3, 1, 0, 1],
[0, 2, 1, 0],
[5, 1, 2, 1],
[1, 0, 1, 0],
], dtype=torch.long).to(device) # [4, 4]
model = GinkaMaskGIT(
num_classes=16,
num_classes=7,
d_model=192,
d_z=64,
dim_ff=512,
dim_ff=2048,
nhead=8,
num_layers=4,
num_layers=6,
map_size=13 * 13,
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)")
for name, module in model.named_children():
n = sum(p.numel() for p in module.parameters())
print(f" {name}: {n:,}")
print_memory(device, "初始化后")
map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169]
z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64]
struct_input = torch.tensor([[3, 1, 0, 1],
[0, 2, 1, 0],
[7, 3, 3, 2],
[1, 0, 2, 1]], dtype=torch.long).to(device) # [B=4, 4]
model.train()
logits = model(map_input, z_input, struct_cond=struct_input)
print(f"\nlogits shape: {logits.shape}") # [4, 169, 16]
# 无条件模式测试
logits_uncond = model(map_input, z_input, struct_cond=None)
print(f"logits_uncond shape: {logits_uncond.shape}") # [4, 169, 16]
start = time.perf_counter()
logits = model(map_input, z_input, struct_input)
end = time.perf_counter()
print_memory(device, "前向传播后")
print(f"推理耗时: {end - start:.4f}s")
print(f"输出形状: logits={logits.shape}")
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}")
print(f"Struct Projection parameters: {sum(p.numel() for p in model.struct_proj.parameters())}")
print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}")
print(f"Output FC parameters: {sum(p.numel() for p in model.output_fc.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

666
ginka/train_seperated.py Normal file
View File

@ -0,0 +1,666 @@
import argparse
import math
import os
import sys
import random
from datetime import datetime
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from .vqvae.quantize import VectorQuantizer
from .vqvae.model import GinkaVQVAE
from .maskGIT.model import GinkaMaskGIT
from .dataset import GinkaSeperatedDataset
from shared.image import matrix_to_image_cv
# 三阶段级联地图生成训练脚本
#
# 整体架构:
# VQ-VAE三组独立编码器 vq1/vq2/vq3将三阶段地图上下文分别编码为离散潜变量
# 再由共用 VectorQuantizer 统一量化为 z_q
# 三个独立 MaskGITmg1/mg2/mg3分别以 z_q 和 struct_inject 为条件,
# 逐阶段迭代解码地图图块序列。
#
# 三阶段生成目标:
# stage1 → floor / wall地图骨架
# stage2 → door / monster / entrance功能性实体
# stage3 → resource资源点
# 图块 ID 定义:
# 0. 空地 1. 墙壁 2. 门 3. 资源 4. 怪物 5. 入口 6. 掩码MASK_TOKEN
# 共用 VQ-VAE 超参
# 三组编码器vq1/vq2/vq3共享相同超参分别对三阶段地图上下文独立编码
VQ_L = 2 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3
VQ_K = 8 # codebook 大小(离散码本条目数)
VQ_D_Z = 64 # 码字维度
VQ_BETA = 0.5 # commit loss 权重(防止编码器输出漂离 codebook
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
VQ_LAYERS = 3 # VQ-VAE Transformer 层数
VQ_DIM_FF = 512 # VQ-VAE 前馈网络隐层维度
VQ_D_MODEL = 64 # VQ-VAE Transformer 模型维度
VQ_NHEAD = 8 # VQ-VAE 多头注意力头数
# 第一阶段 MaskGIT 超参
STAGE1_MG_DMODEL = 192
STAGE1_MG_NHEAD = 8
STAGE1_MG_NUM_LAYERS = 6
STAGE1_MG_DIM_FF = 1024
# 第二阶段 MaskGIT 超参
STAGE2_MG_DMODEL = 192
STAGE2_MG_NHEAD = 8
STAGE2_MG_NUM_LAYERS = 6
STAGE2_MG_DIM_FF = 1024
# 第三阶段 MaskGIT 超参
STAGE3_MG_DMODEL = 192
STAGE3_MG_NHEAD = 8
STAGE3_MG_NUM_LAYERS = 6
STAGE3_MG_DIM_FF = 1024
# 三阶段 Focal Loss 损失权重(可调节各阶段对总损失的贡献比例)
STAGE1_FOCAL_WEIGHT = 1.0
STAGE2_FOCAL_WEIGHT = 1.0
STAGE3_FOCAL_WEIGHT = 1.0
# 各阶段 VQ commit loss 权重(当前未单独使用,统一由 VQ_BETA 控制)
STAGE1_VQ_WEIGHT = 0.5
STAGE2_VQ_WEIGHT = 0.5
STAGE3_VQ_WEIGHT = 0.5
# 全局参数
NUM_CLASSES = 7 # 图块类型数
MASK_TOKEN = 6 # 掩码图块
MAP_W = 13 # 地图宽度
MAP_H = 13 # 地图高度
MAP_SIZE = MAP_W * MAP_H # 地图大小
GENERATE_STEP = 18 # MaskGIT 采样步数
SUBSET2_WALL_PROB = 0.7 # 子集2 进行墙壁掩码的概率
SUBSET_WEIGHTS = (0.5, 0.3, 0.2) # 每个子集的概率
MG_Z_DROPOUT = 0.1 # z 隐变量 Dropout 概率
MG_STRUCT_DROPOUT = 0.1 # 结构参量 Dropout 概率
# 损失参数
FOCAL_GAMMA = 2.0 # Focal Loss 参数
VQ_BETA = 0.5 # 承诺损失权重
# 训练超参
BATCH_SIZE = 64 # 每批样本数
LR = 1e-4 # AdamW 初始学习率
MIN_LR = 1e-6 # 余弦退火最低学习率
WEIGHT_DECAY = 1e-4 # L2 正则化系数
EPOCHS = 400 # 总训练轮数
CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证
device = torch.device(
"cuda:1" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
disable_tqdm = not sys.stdout.isatty()
def _str2bool(v: str):
if isinstance(v, bool): return v
if v.lower() in ('true', '1', 'yes'): return True
if v.lower() in ('false', '0', 'no'): return False
raise argparse.ArgumentTypeError(f"布尔值应为 True/False收到: {v!r}")
def parse_arguments():
parser = argparse.ArgumentParser(description="三阶段级联训练")
parser.add_argument("--resume", type=_str2bool, default=False)
parser.add_argument("--state", type=str, default="", help="续训时检查点路径")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--load_optim", type=_str2bool, default=True)
return parser.parse_args()
def build_model(device: torch.device):
# 三组 VQ-VAE 编码器各自独立编码一个阶段的地图上下文encoder_stage1/2/3
# 输出形状均为 [B, L, d_z],拼接后送入共用 quantizer
vq_kwargs = dict(
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_model=VQ_D_MODEL,
nhead=VQ_NHEAD, num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_size=MAP_SIZE
)
vq1 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage1 上下文floor/wall
vq2 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage2 上下文door/monster/entrance
vq3 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage3 上下文resource
# 三个独立 MaskGIT 解码器,均接收完整的三阶段 z_q 作为条件
mg1 = GinkaMaskGIT(
num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF,
nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_size=MAP_SIZE
).to(device)
mg2 = GinkaMaskGIT(
num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF,
nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_size=MAP_SIZE
).to(device)
mg3 = GinkaMaskGIT(
num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF,
nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_size=MAP_SIZE
).to(device)
# 六个模型参数合并到同一优化器,端到端联合训练
all_params = (
list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) +
list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters())
)
optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4)
# 余弦退火:从 LR 线性衰减至 MIN_LR周期为全部训练轮数
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
# 共用 VectorQuantizer不参与梯度更新仅在前向时做码本查表
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z)
return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler
def focal_loss(logits, target):
# logits: [B, L, C],需转为 [B, C, L] 以匹配 cross_entropy 期望格式
ce = F.cross_entropy(logits.permute(0, 2, 1), target, reduction='none')
pt = torch.exp(-ce) # pt = 模型对正确类的预测概率
# Focal Loss对高置信度样本降低权重让模型更专注于难样本
focal = ((1 - pt) ** FOCAL_GAMMA) * ce
return focal.mean()
def random_struct(device: torch.device) -> torch.Tensor:
# 随机采样一组结构参量,用于无条件自由生成
# struct_inject 格式:[cond_sym(0-7), cond_room(0-2), cond_branch(0-2), cond_outer(0-1)]
cond_sym = random.randint(0, 7) # 地图对称类型
cond_room = random.randint(0, 2) # 房间数量档位
cond_branch = random.randint(0, 2) # 分支复杂度档位
cond_outer = random.randint(0, 1) # 是否有外围走廊
return torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]).unsqueeze(0).to(device)
def maskgit_sample(
model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor,
struct: torch.Tensor, steps: int
) -> np.ndarray:
current = inp.clone()
# 迭代去掩码:每步根据置信度分数重新决定掩码位置
for step in range(steps):
logits = model(current, z, struct)
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
sampled = dist.sample()
confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1)
# 余弦退火调度:随步数推进,保留掩码的位置数量递减至 0
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
num_to_mask = math.floor(ratio * MAP_SIZE)
# 输入中已有的非掩码位(来自上一阶段)保持不变
fixed_mask = (current[0] != MASK_TOKEN)
sampled[0, fixed_mask] = current[0, fixed_mask]
confidences[0, fixed_mask] = 1.0
if num_to_mask > 0:
# 将置信度最低的位重新掩码,留待下一步重新预测
_, mask_indices = torch.topk(confidences[0], k=num_to_mask, largest=False)
sampled[0].scatter_(0, mask_indices, MASK_TOKEN)
current = sampled
if (current[0] == MASK_TOKEN).sum() == 0:
break
# 兜底:若仍有残余掩码位(理论上不应发生),用 argmax 确定性填充
still_masked = (current[0] == MASK_TOKEN)
if still_masked.any():
logits = model(current, z, struct)
current[0, still_masked] = torch.argmax(logits[0, still_masked], dim=-1)
return current[0].cpu().numpy().reshape(MAP_H, MAP_W)
def full_generate_random_z(
input: torch.Tensor,
struct: torch.Tensor,
models: list[torch.nn.Module],
device: torch.device
) -> tuple:
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
with torch.no_grad():
z = quantizer.sample(1, VQ_L, device)
# stage1生成 floor/wall 骨架
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP)
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
# stage2在骨架上生成 door/monster/entrance非零结果覆盖合并
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP)
merged12 = pred1_np.copy()
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp3[inp3 == 0] = MASK_TOKEN
# stage3填充 resource
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP)
merged123 = merged12.copy()
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
return pred1_np, merged12, merged123
def full_generate_specific_z(
input: torch.Tensor,
z: torch.Tensor,
struct: torch.Tensor,
models: list[torch.nn.Module],
device: torch.device
) -> tuple:
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
with torch.no_grad():
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP)
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp2[inp2 == 0] = MASK_TOKEN
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP)
merged12 = pred1_np.copy()
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp3[inp3 == 0] = MASK_TOKEN
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP)
merged123 = merged12.copy()
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
return pred1_np, merged12, merged123
# 验证可视化 part13×3 网格行1=编码器输入行2=掩码输入行3=三阶段预测(合并)
def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
SEP = 3
TILE_SIZE = 32
img_h = MAP_H * TILE_SIZE
img_w = MAP_W * TILE_SIZE
def to_img(mat):
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
pred1 = torch.argmax(logits1[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
pred2 = torch.argmax(logits2[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
pred3 = torch.argmax(logits3[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
pred3_merged = pred1.copy()
pred3_merged[pred2 != 0] = pred2[pred2 != 0]
pred3_merged[pred3 != 0] = pred3[pred3 != 0]
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
inp2_np = batch["input_stage2"][0].numpy().reshape(MAP_H, MAP_W)
inp3_np = batch["input_stage3"][0].numpy().reshape(MAP_H, MAP_W)
rows = [
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
[to_img(inp1_np), to_img(inp2_np), to_img(inp3_np)],
[to_img(pred1), to_img(pred2), to_img(pred3_merged)],
]
grid = np.ones((3 * img_h + 4 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
for c, img in enumerate(row):
y = SEP + r * (img_h + SEP)
x = SEP + c * (img_w + SEP)
grid[y:y + img_h, x:x + img_w] = img
return grid
# 验证可视化 part2行1=真实地图三阶段行2=stage1 输入与使用真实 z 自回归生成的各阶段结果
def visualize_part2(batch, z_q, models, device, tile_dict):
SEP = 3
TILE_SIZE = 32
img_h = MAP_H * TILE_SIZE
img_w = MAP_W * TILE_SIZE
def to_img(mat):
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
struct_t = batch["struct_inject"][0:1].to(device)
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
inp1_t, z_q[0:1], struct_t, models, device
)
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
rows = [
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
[to_img(inp1_np), to_img(auto_pred1_np), to_img(auto_merged12), to_img(auto_merged123)],
]
grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
for c, img in enumerate(row):
y = SEP + r * (img_h + SEP)
x = SEP + c * (img_w + SEP)
grid[y:y + img_h, x:x + img_w] = img
return grid
# 验证可视化 part32×3 网格行1=参考输入+相同 struct 随机 z 生成行2=随机 struct 生成
def visualize_part3(batch, models, device, tile_dict):
SEP = 3
TILE_SIZE = 32
img_h = MAP_H * TILE_SIZE
img_w = MAP_W * TILE_SIZE
def to_img(mat):
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
struct_ref = batch["struct_inject"][0:1].to(device)
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
row1 = [to_img(inp1_np)]
for _ in range(2):
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device)
row1.append(to_img(merged123))
row2 = []
for _ in range(3):
_, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device)
row2.append(to_img(merged123))
rows = [row1, row2]
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
for c, img in enumerate(row):
y = SEP + r * (img_h + SEP)
x = SEP + c * (img_w + SEP)
grid[y:y + img_h, x:x + img_w] = img
return grid
# 验证可视化 part42×3 网格;以少量随机墙壁作为种子,纯随机 struct+z 自由生成
def visualize_part4(models, device, tile_dict):
SEP = 3
TILE_SIZE = 32
img_h = MAP_H * TILE_SIZE
img_w = MAP_W * TILE_SIZE
def to_img(mat):
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
n_walls = random.randint(math.floor(MAP_SIZE * 0.02), math.floor(MAP_SIZE * 0.06))
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
wall_pos = torch.randperm(MAP_SIZE, device=device)[:n_walls]
seed[0, wall_pos] = 1
seed_np = seed[0].cpu().numpy().reshape(MAP_H, MAP_W)
results = []
for _ in range(5):
_, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device)
results.append(to_img(merged123))
row1 = [to_img(seed_np)] + results[:2]
row2 = results[2:]
rows = [row1, row2]
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
for c, img in enumerate(row):
y = SEP + r * (img_h + SEP)
x = SEP + c * (img_w + SEP)
grid[y:y + img_h, x:x + img_w] = img
return grid
def visualize_validate(
batch, logits1, logits2, logits3, z_q,
models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int, batch_idx: int
):
save_dir = f"result/seperated/e{epoch}"
os.makedirs(save_dir, exist_ok=True)
cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict))
cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict))
cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict))
def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int):
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
# 切换为推理模式(关闭 Dropout / BatchNorm 统计更新)
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
m.eval()
# 累计各阶段损失(跨所有 batch 求和,最终除以 batch 数得到均值)
loss1_total = torch.Tensor([0]).to(device)
loss2_total = torch.Tensor([0]).to(device)
loss3_total = torch.Tensor([0]).to(device)
commit_total = torch.Tensor([0]).to(device)
idx = 0
with torch.no_grad():
for batch in tqdm(dataloader, leave=False, desc="Validate Progress", disable=disable_tqdm):
# 三阶段各自的掩码输入、预测目标和 VQ 编码器输入
inp1 = batch["input_stage1"].to(device).reshape(-1, MAP_SIZE)
target1 = batch["target_stage1"].to(device).reshape(-1, MAP_SIZE)
enc1 = batch["encoder_stage1"].to(device).reshape(-1, MAP_SIZE)
inp2 = batch["input_stage2"].to(device).reshape(-1, MAP_SIZE)
target2 = batch["target_stage2"].to(device).reshape(-1, MAP_SIZE)
enc2 = batch["encoder_stage2"].to(device).reshape(-1, MAP_SIZE)
inp3 = batch["input_stage3"].to(device).reshape(-1, MAP_SIZE)
target3 = batch["target_stage3"].to(device).reshape(-1, MAP_SIZE)
enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE)
struct = batch["struct_inject"].to(device)
# VQ 编码:各阶段独立编码后拼接、量化
z_e1 = vq1(enc1) # [B, L, d_z]
z_e2 = vq2(enc2)
z_e3 = vq3(enc3)
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
# 三阶段 MaskGIT 推理(均以完整 z_q 和 struct 为条件)
logits1 = mg1(inp1, z_q, struct)
logits2 = mg2(inp2, z_q, struct)
logits3 = mg3(inp3, z_q, struct)
loss1_total += focal_loss(logits1, target1)
loss2_total += focal_loss(logits2, target2)
loss3_total += focal_loss(logits3, target3)
commit_total += commit_loss
# 每个 batch 生成三种可视化图val/full/rand
visualize_validate(batch, logits1, logits2, logits3, z_q, models, device, tile_dict, epoch, idx)
idx += 1
# 每个 epoch 额外生成一张无条件自由生成图(不依赖任何 batch 样本)
save_dir = f"result/seperated/e{epoch}"
os.makedirs(save_dir, exist_ok=True)
cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict))
# 恢复训练模式
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
m.train()
return loss1_total, loss2_total, loss3_total, commit_total
def train(device: torch.device):
args = parse_arguments()
models = build_model(device)
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
start_epoch = 0
if args.resume:
# 从指定检查点恢复:加载所有模型权重及训练状态
ckpt = torch.load(args.state, map_location=device)
vq1.load_state_dict(ckpt["vq1"])
vq2.load_state_dict(ckpt["vq2"])
vq3.load_state_dict(ckpt["vq3"])
mg1.load_state_dict(ckpt["mg1"])
mg2.load_state_dict(ckpt["mg2"])
mg3.load_state_dict(ckpt["mg3"])
quantizer.load_state_dict(ckpt["quantizer"])
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
if args.load_optim and "optimizer" in ckpt:
optimizer.load_state_dict(ckpt["optimizer"])
if args.load_optim and "scheduler" in ckpt:
scheduler.load_state_dict(ckpt["scheduler"])
start_epoch = ckpt.get("epoch", 0) # 从上次保存的 epoch 继续
tqdm.write(f"Resumed from epoch {start_epoch}: {args.state}")
os.makedirs("result/seperated", exist_ok=True)
dataset = GinkaSeperatedDataset(
args.train, subset_weights=SUBSET_WEIGHTS, subset2_wall_prob=SUBSET2_WALL_PROB
)
dataloader = DataLoader(
dataset, batch_size=BATCH_SIZE, shuffle=True
)
dataset_val = GinkaSeperatedDataset(
args.validate, subset_weights=SUBSET_WEIGHTS, subset2_wall_prob=SUBSET2_WALL_PROB
)
dataloader_val = DataLoader(
dataset_val, batch_size=min(BATCH_SIZE, len(dataset_val) // 8), shuffle=True
)
# 预加载图块图像,键为文件名(不含扩展名),用于可视化时将 ID 映射为像素图
tile_dict = {}
for f in os.listdir("tiles"):
name = os.path.splitext(f)[0]
img = cv2.imread(f"tiles/{f}", cv2.IMREAD_UNCHANGED)
if img is not None:
tile_dict[name] = img
for epoch in tqdm(range(start_epoch, EPOCHS), desc="Seperated Training", disable=disable_tqdm):
loss_total = torch.Tensor([0]).to(device)
loss1_total = torch.Tensor([0]).to(device)
loss2_total = torch.Tensor([0]).to(device)
loss3_total = torch.Tensor([0]).to(device)
commit_total = torch.Tensor([0]).to(device)
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
# 三阶段各自的掩码输入序列、预测目标和编码器上下文
inp1 = batch["input_stage1"].to(device).reshape(-1, MAP_SIZE)
target1 = batch["target_stage1"].to(device).reshape(-1, MAP_SIZE)
enc1 = batch["encoder_stage1"].to(device).reshape(-1, MAP_SIZE)
inp2 = batch["input_stage2"].to(device).reshape(-1, MAP_SIZE)
target2 = batch["target_stage2"].to(device).reshape(-1, MAP_SIZE)
enc2 = batch["encoder_stage2"].to(device).reshape(-1, MAP_SIZE)
inp3 = batch["input_stage3"].to(device).reshape(-1, MAP_SIZE)
target3 = batch["target_stage3"].to(device).reshape(-1, MAP_SIZE)
enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE)
# 结构条件向量:[cond_sym, cond_room, cond_branch, cond_outer]
struct = batch["struct_inject"].to(device)
optimizer.zero_grad()
# VQ 编码:各阶段编码器分别处理各自上下文切片
z_e1 = vq1(enc1) # [B, L, d_z]
z_e2 = vq2(enc2)
z_e3 = vq3(enc3)
# 合并三阶段编码后量化
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
# 三阶段 MaskGIT 前向(均接收完整三阶段 z_q
logits1 = mg1(inp1, z_q, struct)
logits2 = mg2(inp2, z_q, struct)
logits3 = mg3(inp3, z_q, struct)
# 三阶段 Focal Loss + VQ commit loss 加权求和
loss1 = focal_loss(logits1, target1)
loss2 = focal_loss(logits2, target2)
loss3 = focal_loss(logits3, target3)
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3
commit_weighted = VQ_BETA * commit_loss
loss = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
loss.backward()
optimizer.step()
# detach 后累加,避免保留计算图占用显存
loss_total += loss.detach()
loss1_total += loss1.detach()
loss2_total += loss2.detach()
loss3_total += loss3.detach()
commit_total += commit_loss.detach()
# 每个 epoch 结束后更新学习率
scheduler.step()
data_length = len(dataloader)
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
f"L1: {loss1_total.item() / data_length:.6f} | "
f"L2: {loss2_total.item() / data_length:.6f} | "
f"L3: {loss3_total.item() / data_length:.6f} | "
f"VQ: {commit_total.item() / data_length:.6f} | "
f"LR: {scheduler.get_last_lr()[0]:.6f}"
)
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
if (epoch + 1) % CHECKPOINT == 0:
losses = validate(dataloader_val, models, device, tile_dict, epoch + 1)
loss1_total, loss2_total, loss3_total, commit_total = losses
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1_total
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2_total
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3_total
commit_weighted = VQ_BETA * commit_total
loss_total = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
data_length = len(dataloader_val)
tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
f"L1: {loss1_total.item() / data_length:.6f} | "
f"L2: {loss2_total.item() / data_length:.6f} | "
f"L3: {loss3_total.item() / data_length:.6f} | "
f"VQ: {commit_total.item() / data_length:.6f} | "
)
ckpt_path = f"result/seperated/sep-{epoch + 1}.pth"
torch.save({
"epoch": epoch + 1,
"vq1": vq1.state_dict(),
"vq2": vq2.state_dict(),
"vq3": vq3.state_dict(),
"mg1": mg1.state_dict(),
"mg2": mg2.state_dict(),
"mg3": mg3.state_dict(),
"quantizer": quantizer.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}, ckpt_path)
tqdm.write(f"Saved checkpoint: {ckpt_path}")
# 训练结束后保存最终完整权重(含优化器状态,可用于后续续训或推理)
final_path = "result/seperated.pth"
torch.save({
"epoch": EPOCHS,
"vq1": vq1.state_dict(),
"vq2": vq2.state_dict(),
"vq3": vq3.state_dict(),
"mg1": mg1.state_dict(),
"mg2": mg2.state_dict(),
"mg3": mg3.state_dict(),
"quantizer": quantizer.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}, final_path)
tqdm.write(f"Training complete. Final model saved: {final_path}")

View File

@ -285,26 +285,26 @@ def make_random_struct_cond():
def make_stage_init(stage: int, context_map: torch.Tensor) -> torch.Tensor:
"""
根据阶段构造 MaskGIT 的推理初始地图
根据阶段构造 MaskGIT 的推理初始地图与训练端掩码策略一致
Stage 1: MASK或保留稀疏 wall 种子
Stage 2: 保留 floor/wall 上下文其余 MASK
Stage 3: 保留完整上下文floor/wall/door/monster/entranceresource MASK
Stage 1: MASK
Stage 2: 只保留 wall(1)floor + 功能元素 MASK
Stage 3: 保留 wall(1)/door(2)/monster(4)/entrance(5)floor + resource MASK
"""
init = context_map.clone()
if stage == 1:
# 全 MASK不依赖上下文地图
init = torch.full_like(init, MASK_TOKEN)
elif stage == 2:
# 保留 floor/wall其余 → MASK
mask = ~torch.isin(init, torch.tensor([0, 1], device=init.device))
init[mask] = MASK_TOKEN
# 保留 wall其余全部 → MASK
keep = torch.isin(init, torch.tensor([1], device=init.device))
init[~keep] = MASK_TOKEN
else: # stage == 3
# 保留非 resourceresource → MASK
init[init == 3] = MASK_TOKEN
# 保留 wall + 功能元素floor + resource → MASK
keep = torch.isin(init, torch.tensor([1, 2, 4, 5], device=init.device))
init[~keep] = MASK_TOKEN
return init

View File

@ -1,8 +1,9 @@
import time
import torch
import torch.nn as nn
from .quantize import VectorQuantizer
from typing import Tuple
from ..utils import print_memory
class _DecodeLayer(nn.Module):
"""单个解码层Pre-LN Cross-Attention + Pre-LN FFN。"""
@ -97,62 +98,13 @@ class VQDecodeHead(nn.Module):
class GinkaVQVAE(nn.Module):
"""
VQ-VAE 风格地图编码器
将一张完整的地图[B, H*W] 整数 tile ID 序列编码为 L 个离散码字
输出 z [B, L, d_z] 作为 MaskGIT 模型的生成条件
架构
tile embedding + 位置编码
L 个可学习 summary token拼接到序列头部
Transformer EncoderPre-LN自注意力
取前 L 个输出
线性投影到 d_z
VectorQuantizer直通估计 + 熵最大化正则
设计约束
- 参数量目标 < 1M
- 不含解码器z 的语义由 MaskGIT 端的交叉熵损失间接约束
- z 定位为风格/多样性控制信号而非结构重建指导
"""
def __init__(
self,
num_classes: int = 16,
L: int = 2,
K: int = 16,
d_z: int = 64,
d_model: int = 128,
nhead: int = 4,
num_layers: int = 2,
dim_ff: int = 256,
map_size: int = 13 * 13,
beta: float = 0.25,
gamma: float = 0.1,
vq_temp: float = 1.0,
self, num_classes: int = 16, L: int = 2, K: int = 16, d_z: int = 64, d_model: int = 128,
nhead: int = 4, num_layers: int = 2, dim_ff: int = 256, map_size: int = 13 * 13
):
"""
Args:
num_classes: tile 类别数 MASK token
L: 码字序列长度 z 的序列维度
K: codebook 大小码字总数
d_z: 码字嵌入维度
d_model: Transformer 内部维度
nhead: 注意力头数
num_layers: Transformer 层数
dim_ff: 前馈网络隐层维度
map_size: 地图 token 总数H * W
beta: 承诺损失权重
gamma: 熵正则损失权重
vq_temp: VQ 软分配 softmax 温度
"""
super().__init__()
self.L = L
self.K = K
self.d_z = d_z
self.beta = beta
self.gamma = gamma
# Tile 嵌入
self.tile_embedding = nn.Embedding(num_classes, d_model)
@ -165,13 +117,8 @@ class GinkaVQVAE(nn.Module):
# Pre-LN Transformer Encoder训练更稳定
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_ff,
batch_first=True,
activation='gelu',
norm_first=True, # Pre-LN
dropout=0.0,
d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True,
activation='gelu', norm_first=True, dropout=0.1,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
@ -181,127 +128,43 @@ class GinkaVQVAE(nn.Module):
nn.LayerNorm(d_z),
)
# 向量量化层
self.vq = VectorQuantizer(K=K, d_z=d_z, temp=vq_temp)
def encode(self, map: torch.Tensor) -> torch.Tensor:
"""
将地图编码为量化前的连续向量序列
Args:
map: [B, H*W] 整数 tile ID
Returns:
z_e: [B, L, d_z] 量化前的编码向量
"""
B = map.shape[0]
x = self.tile_embedding(map) # [B, H*W, d_model]
x = x + self.pos_embedding # [B, H*W, d_model]
summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model]
x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model]
x = self.transformer(x) # [B, L+H*W, d_model]
z_e = self.proj(x[:, :self.L]) # [B, L, d_z]
return z_e
def encode_soft(self, soft_emb: torch.Tensor) -> torch.Tensor:
"""
将软嵌入序列编码为量化前的连续向量序列用于一致性约束
encode() 的区别输入是已经过 softmax 加权求和得到的连续嵌入矩阵
[B, H*W, d_model]而非整数 tile ID梯度可完整回传到调用方的 logits
Args:
soft_emb: [B, H*W, d_model] softmax 加权 tile 嵌入已在 d_model 空间
Returns:
z_e: [B, L, d_z] 量化前的编码向量
"""
B = soft_emb.shape[0]
x = soft_emb + self.pos_embedding # [B, H*W, d_model]
summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model]
x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model]
x = self.transformer(x) # [B, L+H*W, d_model]
z_e = self.proj(x[:, :self.L]) # [B, L, d_z]
return z_e
def forward(self, map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
完整前向传播编码 量化 计算损失
Args:
map: [B, H*W] 整数 tile ID训练时传入完整真实地图
Returns:
z_q: [B, L, d_z] 量化后的 z含直通梯度 MaskGIT 使用
z_e: [B, L, d_z] 量化前的连续编码向量供一致性约束使用
indices: [B, L] 每个位置对应的码字索引
vq_loss: scalar VQ 总损失 = beta * commit_loss + gamma * entropy_loss
commit_loss: scalar
entropy_loss: scalar
"""
z_e = self.encode(map)
z_q, indices, commit_loss, entropy_loss = self.vq(z_e)
vq_loss = self.beta * commit_loss + self.gamma * entropy_loss
return z_q, z_e, indices, vq_loss, commit_loss, entropy_loss
def sample(self, B: int, device: torch.device) -> torch.Tensor:
"""
推理阶段 codebook 中随机均匀采样 L 个码字
Args:
B: batch size
device: 目标设备
Returns:
z: [B, L, d_z]
"""
indices = torch.randint(0, self.K, (B, self.L), device=device)
z = self.vq.codebook(indices) # [B, L, d_z]
return z
# map: [B, H * W]
B, L = map.shape
x = self.tile_embedding(map) # [B, H * W, d_model]
x = x + self.pos_embedding # [B, H * W, d_model]
summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model]
x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model]
x = self.transformer(x)
z_e = self.proj(x[:, :self.L])
return z_e
if __name__ == "__main__":
device = torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [B=4, 169]
model = GinkaVQVAE(
num_classes=16,
L=2,
K=16,
d_z=64,
d_model=128,
nhead=4,
num_layers=2,
dim_ff=256,
map_size=13 * 13,
num_classes=7, L=2, K=16, d_z=64, d_model=128,
nhead=4, num_layers=2, dim_ff=256, map_size=13 * 13,
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)")
print_memory(device, "初始化后")
# 分模块参数统计
for name, module in model.named_children():
n = sum(p.numel() for p in module.parameters())
print(f" {name}: {n:,}")
start = time.perf_counter()
z_e = model(map_input)
end = time.perf_counter()
# 前向传播测试
map_input = torch.randint(0, 15, (4, 13 * 13)).to(device) # [B=4, 169]
print_memory(device, "前向传播后")
z_q, z_e, indices, vq_loss, commit_loss, entropy_loss = model(map_input)
print(f"\nz_q shape: {z_q.shape}") # [4, 2, 64]
print(f"z_e shape: {z_e.shape}") # [4, 2, 64]
print(f"indices shape:{indices.shape}") # [4, 2]
print(f"vq_loss: {vq_loss.item():.4f}")
# 推理采样测试
z_sample = model.sample(B=4, device=device)
print(f"sample shape: {z_sample.shape}") # [4, 2, 64]
print(f"推理耗时: {end - start:.4f}s")
print(f"输出形状: z_e={z_e.shape}")
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}")
print(f"Projection parameters: {sum(p.numel() for p in model.proj.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -3,20 +3,8 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class VectorQuantizer(nn.Module):
"""
向量量化层Vector Quantization
将连续的编码向量序列映射到离散的 codebook 码字索引
并通过直通估计Straight-Through Estimator保持梯度流
均匀分布正则化采用软分配熵最大化方案
通过对距离做 softmax 得到软分配概率计算平均码字使用率的熵
最小化负熵以鼓励所有码字被均等使用
"""
def __init__(self, K: int, d_z: int, temp: float = 1.0):
def __init__(self, K: int, d_z: int):
"""
Args:
K: codebook 大小码字数量
@ -26,12 +14,12 @@ class VectorQuantizer(nn.Module):
super().__init__()
self.K = K
self.d_z = d_z
self.temp = temp
self.codebook = nn.Embedding(K, d_z)
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# z_e: [B, L * 3, d_z]
"""
Args:
z_e: [B, L, d_z] 编码器输出的连续向量序列
@ -40,28 +28,25 @@ class VectorQuantizer(nn.Module):
z_q_st: [B, L, d_z] 量化后向量直通梯度
indices: [B, L] 每个位置对应的码字索引
commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2
entropy_loss: scalar 负熵损失最小化 = 最大化码字使用均匀度
"""
B, L, d_z = z_e.shape
# 展平到 [B*L, d_z]
z_flat = z_e.reshape(B * L, d_z)
z_flat = z_e.reshape(B * L, d_z) # [B * L * 3, d_z]
codebook_w = self.codebook.weight # [K, d_z]
# 计算 L2 距离:||z_e - e_k||^2 = ||z_e||^2 + ||e_k||^2 - 2 * z_e · e_k
# distances: [B*L, K]
distances = (
(z_flat ** 2).sum(dim=1, keepdim=True) # [B*L, 1]
+ (codebook_w ** 2).sum(dim=1) # [K]
- 2.0 * z_flat @ codebook_w.t() # [B*L, K]
)
ze_square = torch.sum(z_flat ** 2, dim=1, keepdim=True)
ek_square = torch.sum(codebook_w ** 2, dim=1)
mul = z_flat @ codebook_w.t()
distances = ze_square + ek_square - 2 * mul
# Hard assignment取最近码字索引
indices = distances.argmin(dim=1) # [B*L]
indices = distances.argmin(dim=1) # [B*L]
# 量化向量
z_q_flat = self.codebook(indices) # [B*L, d_z]
z_q_flat = self.codebook(indices) # [B*L, d_z]
z_q = z_q_flat.reshape(B, L, d_z)
# 直通估计:前向传 z_q反向传 z_e 的梯度
@ -70,12 +55,14 @@ class VectorQuantizer(nn.Module):
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
commit_loss = F.mse_loss(z_e, z_q.detach())
# 熵最大化正则:通过软分配计算平均码字使用率,最小化负熵
# soft_assign: [B*L, K],对距离做 softmax距离越小概率越大
soft_assign = F.softmax(-distances / self.temp, dim=1)
avg_assign = soft_assign.mean(dim=0) # [K],平均码字使用率
# entropy_loss = -H(p) = sum(p * log(p)),最小化即最大化熵
entropy_loss = (avg_assign * torch.log(avg_assign + 1e-10)).sum()
indices = indices.reshape(B, L)
return z_q_st, indices, commit_loss, entropy_loss
return z_q_st, indices, commit_loss
def sample(self, B: int, L: int, device: torch.device) -> torch.Tensor:
indices1 = torch.randint(0, self.K, (B, L), device=device)
indices2 = torch.randint(0, self.K, (B, L), device=device)
indices3 = torch.randint(0, self.K, (B, L), device=device)
z1 = self.codebook(indices1)
z2 = self.codebook(indices2)
z3 = self.codebook(indices3)
return torch.cat([z1, z2, z3], dim=1)

76
prompt.md Normal file
View File

@ -0,0 +1,76 @@
# Ginka 地图生成器 - Copilot 指引
## 项目概述
本项目是一个基于深度学习的二维网格状地图生成模型用于生成魔塔Magic Tower类网页游戏地图。
- **模型结构**VQ-VAE 风格编码器 + MaskGIT 解码器
- VQ-VAE 编码器将完整地图压缩为离散隐变量 z从 codebook 查得)
- MaskGIT 以 z 为条件,通过迭代掩码预测生成地图
- 推理时直接随机采样 z无需用户输入
- **地图规格**13×13 格子7 类图块
- **目录结构**
- `ginka/` — 模型定义与训练脚本Python
- `data/` — 数据预处理TypeScript因游戏是网页游戏
- `docs/` — 设计文档
- `shared/` — 可视化等共享工具
## 重要约束
### 训练
- **不要在当前设备上运行训练**,训练在其他设备上进行
- 可以运行小规模验证、推理或单步测试,但不要触发完整训练流程
### 代码风格
#### Python
- 不使用三引号注释(`"""..."""`),一律改用 `#` 注释。对于行后的注释,注释的 # 应该在语句后面空一格的地方开始,不要多空,也不要少空,例如 `a = b # 注释内容`
- 不出现连续空行(即空行仅允许连续出现一行)不出现连续空格,例如下面的例子就不允许出现:
```python
a = func1()
abcdef = func2()
```
应改为:
```python
a = func1()
abcdef = func2()
```
- 遵循类似 Prettier 的风格,不出现尾逗号。
- 不进行无意义的对齐,例如函数参数定义应该遵循这种风格,到达 80 字符左右换行:
```python
def func(
param1: type, param2: type, param3: type,
param4: type, param5: type
)
```
而不是:
```python
def func(param1: type, param2: type, param3: type,
param4: type, param5: type)
```
- 不使用下划线开头命名任何内容,包括私有方法。
- 不写静态方法。
- 仅允许在文件开头引入内容,不允许其他地方出现任何 `import`
- 文件尾添加空行。
- 不允许出现连等。
- 不允许使用元组语法同时给多个量分别赋值,比如 `a, b, c = d, e, f` 不允许出现,仅允许 `a, b, c = func()` 这种一赋多的场景。
- 不要在文件开头添加注释,开头第一句应该是 `import`。文件注释应该在 `import` 之后写。
#### TypeScript
遵循 Prettier 风格。
### 验证与可视化
- 编写验证代码时,优先输出可视化结果(图片文件),使用 `shared/image.py` 中的工具
- 验证阶段应对不同条件(不同 z 采样)分别生成图片,便于直观对比模型效果