mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
Compare commits
2 Commits
6746e96994
...
5f542fb577
| Author | SHA1 | Date | |
|---|---|---|---|
| 5f542fb577 | |||
| 5d95027894 |
30
.github/copilot-instructions.md
vendored
30
.github/copilot-instructions.md
vendored
@ -1,30 +0,0 @@
|
||||
# Ginka 地图生成器 - Copilot 指引
|
||||
|
||||
## 项目概述
|
||||
|
||||
本项目是一个基于深度学习的二维网格状地图生成模型,用于生成魔塔(Magic Tower)类网页游戏地图。
|
||||
|
||||
- **模型结构**:VQ-VAE 风格编码器 + MaskGIT 解码器
|
||||
- VQ-VAE 编码器将完整地图压缩为离散隐变量 z(从 codebook 查得)
|
||||
- MaskGIT 以 z 为条件,通过迭代掩码预测生成地图
|
||||
- 推理时直接随机采样 z,无需用户输入
|
||||
- **地图规格**:13×13 格子,7 类图块
|
||||
- **目录结构**
|
||||
- `ginka/` — 模型定义与训练脚本(Python)
|
||||
- `data/` — 数据预处理(TypeScript,因游戏是网页游戏)
|
||||
- `docs/` — 设计文档
|
||||
- `shared/` — 可视化等共享工具
|
||||
|
||||
## 重要约束
|
||||
|
||||
### 训练
|
||||
- **不要在当前设备上运行训练**,训练在其他设备上进行
|
||||
- 可以运行小规模验证、推理或单步测试,但不要触发完整训练流程
|
||||
|
||||
### 代码风格
|
||||
- **Python**:不使用三引号注释(`"""..."""`),一律改用 `#` 注释;不出现连续空格;遵循 Prettier 风格(缩进 4 空格,行宽 88)
|
||||
- **TypeScript**:遵循 Prettier 默认风格
|
||||
|
||||
### 验证与可视化
|
||||
- 编写验证代码时,优先输出可视化结果(图片文件),使用 `shared/image.py` 中的工具
|
||||
- 验证阶段应对不同条件(不同 z 采样)分别生成图片,便于直观对比模型效果
|
||||
800
ginka/dataset.py
800
ginka/dataset.py
@ -4,715 +4,187 @@ import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
def _compute_map_labels(map_2d) -> dict:
|
||||
"""
|
||||
从 2D 地图列表(或 numpy 数组)推算结构标签。
|
||||
当 JSON 数据缺少 roomCount / highDegBranchCount / outerWall 字段时调用。
|
||||
"""
|
||||
arr = np.array(map_2d, dtype=np.int64) # [H, W]
|
||||
H, W = arr.shape
|
||||
WALL, ENTRY = 1, 5
|
||||
|
||||
# outerWall:最外圈中 wall+entry 占比 > 90%
|
||||
border = np.concatenate([arr[0, :], arr[-1, :], arr[1:-1, 0], arr[1:-1, -1]])
|
||||
total_b = border.size
|
||||
outer_wall = int(total_b > 0 and np.sum((border == WALL) | (border == ENTRY)) / total_b > 0.9)
|
||||
|
||||
# roomCount:BFS 统计 floor(0)+resource(3) 连通区域,
|
||||
# 需满足:总格子 >= 4,外接矩形宽 >= 2 且高 >= 2
|
||||
FLOOR_SET = (0, 3)
|
||||
visited = np.zeros((H, W), dtype=bool)
|
||||
room_count = 0
|
||||
for sy in range(H):
|
||||
for sx in range(W):
|
||||
if arr[sy, sx] not in FLOOR_SET or visited[sy, sx]:
|
||||
continue
|
||||
queue = [(sy, sx)]
|
||||
visited[sy, sx] = True
|
||||
tiles_y, tiles_x = [sy], [sx]
|
||||
head = 0
|
||||
while head < len(queue):
|
||||
y, x = queue[head]; head += 1
|
||||
for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
|
||||
ny, nx = y + dy, x + dx
|
||||
if 0 <= ny < H and 0 <= nx < W and not visited[ny, nx] and arr[ny, nx] in FLOOR_SET:
|
||||
visited[ny, nx] = True
|
||||
queue.append((ny, nx))
|
||||
tiles_y.append(ny); tiles_x.append(nx)
|
||||
if (len(tiles_y) >= 4
|
||||
and max(tiles_y) - min(tiles_y) >= 1
|
||||
and max(tiles_x) - min(tiles_x) >= 1):
|
||||
room_count += 1
|
||||
|
||||
# highDegBranchCount:非 wall 格子中,4 邻域非 wall 邻居 >= 3 的数量
|
||||
non_wall = (arr != WALL).astype(np.int32)
|
||||
padded = np.pad(non_wall, 1, mode='constant', constant_values=0)
|
||||
nbr_sum = (padded[:-2, 1:-1] + padded[2:, 1:-1] +
|
||||
padded[1:-1, :-2] + padded[1:-1, 2:])
|
||||
high_deg = int(np.sum((non_wall == 1) & (nbr_sum >= 3)))
|
||||
|
||||
return {'outerWall': outer_wall, 'roomCount': room_count, 'highDegBranchCount': high_deg}
|
||||
|
||||
|
||||
def load_data(path: str):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = []
|
||||
for value in data["data"].values():
|
||||
# 兼容旧版数据集(缺少结构标签字段)
|
||||
if 'roomCount' not in value:
|
||||
labels = _compute_map_labels(value['map'])
|
||||
value.update(labels)
|
||||
# symmetry 字段由 __getitem__ 在增强后重新计算,此处不需要从 JSON 读取
|
||||
data_list.append(value)
|
||||
|
||||
return data_list
|
||||
|
||||
def _compute_symmetry(target_np: np.ndarray) -> tuple:
|
||||
def compute_symmetry(target_np: np.ndarray) -> tuple:
|
||||
"""从 numpy 地图矩阵中直接计算三种对称性,O(H*W)"""
|
||||
sym_h = bool(np.all(target_np == target_np[:, ::-1]))
|
||||
sym_v = bool(np.all(target_np == target_np[::-1, :]))
|
||||
sym_c = bool(np.all(target_np == target_np[::-1, ::-1]))
|
||||
return int(sym_h), int(sym_v), int(sym_c)
|
||||
|
||||
|
||||
class GinkaVQDataset(Dataset):
|
||||
"""
|
||||
用于 VQ-VAE + MaskGIT 联合训练的多子集数据集。
|
||||
|
||||
每次 __getitem__ 按权重随机选取以下四种子集之一:
|
||||
A (standard): 标准 MaskGIT 随机掩码,随机遮盖部分 tile
|
||||
B (wall-only): 仅保留 wall(1) + floor(0),其余全部替换为 MASK(6)
|
||||
C (wall-random): 在 B 基础上,再随机 mask 部分 wall tile
|
||||
D (wall+entry): 仅保留 wall(1) + floor(0) + entrance(5),其余全部替换为 MASK(6)
|
||||
|
||||
返回 dict:
|
||||
raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码)
|
||||
masked_map: LongTensor [H*W] MaskGIT 输入(被 mask 的位置 = 6)
|
||||
target_map: LongTensor [H*W] CE loss ground truth(等同 raw_map)
|
||||
subset: str 子集标识,供调试/统计用
|
||||
"""
|
||||
|
||||
FLOOR = 0
|
||||
WALL = 1
|
||||
class GinkaSeperatedDataset(Dataset):
|
||||
FLOOR = 0
|
||||
WALL = 1
|
||||
DOOR = 2
|
||||
RESOURCE = 3
|
||||
MONSTER = 4
|
||||
ENTRANCE = 5
|
||||
MASK_ID = 6
|
||||
MASK_ID = 6
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
|
||||
wall_mask_ratio: float = 0.3,
|
||||
room_thresholds: tuple = None,
|
||||
branch_thresholds: tuple = None,
|
||||
subset_weights: tuple = (0.5, 0.3, 0.2),
|
||||
subset2_wall_prob: float = 0.7
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
data_path: JSON 数据文件路径
|
||||
subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化
|
||||
wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限
|
||||
(每次从 [0, wall_mask_ratio] 均匀采样实际比例)
|
||||
room_thresholds: (th1, th2) 房间数量等频分箱阈值;为 None 时自动从当前数据计算(训练集)
|
||||
branch_thresholds: (th1, th2) 分支数量等频分箱阈值;为 None 时自动从当前数据计算(训练集)
|
||||
"""
|
||||
self.data = load_data(data_path)
|
||||
self.wall_mask_ratio = wall_mask_ratio
|
||||
self.subset2_wall_prob = subset2_wall_prob
|
||||
total = sum(subset_weights)
|
||||
self.subset_cumw = [sum(subset_weights[:i+1]) / total for i in range(len(subset_weights))]
|
||||
|
||||
# 累积权重,用于快速随机子集选择
|
||||
total_w = sum(subset_weights)
|
||||
normalized = [x / total_w for x in subset_weights]
|
||||
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
|
||||
n = len(self.data)
|
||||
rs = sorted(item['roomCount'] for item in self.data)
|
||||
bs = sorted(item['highDegBranchCount'] for item in self.data)
|
||||
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
|
||||
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
|
||||
if th1_r == th2_r: th2_r = th1_r + 1
|
||||
if th1_b == th2_b: th2_b = th1_b + 1
|
||||
self.room_th = (th1_r, th2_r)
|
||||
self.branch_th = (th1_b, th2_b)
|
||||
|
||||
# ── 两趟扫描:计算等频分箱阈值 ──────────────────────────────
|
||||
room_counts = [item['roomCount'] for item in self.data]
|
||||
branch_counts = [item['highDegBranchCount'] for item in self.data]
|
||||
|
||||
if room_thresholds is None:
|
||||
n = len(room_counts)
|
||||
rs = sorted(room_counts)
|
||||
bs = sorted(branch_counts)
|
||||
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
|
||||
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
|
||||
# 防止 Medium 等级为空
|
||||
if th1_r == th2_r:
|
||||
th2_r = th1_r + 1
|
||||
if th1_b == th2_b:
|
||||
th2_b = th1_b + 1
|
||||
self.room_th = (th1_r, th2_r)
|
||||
self.branch_th = (th1_b, th2_b)
|
||||
else:
|
||||
self.room_th = room_thresholds
|
||||
self.branch_th = branch_thresholds
|
||||
|
||||
def to_level(v: int, th: tuple) -> int:
|
||||
return 0 if v < th[0] else (1 if v < th[1] else 2)
|
||||
|
||||
# 回填等级字段
|
||||
for item in self.data:
|
||||
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
|
||||
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
|
||||
item['roomCountLevel'] = self.to_level(item['roomCount'], self.room_th)
|
||||
item['branchLevel'] = self.to_level(item['highDegBranchCount'], self.branch_th)
|
||||
|
||||
def to_level(self, v, th):
|
||||
return 0 if v < th[0] else (1 if v < th[1] else 2)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 内联随机掩码生成(避免 scipy 的 NumPy 版本兼容问题)
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float:
|
||||
"""用 Beta(2,2) 分布采样掩码比例,集中在中间值。"""
|
||||
r = np.random.beta(2, 2)
|
||||
return min_r + (max_r - min_r) * r
|
||||
def degrade_tile(self, m: np.ndarray, tiles: list) -> np.ndarray:
|
||||
# 将指定 tile ID 替换为 floor(0),原地修改
|
||||
for t in tiles:
|
||||
m[m == t] = self.FLOOR
|
||||
return m
|
||||
|
||||
@staticmethod
|
||||
def _random_mask(h: int, w: int) -> np.ndarray:
|
||||
"""纯随机掩码,返回 [H*W] bool。"""
|
||||
ratio = GinkaVQDataset._sample_mask_ratio()
|
||||
total = h * w
|
||||
idx = np.random.choice(total, int(total * ratio), replace=False)
|
||||
mask = np.zeros(total, dtype=bool)
|
||||
mask[idx] = True
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def _block_mask(h: int, w: int) -> np.ndarray:
|
||||
"""矩形分块随机掩码,返回 [H*W] bool。"""
|
||||
ratio = GinkaVQDataset._sample_mask_ratio()
|
||||
max_block = max(2, min(h, w) // 2)
|
||||
target = int(h * w * ratio)
|
||||
mask = np.zeros((h, w), dtype=bool)
|
||||
while mask.sum() < target:
|
||||
bh = np.random.randint(2, max_block + 1)
|
||||
bw = np.random.randint(2, max_block + 1)
|
||||
x = np.random.randint(0, max(1, h - bh + 1))
|
||||
y = np.random.randint(0, max(1, w - bw + 1))
|
||||
mask[x:x + bh, y:y + bw] = True
|
||||
return mask.reshape(-1)
|
||||
|
||||
def _std_mask(self, h: int, w: int) -> np.ndarray:
|
||||
"""标准 MaskGIT 掩码:随机选择纯随机或分块策略。"""
|
||||
def std_mask(self) -> np.ndarray:
|
||||
# Beta(2,2) 采样掩码比例,50% 随机掩码 / 50% 分块掩码,返回 bool[13, 13]
|
||||
ratio = float(np.random.beta(2, 2)) * 0.95 + 0.05
|
||||
if random.random() < 0.5:
|
||||
return self._random_mask(h, w)
|
||||
else:
|
||||
return self._block_mask(h, w)
|
||||
idx = np.random.choice(169, int(169 * ratio), replace=False)
|
||||
mask = np.zeros(169, dtype=bool)
|
||||
mask[idx] = True
|
||||
return mask.reshape(13, 13)
|
||||
target = int(169 * ratio)
|
||||
mask = np.zeros((13, 13), dtype=bool)
|
||||
while mask.sum() < target:
|
||||
bh = np.random.randint(2, 7)
|
||||
bw = np.random.randint(2, 7)
|
||||
x = np.random.randint(0, 14 - bh)
|
||||
y = np.random.randint(0, 14 - bw)
|
||||
mask[x:x + bh, y:y + bw] = True
|
||||
return mask
|
||||
|
||||
def create_degreaded(self, raw: np.ndarray):
|
||||
# 阶段一:生成墙壁和入口
|
||||
target1 = raw.copy()
|
||||
self.degrade_tile(target1, [self.DOOR, self.RESOURCE, self.MONSTER])
|
||||
inp1 = target1.copy()
|
||||
|
||||
# 阶段二:生成怪物、门,同时也允许生成入口以适配结构
|
||||
target2 = raw.copy()
|
||||
self.degrade_tile(target2, [self.RESOURCE, self.WALL])
|
||||
inp2 = raw.copy()
|
||||
self.degrade_tile(inp2, [self.RESOURCE])
|
||||
|
||||
# 阶段三:生成资源
|
||||
target3 = raw.copy()
|
||||
self.degrade_tile(target3, [self.WALL, self.DOOR, self.MONSTER, self.ENTRANCE])
|
||||
inp3 = raw.copy()
|
||||
|
||||
return target1, inp1, target2, inp2, target3, inp3
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
def apply_subset1(self, raw: np.ndarray):
|
||||
# 子集 1:std_mask 随机掩码
|
||||
|
||||
target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw)
|
||||
|
||||
enc1 = target1.copy()
|
||||
enc2 = inp2.copy()
|
||||
enc3 = raw.copy()
|
||||
|
||||
# stage1:对整图 std_mask
|
||||
inp1[self.std_mask()] = self.MASK_ID
|
||||
|
||||
# stage2:对 floor+功能元素区域 std_mask
|
||||
need_mask = np.isin(inp2, [self.FLOOR, self.DOOR, self.MONSTER, self.ENTRANCE])
|
||||
inp2[need_mask & self.std_mask()] = self.MASK_ID
|
||||
|
||||
# stage3:对 floor+resource 区域 std_mask
|
||||
need_mask = np.isin(inp3, [self.FLOOR, self.RESOURCE])
|
||||
inp3[need_mask & self.std_mask()] = self.MASK_ID
|
||||
|
||||
return inp1, target1, enc1, inp2, target2, enc2, inp3, target3, enc3
|
||||
|
||||
def apply_subset2(self, raw: np.ndarray):
|
||||
# 子集 2:掩码所有内容,墙壁随机掩码,不掩码入口
|
||||
target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw)
|
||||
|
||||
enc1 = target1.copy()
|
||||
enc2 = inp2.copy()
|
||||
enc3 = raw.copy()
|
||||
|
||||
if np.random.random() < self.subset2_wall_prob:
|
||||
inp1[self.std_mask()] = self.MASK_ID
|
||||
need_mask = np.isin(inp2, [self.FLOOR, self.DOOR, self.MONSTER, self.ENTRANCE])
|
||||
inp2[need_mask] = self.MASK_ID
|
||||
need_mask = np.isin(inp3, [self.FLOOR, self.RESOURCE])
|
||||
inp3[need_mask] = self.MASK_ID
|
||||
|
||||
return inp1, target1, enc1, inp2, target2, enc2, inp3, target3, enc3
|
||||
|
||||
def apply_subset3(self, raw: np.ndarray):
|
||||
# 子集 3:在 2 的基础上掩码入口
|
||||
out = self.apply_subset2(raw)
|
||||
out[0][out[0] == self.ENTRANCE] = self.MASK_ID
|
||||
return out
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
map_np = np.array(item['map'], dtype=np.int64)
|
||||
|
||||
def _augment(self, arr: np.ndarray) -> np.ndarray:
|
||||
"""随机旋转 / 翻转数据增强,返回新 array。"""
|
||||
if np.random.rand() > 0.5:
|
||||
k = np.random.randint(1, 4)
|
||||
arr = np.rot90(arr, k).copy()
|
||||
map_np = np.rot90(map_np, k).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.fliplr(arr).copy()
|
||||
map_np = np.fliplr(map_np).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.flipud(arr).copy()
|
||||
return arr
|
||||
map_np = np.flipud(map_np).copy()
|
||||
|
||||
def _choose_subset(self) -> str:
|
||||
r = random.random()
|
||||
if r < self.subset_cumw[0]:
|
||||
return 'A'
|
||||
out = self.apply_subset1(map_np)
|
||||
elif r < self.subset_cumw[1]:
|
||||
return 'B'
|
||||
elif r < self.subset_cumw[2]:
|
||||
return 'C'
|
||||
out = self.apply_subset2(map_np)
|
||||
else:
|
||||
return 'D'
|
||||
out = self.apply_subset3(map_np)
|
||||
|
||||
def _apply_subset(self, raw: np.ndarray, subset: str) -> np.ndarray:
|
||||
"""
|
||||
根据子集类型生成 masked_map。
|
||||
|
||||
Args:
|
||||
raw: [H, W] int64 完整原始地图
|
||||
subset: 'A' | 'B' | 'C' | 'D'
|
||||
|
||||
Returns:
|
||||
[H*W] int64,被遮盖位置値为 MASK_ID(6)
|
||||
"""
|
||||
H, W = raw.shape
|
||||
|
||||
if subset == 'A':
|
||||
# 标准随机 mask:纯随机或分块策略
|
||||
mask = self._std_mask(H, W) # [H*W] bool
|
||||
flat = raw.reshape(-1).copy()
|
||||
flat[mask] = self.MASK_ID
|
||||
return flat
|
||||
|
||||
elif subset == 'B':
|
||||
# 仅保留 wall(1),floor(0) 和其他非墙内容全部 mask
|
||||
flat = raw.reshape(-1).copy()
|
||||
keep = (flat == self.WALL)
|
||||
flat[~keep] = self.MASK_ID
|
||||
return flat
|
||||
|
||||
elif subset == 'C':
|
||||
# Subset B + 随机 mask 部分 wall
|
||||
flat = raw.reshape(-1).copy()
|
||||
keep = (flat == self.WALL)
|
||||
flat[~keep] = self.MASK_ID
|
||||
|
||||
wall_idx = np.where(flat == self.WALL)[0]
|
||||
if len(wall_idx) > 0:
|
||||
ratio = random.random() * self.wall_mask_ratio
|
||||
n = max(1, int(len(wall_idx) * ratio))
|
||||
chosen = np.random.choice(wall_idx, n, replace=False)
|
||||
flat[chosen] = self.MASK_ID
|
||||
return flat
|
||||
|
||||
else: # D
|
||||
# 仅保留 wall(1) 和 entrance(5),floor(0) 和其他非墙内容全部 mask
|
||||
flat = raw.reshape(-1).copy()
|
||||
keep = (flat == self.WALL) | (flat == self.ENTRANCE)
|
||||
flat[~keep] = self.MASK_ID
|
||||
|
||||
# 随机 mask 部分 wall(模拟真实场景,与子集 C 一致)
|
||||
wall_idx = np.where(flat == self.WALL)[0]
|
||||
if len(wall_idx) > 0:
|
||||
ratio = random.random() * self.wall_mask_ratio
|
||||
n = max(1, int(len(wall_idx) * ratio))
|
||||
chosen = np.random.choice(wall_idx, n, replace=False)
|
||||
flat[chosen] = self.MASK_ID
|
||||
return flat
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
|
||||
subset = self._choose_subset()
|
||||
masked_np = self._apply_subset(raw_np, subset) # [H*W]
|
||||
raw_flat = raw_np.reshape(-1) # [H*W]
|
||||
|
||||
# 对称性:在增强后重新计算
|
||||
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
|
||||
cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7]
|
||||
|
||||
# 其余结构标签:增强不改变拓扑结构,直接读取
|
||||
cond_room = item['roomCountLevel'] # 0/1/2
|
||||
cond_branch = item['branchLevel'] # 0/1/2
|
||||
cond_outer = item['outerWall'] # 0/1
|
||||
|
||||
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
|
||||
|
||||
raw_t = torch.LongTensor(raw_flat)
|
||||
return {
|
||||
"raw_map": raw_t, # VQ-VAE 编码器输入
|
||||
"slice1": make_slice(raw_t, {0, 1}), # 通道 1:floor+wall
|
||||
"slice2": make_slice(raw_t, {0, 1, 2, 4, 5}), # 通道 2:floor+wall+门+怪+入口
|
||||
"slice3": raw_t.clone(), # 通道 3:完整地图
|
||||
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
|
||||
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
|
||||
"subset": subset, # 调试/统计用
|
||||
"struct_cond": struct_cond, # [4],供模型 Embedding 查表
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# make_slice:按保留集合切割地图,其余位置替换为 floor(0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_slice(map_flat: torch.Tensor, keep_set: set) -> torch.Tensor:
|
||||
"""
|
||||
从完整地图中只保留 keep_set 中的 tile 类型,其余位置替换为 floor(0)。
|
||||
|
||||
Args:
|
||||
map_flat: LongTensor [H*W] 完整地图 tile ID 序列
|
||||
keep_set: set of int 需要保留的 tile 类型集合
|
||||
|
||||
Returns:
|
||||
LongTensor [H*W] 切片后的地图(非保留 tile 位置值为 0)
|
||||
"""
|
||||
out = map_flat.clone()
|
||||
mask = torch.zeros_like(out, dtype=torch.bool)
|
||||
for t in keep_set:
|
||||
mask |= (out == t)
|
||||
out[~mask] = 0
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GinkaSplitDataset:三通道分拆预训练专用数据集
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class GinkaSplitDataset(Dataset):
|
||||
"""
|
||||
三通道分拆预训练(方案 B)专用数据集。
|
||||
|
||||
每个样本只提供完整地图及其三路切片,不做 MaskGIT 掩码处理。
|
||||
切片按累积式设计:
|
||||
slice1 = floor(0) + wall(1)
|
||||
slice2 = floor(0) + wall(1) + door(2) + mob(4) + entrance(5)
|
||||
slice3 = 完整地图(所有 tile)
|
||||
|
||||
返回 dict:
|
||||
raw_map: LongTensor [H*W] 完整原始地图
|
||||
slice1: LongTensor [H*W] 通道 1 切片(floor+wall)
|
||||
slice2: LongTensor [H*W] 通道 2 切片(floor+wall+门+怪+入口)
|
||||
slice3: LongTensor [H*W] 通道 3 切片(完整地图)
|
||||
"""
|
||||
|
||||
def __init__(self, data_path: str):
|
||||
self.data = load_data(data_path)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
arr = np.array(item['map'], dtype=np.int64) # [H, W]
|
||||
|
||||
# 随机旋转 / 翻转数据增强
|
||||
if np.random.rand() > 0.5:
|
||||
k = np.random.randint(1, 4)
|
||||
arr = np.rot90(arr, k).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.fliplr(arr).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.flipud(arr).copy()
|
||||
|
||||
raw = torch.LongTensor(arr.reshape(-1)) # [H*W]
|
||||
return {
|
||||
"raw_map": raw,
|
||||
"slice1": make_slice(raw, {0, 1}),
|
||||
"slice2": make_slice(raw, {0, 1, 2, 4, 5}),
|
||||
"slice3": raw.clone(),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GinkaStageDataset:三阶段级联训练专用数据集
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class GinkaStageDataset(Dataset):
|
||||
"""
|
||||
三阶段级联生成训练专用 Dataset。
|
||||
|
||||
每个阶段只预测特定类别的 tile,后续阶段以前序阶段输出作为上下文。
|
||||
训练时统一使用 GT 作为前序上下文(teacher forcing),避免误差级联。
|
||||
|
||||
阶段划分:
|
||||
stage=1 结构骨架:预测 floor(0) + wall(1)
|
||||
stage=2 功能元素:预测 door(2) + monster(4) + entrance(5),以 floor/wall 为上下文
|
||||
stage=3 资源放置:预测 resource(3),以完整骨架为上下文
|
||||
|
||||
返回 dict:
|
||||
raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码)
|
||||
vq_slice: LongTensor [H*W] 当前阶段 VQ 编码器的输入切片
|
||||
stage_input: LongTensor [H*W] MaskGIT 输入(含上下文 + MASK 位置)
|
||||
target_map: LongTensor [H*W] CE loss ground truth
|
||||
loss_mask: BoolTensor [H*W] 只对 True 位置计算损失
|
||||
subset: str 子集标识 A/B/C/D
|
||||
struct_cond: LongTensor [4] [sym, room, branch, outer]
|
||||
"""
|
||||
|
||||
FLOOR = 0
|
||||
WALL = 1
|
||||
DOOR = 2
|
||||
RESOURCE = 3
|
||||
MONSTER = 4
|
||||
ENTRANCE = 5
|
||||
MASK_ID = 6
|
||||
|
||||
STAGE1_TARGETS = frozenset({0, 1})
|
||||
STAGE2_TARGETS = frozenset({2, 4, 5})
|
||||
STAGE3_TARGETS = frozenset({3})
|
||||
|
||||
# VQ 切片集合:各阶段编码器只"看"与自身相关的 tile
|
||||
_VQ_KEEP = {
|
||||
1: frozenset({0, 1}),
|
||||
2: frozenset({0, 1, 2, 4, 5}),
|
||||
3: None, # 完整地图
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
stage: int,
|
||||
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
|
||||
wall_mask_ratio: float = 0.3,
|
||||
room_thresholds: tuple = None,
|
||||
branch_thresholds: tuple = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
data_path: JSON 数据文件路径
|
||||
stage: 生成阶段 1/2/3
|
||||
subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化
|
||||
wall_mask_ratio: Subset C 中额外随机 mask 的 wall 比例上限
|
||||
room_thresholds: 等频分箱阈值(None 时自动计算)
|
||||
branch_thresholds: 等频分箱阈值(None 时自动计算)
|
||||
"""
|
||||
assert stage in (1, 2, 3), f"stage 必须是 1/2/3,收到 {stage}"
|
||||
self.stage = stage
|
||||
self.data = load_data(data_path)
|
||||
self.wall_mask_ratio = wall_mask_ratio
|
||||
|
||||
total_w = sum(subset_weights)
|
||||
normalized = [x / total_w for x in subset_weights]
|
||||
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
|
||||
|
||||
room_counts = [item['roomCount'] for item in self.data]
|
||||
branch_counts = [item['highDegBranchCount'] for item in self.data]
|
||||
|
||||
if room_thresholds is None:
|
||||
n = len(room_counts)
|
||||
rs = sorted(room_counts)
|
||||
bs = sorted(branch_counts)
|
||||
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
|
||||
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
|
||||
if th1_r == th2_r: th2_r = th1_r + 1
|
||||
if th1_b == th2_b: th2_b = th1_b + 1
|
||||
self.room_th = (th1_r, th2_r)
|
||||
self.branch_th = (th1_b, th2_b)
|
||||
else:
|
||||
self.room_th = room_thresholds
|
||||
self.branch_th = branch_thresholds
|
||||
|
||||
def to_level(v, th):
|
||||
return 0 if v < th[0] else (1 if v < th[1] else 2)
|
||||
|
||||
for item in self.data:
|
||||
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
|
||||
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 掩码辅助(与 GinkaVQDataset 相同逻辑)
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float:
|
||||
r = np.random.beta(2, 2)
|
||||
return min_r + (max_r - min_r) * r
|
||||
|
||||
@staticmethod
|
||||
def _random_mask(h: int, w: int) -> np.ndarray:
|
||||
ratio = GinkaStageDataset._sample_mask_ratio()
|
||||
total = h * w
|
||||
idx = np.random.choice(total, int(total * ratio), replace=False)
|
||||
mask = np.zeros(total, dtype=bool)
|
||||
mask[idx] = True
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def _block_mask(h: int, w: int) -> np.ndarray:
|
||||
ratio = GinkaStageDataset._sample_mask_ratio()
|
||||
max_block = max(2, min(h, w) // 2)
|
||||
target = int(h * w * ratio)
|
||||
mask = np.zeros((h, w), dtype=bool)
|
||||
while mask.sum() < target:
|
||||
bh = np.random.randint(2, max_block + 1)
|
||||
bw = np.random.randint(2, max_block + 1)
|
||||
x = np.random.randint(0, max(1, h - bh + 1))
|
||||
y = np.random.randint(0, max(1, w - bw + 1))
|
||||
mask[x:x + bh, y:y + bw] = True
|
||||
return mask.reshape(-1)
|
||||
|
||||
def _std_mask(self, h: int, w: int) -> np.ndarray:
|
||||
return self._random_mask(h, w) if random.random() < 0.5 else self._block_mask(h, w)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 子集选择
|
||||
# ------------------------------------------------------------------
|
||||
def _choose_subset(self) -> str:
|
||||
r = random.random()
|
||||
if r < self.subset_cumw[0]: return 'A'
|
||||
if r < self.subset_cumw[1]: return 'B'
|
||||
if r < self.subset_cumw[2]: return 'C'
|
||||
return 'D'
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 阶段一:结构骨架(floor + wall)
|
||||
# ------------------------------------------------------------------
|
||||
def _make_stage1(self, raw_flat: np.ndarray, subset: str):
|
||||
"""
|
||||
阶段一:预测 floor/wall,所有非 floor/wall tile 在目标中重映射为 floor。
|
||||
子集决定向模型提供多少 wall 作为上下文条件。
|
||||
"""
|
||||
H = W = 13
|
||||
|
||||
# 目标:非 floor/wall → floor
|
||||
target = raw_flat.copy()
|
||||
target[~np.isin(target, [self.FLOOR, self.WALL])] = self.FLOOR
|
||||
|
||||
inp = target.copy()
|
||||
|
||||
if subset == 'A':
|
||||
# 标准随机 mask:随机遮盖部分 floor/wall
|
||||
mask = self._std_mask(H, W)
|
||||
inp[mask] = self.MASK_ID
|
||||
|
||||
elif subset == 'B':
|
||||
# 保留全部 wall,MASK floor
|
||||
inp[inp == self.FLOOR] = self.MASK_ID
|
||||
|
||||
elif subset == 'C':
|
||||
# 随机保留部分 wall,MASK 其余(含全部 floor)
|
||||
inp[inp == self.FLOOR] = self.MASK_ID
|
||||
wall_idx = np.where(inp == self.WALL)[0]
|
||||
if len(wall_idx) > 0:
|
||||
ratio = random.random() * self.wall_mask_ratio
|
||||
n = max(1, int(len(wall_idx) * ratio))
|
||||
chosen = np.random.choice(wall_idx, n, replace=False)
|
||||
inp[chosen] = self.MASK_ID
|
||||
|
||||
else: # D:与 B 相同(阶段一无 entrance 维度)
|
||||
inp[inp == self.FLOOR] = self.MASK_ID
|
||||
|
||||
loss_mask = (inp == self.MASK_ID)
|
||||
return inp, target, loss_mask
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 阶段二:功能元素(door + monster + entrance)
|
||||
# ------------------------------------------------------------------
|
||||
def _make_stage2(self, raw_flat: np.ndarray, subset: str):
|
||||
"""
|
||||
阶段二:以 floor/wall 为上下文,预测 door/monster/entrance。
|
||||
resource 在输入与目标中均视为 floor(阶段二不负责资源)。
|
||||
子集决定 wall 上下文的完整程度与 door/monster/entrance 的掩码方式。
|
||||
"""
|
||||
# 目标:resource → floor
|
||||
target = raw_flat.copy()
|
||||
target[target == self.RESOURCE] = self.FLOOR
|
||||
|
||||
# 基础输入:resource → floor,功能元素先保留,再按子集处理
|
||||
inp = raw_flat.copy()
|
||||
inp[inp == self.RESOURCE] = self.FLOOR
|
||||
|
||||
if subset == 'A':
|
||||
# 随机遮盖部分 door/monster/entrance(部分上下文补全)
|
||||
func_idx = np.where(np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE]))[0]
|
||||
if len(func_idx) > 0:
|
||||
ratio = random.random() * 0.8 + 0.2 # 20%~100%
|
||||
n = max(1, int(len(func_idx) * ratio))
|
||||
chosen = np.random.choice(func_idx, n, replace=False)
|
||||
inp[chosen] = self.MASK_ID
|
||||
else:
|
||||
# B/C/D:全部 door/monster/entrance → MASK
|
||||
inp[np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE])] = self.MASK_ID
|
||||
|
||||
if subset == 'C':
|
||||
# 额外随机 mask 部分 wall(降低 wall 上下文质量)
|
||||
wall_idx = np.where(inp == self.WALL)[0]
|
||||
if len(wall_idx) > 0:
|
||||
ratio = random.random() * self.wall_mask_ratio
|
||||
n = max(1, int(len(wall_idx) * ratio))
|
||||
chosen = np.random.choice(wall_idx, n, replace=False)
|
||||
inp[chosen] = self.MASK_ID
|
||||
|
||||
# loss_mask:阶段二只对 door/monster/entrance 原始位置计算损失,
|
||||
# 不对被额外 mask 的 wall 位置计算(它们在 target 中已知为 wall)
|
||||
loss_mask = np.isin(raw_flat, [self.DOOR, self.MONSTER, self.ENTRANCE])
|
||||
return inp, target, loss_mask
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 阶段三:资源放置(resource)
|
||||
# ------------------------------------------------------------------
|
||||
def _make_stage3(self, raw_flat: np.ndarray, subset: str):
|
||||
"""
|
||||
阶段三:以完整骨架为上下文,预测 resource 位置。
|
||||
所有 resource 位置在输入中替换为 MASK。
|
||||
子集 A 随机保留部分 resource 作为上下文(部分补全训练),
|
||||
其余子集始终 MASK 全部 resource。
|
||||
"""
|
||||
target = raw_flat.copy()
|
||||
inp = raw_flat.copy()
|
||||
|
||||
if subset == 'A':
|
||||
# 随机遮盖部分 resource(部分上下文补全)
|
||||
res_idx = np.where(inp == self.RESOURCE)[0]
|
||||
if len(res_idx) > 0:
|
||||
ratio = random.random() * 0.8 + 0.2 # 20%~100%
|
||||
n = max(1, int(len(res_idx) * ratio))
|
||||
chosen = np.random.choice(res_idx, n, replace=False)
|
||||
inp[chosen] = self.MASK_ID
|
||||
else:
|
||||
pass # 无 resource 时无需处理
|
||||
else:
|
||||
# B/C/D:全部 resource → MASK
|
||||
inp[inp == self.RESOURCE] = self.MASK_ID
|
||||
|
||||
loss_mask = (inp == self.MASK_ID)
|
||||
return inp, target, loss_mask
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# __getitem__
|
||||
# ------------------------------------------------------------------
|
||||
def _augment(self, arr: np.ndarray) -> np.ndarray:
|
||||
if np.random.rand() > 0.5:
|
||||
k = np.random.randint(1, 4)
|
||||
arr = np.rot90(arr, k).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.fliplr(arr).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.flipud(arr).copy()
|
||||
return arr
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
|
||||
raw_flat = raw_np.reshape(-1) # [H*W]
|
||||
subset = self._choose_subset()
|
||||
|
||||
if self.stage == 1:
|
||||
stage_input_np, target_np, loss_mask_np = self._make_stage1(raw_flat, subset)
|
||||
elif self.stage == 2:
|
||||
stage_input_np, target_np, loss_mask_np = self._make_stage2(raw_flat, subset)
|
||||
else:
|
||||
stage_input_np, target_np, loss_mask_np = self._make_stage3(raw_flat, subset)
|
||||
|
||||
# 若 loss_mask 全为 False(如地图中无 resource 时的 stage3),
|
||||
# 退回为全图损失,避免 NaN
|
||||
if not loss_mask_np.any():
|
||||
loss_mask_np = np.ones_like(loss_mask_np)
|
||||
|
||||
# VQ 切片:当前阶段编码器的输入(仅保留相关 tile)
|
||||
raw_t = torch.LongTensor(raw_flat)
|
||||
vq_keep = self._VQ_KEEP[self.stage]
|
||||
if vq_keep is None:
|
||||
vq_slice = raw_t.clone()
|
||||
else:
|
||||
vq_slice = make_slice(raw_t, vq_keep)
|
||||
|
||||
# 结构标签
|
||||
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
|
||||
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
|
||||
cond_room = item['roomCountLevel']
|
||||
sym_h, sym_v, sym_c = compute_symmetry(map_np)
|
||||
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
|
||||
cond_room = item['roomCountLevel']
|
||||
cond_branch = item['branchLevel']
|
||||
cond_outer = item['outerWall']
|
||||
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
|
||||
cond_outer = item['outerWall']
|
||||
struct_inject = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
|
||||
|
||||
return {
|
||||
"raw_map": raw_t,
|
||||
"vq_slice": vq_slice,
|
||||
"stage_input": torch.LongTensor(stage_input_np),
|
||||
"target_map": torch.LongTensor(target_np),
|
||||
"loss_mask": torch.BoolTensor(loss_mask_np),
|
||||
"subset": subset,
|
||||
"struct_cond": struct_cond,
|
||||
"input_stage1": torch.LongTensor(out[0]),
|
||||
"target_stage1": torch.LongTensor(out[1]),
|
||||
"encoder_stage1": torch.LongTensor(out[2]),
|
||||
"input_stage2": torch.LongTensor(out[3]),
|
||||
"target_stage2": torch.LongTensor(out[4]),
|
||||
"encoder_stage2": torch.LongTensor(out[5]),
|
||||
"input_stage3": torch.LongTensor(out[6]),
|
||||
"target_stage3": torch.LongTensor(out[7]),
|
||||
"encoder_stage3": torch.LongTensor(out[8]),
|
||||
"struct_inject": struct_inject
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json')
|
||||
ds = GinkaVQDataset(data_path)
|
||||
print(f"数据集大小: {len(ds)}")
|
||||
|
||||
subset_count = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
|
||||
for i in range(200):
|
||||
sample = ds[i % len(ds)]
|
||||
subset_count[sample['subset']] += 1
|
||||
|
||||
raw = sample['raw_map']
|
||||
masked = sample['masked_map']
|
||||
target = sample['target_map']
|
||||
print(f"raw_map shape={raw.shape}, dtype={raw.dtype}")
|
||||
print(f"masked_map shape={masked.shape}, dtype={masked.dtype}")
|
||||
print(f"target_map shape={target.shape}, dtype={target.dtype}")
|
||||
print(f"被 mask 的位置数: {(masked == 6).sum().item()} / {masked.numel()}")
|
||||
print(f"\n200 次采样子集分布: {subset_count}")
|
||||
|
||||
@ -4,53 +4,28 @@ import torch.nn as nn
|
||||
from ..utils import print_memory
|
||||
from .maskGIT import Transformer
|
||||
|
||||
# 结构标签词表大小(最后一个索引为无条件占位符 null)
|
||||
SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-6,7 = null
|
||||
ROOM_VOCAB = 4 # roomCountLevel 0-2,3 = null
|
||||
BRANCH_VOCAB = 4 # branchLevel 0-2,3 = null
|
||||
OUTER_VOCAB = 3 # outerWall 0-1,2 = null
|
||||
|
||||
# 结构标签词表大小
|
||||
SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-7
|
||||
ROOM_VOCAB = 3 # roomCountLevel 0-2
|
||||
BRANCH_VOCAB = 3 # branchLevel 0-2
|
||||
OUTER_VOCAB = 2 # outerWall 0-1
|
||||
|
||||
class GinkaMaskGIT(nn.Module):
|
||||
"""
|
||||
改造后的 MaskGIT 地图生成模型。
|
||||
|
||||
以掩码地图序列和 VQ-VAE 输出的离散隐变量 z 为输入,
|
||||
通过 Transformer encoder-decoder 结构预测被遮盖位置的 tile 类别。
|
||||
|
||||
z 通过 cross-attention 注入到 Transformer decoder,
|
||||
作为风格/多样性控制信号,而非结构重建指导。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int = 16,
|
||||
d_model: int = 192,
|
||||
d_z: int = 64,
|
||||
dim_ff: int = 512,
|
||||
nhead: int = 8,
|
||||
num_layers: int = 4,
|
||||
map_size: int = 13 * 13,
|
||||
z_dropout: float = 0.1,
|
||||
struct_dropout: float = 0.15,
|
||||
self, num_classes: int = 16, d_model: int = 192, dim_ff: int = 512,
|
||||
nhead: int = 8, num_layers: int = 4, map_size: int = 13 * 13, d_z: int = 64
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: tile 类别数(含 MASK token=15)
|
||||
d_model: Transformer 内部维度
|
||||
d_z: VQ-VAE 码字嵌入维度,需与 GinkaVQVAE.d_z 一致
|
||||
dim_ff: 前馈网络隐层维度
|
||||
nhead: 注意力头数
|
||||
num_layers: Transformer 层数
|
||||
map_size: 地图 token 总数(H * W)
|
||||
z_dropout: 训练时随机替换 z 为随机码字的概率(提升鲁棒性)
|
||||
struct_dropout: 训练时以此概率将结构标签替换为 null(无条件占位),
|
||||
实现 classifier-free guidance 兼容训练
|
||||
"""
|
||||
super().__init__()
|
||||
self.z_dropout = z_dropout
|
||||
self.struct_dropout_prob = struct_dropout
|
||||
|
||||
|
||||
# Tile 嵌入 + 位置编码
|
||||
self.tile_embedding = nn.Embedding(num_classes, d_model)
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
|
||||
@ -67,10 +42,10 @@ class GinkaMaskGIT(nn.Module):
|
||||
|
||||
# 结构标签嵌入(编码到 d_z 维度)
|
||||
# 注意:结构标签与 VQ 码字语义不同,使用独立投影层避免混用
|
||||
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
|
||||
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
|
||||
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
|
||||
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
|
||||
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
|
||||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||
|
||||
self.struct_proj = nn.Sequential(
|
||||
nn.Linear(d_z, d_model * 2),
|
||||
@ -92,108 +67,75 @@ class GinkaMaskGIT(nn.Module):
|
||||
self,
|
||||
map: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
struct_cond: torch.Tensor | None = None,
|
||||
dropout_struct: bool = False,
|
||||
struct: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15)
|
||||
z: [B, L_total, d_z] VQ-VAE 量化后的离散隐变量;
|
||||
方案 B 中 L_total = L1+L2+L3(三路 z 拼接)
|
||||
struct_cond: [B, 4] 结构标签 LongTensor,顺序为
|
||||
[cond_sym, cond_room, cond_branch, cond_outer];
|
||||
为 None 时等价于全 null(无条件模式)
|
||||
dropout_struct: bool 强制将所有结构标签替换为 null(推理时无条件生成)
|
||||
# map: [B, H * W]
|
||||
# z: [B, L * 3, d_z]
|
||||
# struch: [B, 4]
|
||||
|
||||
Returns:
|
||||
logits: [B, H*W, num_classes]
|
||||
"""
|
||||
B = z.shape[0]
|
||||
|
||||
# z dropout:训练时以一定概率将 z 替换为随机均匀噪声,
|
||||
# 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义
|
||||
if self.training and self.z_dropout > 0:
|
||||
mask = torch.rand(B, 1, 1, device=z.device) < self.z_dropout
|
||||
rand_z = torch.randn_like(z)
|
||||
z = torch.where(mask, rand_z, z)
|
||||
|
||||
# 结构标签嵌入
|
||||
# struct_cond 为 None 或 dropout_struct=True 时,全部使用 null 索引
|
||||
if struct_cond is None or dropout_struct:
|
||||
sym_idx = torch.full((B,), SYM_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||
room_idx = torch.full((B,), ROOM_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||
branch_idx = torch.full((B,), BRANCH_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||
outer_idx = torch.full((B,), OUTER_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||
else:
|
||||
sc = struct_cond.to(z.device)
|
||||
sym_idx, room_idx, branch_idx, outer_idx = sc[:, 0], sc[:, 1], sc[:, 2], sc[:, 3]
|
||||
|
||||
# 训练时对各标签独立做 struct dropout
|
||||
if self.training and self.struct_dropout_prob > 0:
|
||||
def _drop(idx, null_val):
|
||||
drop_mask = torch.rand(B, device=z.device) < self.struct_dropout_prob
|
||||
return torch.where(drop_mask, torch.full_like(idx, null_val), idx)
|
||||
sym_idx = _drop(sym_idx, SYM_VOCAB - 1)
|
||||
room_idx = _drop(room_idx, ROOM_VOCAB - 1)
|
||||
branch_idx = _drop(branch_idx, BRANCH_VOCAB - 1)
|
||||
outer_idx = _drop(outer_idx, OUTER_VOCAB - 1)
|
||||
sym_idx = struct[:, 0]
|
||||
room_idx = struct[:, 1]
|
||||
branch_idx = struct[:, 2]
|
||||
outer_idx = struct[:, 3]
|
||||
|
||||
# 嵌入结构标签到 d_z 维度,拼接到 z 序列末尾
|
||||
e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_branch = self.branch_embed(branch_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
|
||||
struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z]
|
||||
struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z]
|
||||
|
||||
# VQ 码字与结构标签语义不同,使用各自独立的投影层后再拼接
|
||||
z_mem_vq = self.z_proj(z) # [B, L, d_model]
|
||||
z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model]
|
||||
z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L+4, d_model]
|
||||
z_mem_vq = self.z_proj(z) # [B, L, d_model]
|
||||
z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model]
|
||||
z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L * 3 + 4, d_model]
|
||||
|
||||
# tile embedding + 位置编码
|
||||
x = self.tile_embedding(map) # [B, H*W, d_model]
|
||||
x = x + self.pos_embedding # [B, H*W, d_model]
|
||||
x = self.tile_embedding(map) # [B, H * W, d_model]
|
||||
x = x + self.pos_embedding # [B, H * W, d_model]
|
||||
|
||||
# Transformer:encoder 做 map 自注意力,decoder cross-attend z+struct
|
||||
x = self.transformer(x, memory=z_mem) # [B, H*W, d_model]
|
||||
x = self.transformer(x, memory=z_mem) # [B, H * W, d_model]
|
||||
|
||||
logits = self.output_fc(x) # [B, H*W, num_classes]
|
||||
logits = self.output_fc(x) # [B, H * W, num_classes]
|
||||
return logits
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [4, 169]
|
||||
z_input = torch.randn(4, 2, 64).to(device) # [4, 2, 64]
|
||||
struct_input = torch.tensor([
|
||||
[3, 1, 0, 1],
|
||||
[0, 2, 1, 0],
|
||||
[5, 1, 2, 1],
|
||||
[1, 0, 1, 0],
|
||||
], dtype=torch.long).to(device) # [4, 4]
|
||||
|
||||
model = GinkaMaskGIT(
|
||||
num_classes=16,
|
||||
num_classes=7,
|
||||
d_model=192,
|
||||
d_z=64,
|
||||
dim_ff=512,
|
||||
dim_ff=2048,
|
||||
nhead=8,
|
||||
num_layers=4,
|
||||
num_layers=6,
|
||||
map_size=13 * 13,
|
||||
).to(device)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)")
|
||||
for name, module in model.named_children():
|
||||
n = sum(p.numel() for p in module.parameters())
|
||||
print(f" {name}: {n:,}")
|
||||
print_memory(device, "初始化后")
|
||||
|
||||
map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169]
|
||||
z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64]
|
||||
struct_input = torch.tensor([[3, 1, 0, 1],
|
||||
[0, 2, 1, 0],
|
||||
[7, 3, 3, 2],
|
||||
[1, 0, 2, 1]], dtype=torch.long).to(device) # [B=4, 4]
|
||||
|
||||
model.train()
|
||||
logits = model(map_input, z_input, struct_cond=struct_input)
|
||||
print(f"\nlogits shape: {logits.shape}") # [4, 169, 16]
|
||||
|
||||
# 无条件模式测试
|
||||
logits_uncond = model(map_input, z_input, struct_cond=None)
|
||||
print(f"logits_uncond shape: {logits_uncond.shape}") # [4, 169, 16]
|
||||
start = time.perf_counter()
|
||||
logits = model(map_input, z_input, struct_input)
|
||||
end = time.perf_counter()
|
||||
|
||||
print_memory(device, "前向传播后")
|
||||
|
||||
print(f"推理耗时: {end - start:.4f}s")
|
||||
print(f"输出形状: logits={logits.shape}")
|
||||
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
|
||||
print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}")
|
||||
print(f"Struct Projection parameters: {sum(p.numel() for p in model.struct_proj.parameters())}")
|
||||
print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}")
|
||||
print(f"Output FC parameters: {sum(p.numel() for p in model.output_fc.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
666
ginka/train_seperated.py
Normal file
666
ginka/train_seperated.py
Normal file
@ -0,0 +1,666 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .vqvae.quantize import VectorQuantizer
|
||||
from .vqvae.model import GinkaVQVAE
|
||||
from .maskGIT.model import GinkaMaskGIT
|
||||
from .dataset import GinkaSeperatedDataset
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
# 三阶段级联地图生成训练脚本
|
||||
#
|
||||
# 整体架构:
|
||||
# VQ-VAE(三组独立编码器 vq1/vq2/vq3)将三阶段地图上下文分别编码为离散潜变量,
|
||||
# 再由共用 VectorQuantizer 统一量化为 z_q;
|
||||
# 三个独立 MaskGIT(mg1/mg2/mg3)分别以 z_q 和 struct_inject 为条件,
|
||||
# 逐阶段迭代解码地图图块序列。
|
||||
#
|
||||
# 三阶段生成目标:
|
||||
# stage1 → floor / wall(地图骨架)
|
||||
# stage2 → door / monster / entrance(功能性实体)
|
||||
# stage3 → resource(资源点)
|
||||
|
||||
# 图块 ID 定义:
|
||||
# 0. 空地 1. 墙壁 2. 门 3. 资源 4. 怪物 5. 入口 6. 掩码(MASK_TOKEN)
|
||||
|
||||
# 共用 VQ-VAE 超参
|
||||
# 三组编码器(vq1/vq2/vq3)共享相同超参,分别对三阶段地图上下文独立编码
|
||||
VQ_L = 2 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3)
|
||||
VQ_K = 8 # codebook 大小(离散码本条目数)
|
||||
VQ_D_Z = 64 # 码字维度
|
||||
VQ_BETA = 0.5 # commit loss 权重(防止编码器输出漂离 codebook)
|
||||
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
|
||||
VQ_LAYERS = 3 # VQ-VAE Transformer 层数
|
||||
VQ_DIM_FF = 512 # VQ-VAE 前馈网络隐层维度
|
||||
VQ_D_MODEL = 64 # VQ-VAE Transformer 模型维度
|
||||
VQ_NHEAD = 8 # VQ-VAE 多头注意力头数
|
||||
|
||||
# 第一阶段 MaskGIT 超参
|
||||
STAGE1_MG_DMODEL = 192
|
||||
STAGE1_MG_NHEAD = 8
|
||||
STAGE1_MG_NUM_LAYERS = 6
|
||||
STAGE1_MG_DIM_FF = 1024
|
||||
|
||||
# 第二阶段 MaskGIT 超参
|
||||
STAGE2_MG_DMODEL = 192
|
||||
STAGE2_MG_NHEAD = 8
|
||||
STAGE2_MG_NUM_LAYERS = 6
|
||||
STAGE2_MG_DIM_FF = 1024
|
||||
|
||||
# 第三阶段 MaskGIT 超参
|
||||
STAGE3_MG_DMODEL = 192
|
||||
STAGE3_MG_NHEAD = 8
|
||||
STAGE3_MG_NUM_LAYERS = 6
|
||||
STAGE3_MG_DIM_FF = 1024
|
||||
|
||||
# 三阶段 Focal Loss 损失权重(可调节各阶段对总损失的贡献比例)
|
||||
STAGE1_FOCAL_WEIGHT = 1.0
|
||||
STAGE2_FOCAL_WEIGHT = 1.0
|
||||
STAGE3_FOCAL_WEIGHT = 1.0
|
||||
|
||||
# 各阶段 VQ commit loss 权重(当前未单独使用,统一由 VQ_BETA 控制)
|
||||
STAGE1_VQ_WEIGHT = 0.5
|
||||
STAGE2_VQ_WEIGHT = 0.5
|
||||
STAGE3_VQ_WEIGHT = 0.5
|
||||
|
||||
# 全局参数
|
||||
NUM_CLASSES = 7 # 图块类型数
|
||||
MASK_TOKEN = 6 # 掩码图块
|
||||
MAP_W = 13 # 地图宽度
|
||||
MAP_H = 13 # 地图高度
|
||||
MAP_SIZE = MAP_W * MAP_H # 地图大小
|
||||
GENERATE_STEP = 18 # MaskGIT 采样步数
|
||||
SUBSET2_WALL_PROB = 0.7 # 子集2 进行墙壁掩码的概率
|
||||
SUBSET_WEIGHTS = (0.5, 0.3, 0.2) # 每个子集的概率
|
||||
|
||||
MG_Z_DROPOUT = 0.1 # z 隐变量 Dropout 概率
|
||||
MG_STRUCT_DROPOUT = 0.1 # 结构参量 Dropout 概率
|
||||
|
||||
# 损失参数
|
||||
FOCAL_GAMMA = 2.0 # Focal Loss 参数
|
||||
VQ_BETA = 0.5 # 承诺损失权重
|
||||
|
||||
# 训练超参
|
||||
BATCH_SIZE = 64 # 每批样本数
|
||||
LR = 1e-4 # AdamW 初始学习率
|
||||
MIN_LR = 1e-6 # 余弦退火最低学习率
|
||||
WEIGHT_DECAY = 1e-4 # L2 正则化系数
|
||||
EPOCHS = 400 # 总训练轮数
|
||||
CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
def _str2bool(v: str):
|
||||
if isinstance(v, bool): return v
|
||||
if v.lower() in ('true', '1', 'yes'): return True
|
||||
if v.lower() in ('false', '0', 'no'): return False
|
||||
raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}")
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="三阶段级联训练")
|
||||
parser.add_argument("--resume", type=_str2bool, default=False)
|
||||
parser.add_argument("--state", type=str, default="", help="续训时检查点路径")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||
parser.add_argument("--load_optim", type=_str2bool, default=True)
|
||||
return parser.parse_args()
|
||||
|
||||
def build_model(device: torch.device):
|
||||
# 三组 VQ-VAE 编码器:各自独立编码一个阶段的地图上下文(encoder_stage1/2/3)
|
||||
# 输出形状均为 [B, L, d_z],拼接后送入共用 quantizer
|
||||
vq_kwargs = dict(
|
||||
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_model=VQ_D_MODEL,
|
||||
nhead=VQ_NHEAD, num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_size=MAP_SIZE
|
||||
)
|
||||
vq1 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage1 上下文(floor/wall)
|
||||
vq2 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage2 上下文(door/monster/entrance)
|
||||
vq3 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage3 上下文(resource)
|
||||
|
||||
# 三个独立 MaskGIT 解码器,均接收完整的三阶段 z_q 作为条件
|
||||
mg1 = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF,
|
||||
nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_size=MAP_SIZE
|
||||
).to(device)
|
||||
mg2 = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF,
|
||||
nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_size=MAP_SIZE
|
||||
).to(device)
|
||||
mg3 = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF,
|
||||
nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_size=MAP_SIZE
|
||||
).to(device)
|
||||
|
||||
# 六个模型参数合并到同一优化器,端到端联合训练
|
||||
all_params = (
|
||||
list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) +
|
||||
list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters())
|
||||
)
|
||||
optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4)
|
||||
# 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
||||
|
||||
# 共用 VectorQuantizer:不参与梯度更新,仅在前向时做码本查表
|
||||
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z)
|
||||
|
||||
return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler
|
||||
|
||||
def focal_loss(logits, target):
|
||||
# logits: [B, L, C],需转为 [B, C, L] 以匹配 cross_entropy 期望格式
|
||||
ce = F.cross_entropy(logits.permute(0, 2, 1), target, reduction='none')
|
||||
pt = torch.exp(-ce) # pt = 模型对正确类的预测概率
|
||||
# Focal Loss:对高置信度样本降低权重,让模型更专注于难样本
|
||||
focal = ((1 - pt) ** FOCAL_GAMMA) * ce
|
||||
return focal.mean()
|
||||
|
||||
def random_struct(device: torch.device) -> torch.Tensor:
|
||||
# 随机采样一组结构参量,用于无条件自由生成
|
||||
# struct_inject 格式:[cond_sym(0-7), cond_room(0-2), cond_branch(0-2), cond_outer(0-1)]
|
||||
cond_sym = random.randint(0, 7) # 地图对称类型
|
||||
cond_room = random.randint(0, 2) # 房间数量档位
|
||||
cond_branch = random.randint(0, 2) # 分支复杂度档位
|
||||
cond_outer = random.randint(0, 1) # 是否有外围走廊
|
||||
return torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]).unsqueeze(0).to(device)
|
||||
|
||||
def maskgit_sample(
|
||||
model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor,
|
||||
struct: torch.Tensor, steps: int
|
||||
) -> np.ndarray:
|
||||
current = inp.clone()
|
||||
|
||||
# 迭代去掩码:每步根据置信度分数重新决定掩码位置
|
||||
for step in range(steps):
|
||||
logits = model(current, z, struct)
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
||||
dist = torch.distributions.Categorical(probs)
|
||||
sampled = dist.sample()
|
||||
|
||||
confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
# 余弦退火调度:随步数推进,保留掩码的位置数量递减至 0
|
||||
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
|
||||
num_to_mask = math.floor(ratio * MAP_SIZE)
|
||||
|
||||
# 输入中已有的非掩码位(来自上一阶段)保持不变
|
||||
fixed_mask = (current[0] != MASK_TOKEN)
|
||||
sampled[0, fixed_mask] = current[0, fixed_mask]
|
||||
confidences[0, fixed_mask] = 1.0
|
||||
|
||||
if num_to_mask > 0:
|
||||
# 将置信度最低的位重新掩码,留待下一步重新预测
|
||||
_, mask_indices = torch.topk(confidences[0], k=num_to_mask, largest=False)
|
||||
sampled[0].scatter_(0, mask_indices, MASK_TOKEN)
|
||||
|
||||
current = sampled
|
||||
|
||||
if (current[0] == MASK_TOKEN).sum() == 0:
|
||||
break
|
||||
|
||||
# 兜底:若仍有残余掩码位(理论上不应发生),用 argmax 确定性填充
|
||||
still_masked = (current[0] == MASK_TOKEN)
|
||||
if still_masked.any():
|
||||
logits = model(current, z, struct)
|
||||
current[0, still_masked] = torch.argmax(logits[0, still_masked], dim=-1)
|
||||
|
||||
return current[0].cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
|
||||
def full_generate_random_z(
|
||||
input: torch.Tensor,
|
||||
struct: torch.Tensor,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device
|
||||
) -> tuple:
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
|
||||
with torch.no_grad():
|
||||
z = quantizer.sample(1, VQ_L, device)
|
||||
|
||||
# stage1:生成 floor/wall 骨架
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP)
|
||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
|
||||
|
||||
# stage2:在骨架上生成 door/monster/entrance,非零结果覆盖合并
|
||||
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP)
|
||||
merged12 = pred1_np.copy()
|
||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
||||
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp3[inp3 == 0] = MASK_TOKEN
|
||||
|
||||
# stage3:填充 resource
|
||||
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP)
|
||||
merged123 = merged12.copy()
|
||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
||||
|
||||
return pred1_np, merged12, merged123
|
||||
|
||||
def full_generate_specific_z(
|
||||
input: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
struct: torch.Tensor,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device
|
||||
) -> tuple:
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
|
||||
with torch.no_grad():
|
||||
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP)
|
||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp2[inp2 == 0] = MASK_TOKEN
|
||||
|
||||
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP)
|
||||
merged12 = pred1_np.copy()
|
||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
||||
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp3[inp3 == 0] = MASK_TOKEN
|
||||
|
||||
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP)
|
||||
merged123 = merged12.copy()
|
||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
||||
|
||||
return pred1_np, merged12, merged123
|
||||
|
||||
# 验证可视化 part1:3×3 网格;行1=编码器输入,行2=掩码输入,行3=三阶段预测(合并)
|
||||
def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
img_w = MAP_W * TILE_SIZE
|
||||
|
||||
def to_img(mat):
|
||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
||||
|
||||
pred1 = torch.argmax(logits1[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
pred2 = torch.argmax(logits2[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
pred3 = torch.argmax(logits3[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
|
||||
pred3_merged = pred1.copy()
|
||||
pred3_merged[pred2 != 0] = pred2[pred2 != 0]
|
||||
pred3_merged[pred3 != 0] = pred3[pred3 != 0]
|
||||
|
||||
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
inp2_np = batch["input_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
inp3_np = batch["input_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
|
||||
rows = [
|
||||
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
|
||||
[to_img(inp1_np), to_img(inp2_np), to_img(inp3_np)],
|
||||
[to_img(pred1), to_img(pred2), to_img(pred3_merged)],
|
||||
]
|
||||
grid = np.ones((3 * img_h + 4 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
for c, img in enumerate(row):
|
||||
y = SEP + r * (img_h + SEP)
|
||||
x = SEP + c * (img_w + SEP)
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
|
||||
# 验证可视化 part2:行1=真实地图三阶段,行2=stage1 输入与使用真实 z 自回归生成的各阶段结果
|
||||
def visualize_part2(batch, z_q, models, device, tile_dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
img_w = MAP_W * TILE_SIZE
|
||||
|
||||
def to_img(mat):
|
||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
||||
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_t = batch["struct_inject"][0:1].to(device)
|
||||
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
|
||||
inp1_t, z_q[0:1], struct_t, models, device
|
||||
)
|
||||
|
||||
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
|
||||
rows = [
|
||||
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
|
||||
[to_img(inp1_np), to_img(auto_pred1_np), to_img(auto_merged12), to_img(auto_merged123)],
|
||||
]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
for c, img in enumerate(row):
|
||||
y = SEP + r * (img_h + SEP)
|
||||
x = SEP + c * (img_w + SEP)
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
|
||||
# 验证可视化 part3:2×3 网格;行1=参考输入+相同 struct 随机 z 生成,行2=随机 struct 生成
|
||||
def visualize_part3(batch, models, device, tile_dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
img_w = MAP_W * TILE_SIZE
|
||||
|
||||
def to_img(mat):
|
||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
||||
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_ref = batch["struct_inject"][0:1].to(device)
|
||||
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
|
||||
row1 = [to_img(inp1_np)]
|
||||
for _ in range(2):
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device)
|
||||
row1.append(to_img(merged123))
|
||||
|
||||
row2 = []
|
||||
for _ in range(3):
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device)
|
||||
row2.append(to_img(merged123))
|
||||
|
||||
rows = [row1, row2]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
for c, img in enumerate(row):
|
||||
y = SEP + r * (img_h + SEP)
|
||||
x = SEP + c * (img_w + SEP)
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
|
||||
# 验证可视化 part4:2×3 网格;以少量随机墙壁作为种子,纯随机 struct+z 自由生成
|
||||
def visualize_part4(models, device, tile_dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
img_w = MAP_W * TILE_SIZE
|
||||
|
||||
def to_img(mat):
|
||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
||||
|
||||
n_walls = random.randint(math.floor(MAP_SIZE * 0.02), math.floor(MAP_SIZE * 0.06))
|
||||
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
|
||||
wall_pos = torch.randperm(MAP_SIZE, device=device)[:n_walls]
|
||||
seed[0, wall_pos] = 1
|
||||
seed_np = seed[0].cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
|
||||
results = []
|
||||
for _ in range(5):
|
||||
_, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device)
|
||||
results.append(to_img(merged123))
|
||||
|
||||
row1 = [to_img(seed_np)] + results[:2]
|
||||
row2 = results[2:]
|
||||
rows = [row1, row2]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
for c, img in enumerate(row):
|
||||
y = SEP + r * (img_h + SEP)
|
||||
x = SEP + c * (img_w + SEP)
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
|
||||
def visualize_validate(
|
||||
batch, logits1, logits2, logits3, z_q,
|
||||
models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int, batch_idx: int
|
||||
):
|
||||
save_dir = f"result/seperated/e{epoch}"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict))
|
||||
|
||||
def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int):
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
|
||||
# 切换为推理模式(关闭 Dropout / BatchNorm 统计更新)
|
||||
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
||||
m.eval()
|
||||
|
||||
# 累计各阶段损失(跨所有 batch 求和,最终除以 batch 数得到均值)
|
||||
loss1_total = torch.Tensor([0]).to(device)
|
||||
loss2_total = torch.Tensor([0]).to(device)
|
||||
loss3_total = torch.Tensor([0]).to(device)
|
||||
commit_total = torch.Tensor([0]).to(device)
|
||||
|
||||
idx = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, leave=False, desc="Validate Progress", disable=disable_tqdm):
|
||||
# 三阶段各自的掩码输入、预测目标和 VQ 编码器输入
|
||||
inp1 = batch["input_stage1"].to(device).reshape(-1, MAP_SIZE)
|
||||
target1 = batch["target_stage1"].to(device).reshape(-1, MAP_SIZE)
|
||||
enc1 = batch["encoder_stage1"].to(device).reshape(-1, MAP_SIZE)
|
||||
|
||||
inp2 = batch["input_stage2"].to(device).reshape(-1, MAP_SIZE)
|
||||
target2 = batch["target_stage2"].to(device).reshape(-1, MAP_SIZE)
|
||||
enc2 = batch["encoder_stage2"].to(device).reshape(-1, MAP_SIZE)
|
||||
|
||||
inp3 = batch["input_stage3"].to(device).reshape(-1, MAP_SIZE)
|
||||
target3 = batch["target_stage3"].to(device).reshape(-1, MAP_SIZE)
|
||||
enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE)
|
||||
|
||||
struct = batch["struct_inject"].to(device)
|
||||
|
||||
# VQ 编码:各阶段独立编码后拼接、量化
|
||||
z_e1 = vq1(enc1) # [B, L, d_z]
|
||||
z_e2 = vq2(enc2)
|
||||
z_e3 = vq3(enc3)
|
||||
|
||||
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
|
||||
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
|
||||
|
||||
# 三阶段 MaskGIT 推理(均以完整 z_q 和 struct 为条件)
|
||||
logits1 = mg1(inp1, z_q, struct)
|
||||
logits2 = mg2(inp2, z_q, struct)
|
||||
logits3 = mg3(inp3, z_q, struct)
|
||||
|
||||
loss1_total += focal_loss(logits1, target1)
|
||||
loss2_total += focal_loss(logits2, target2)
|
||||
loss3_total += focal_loss(logits3, target3)
|
||||
commit_total += commit_loss
|
||||
|
||||
# 每个 batch 生成三种可视化图(val/full/rand)
|
||||
visualize_validate(batch, logits1, logits2, logits3, z_q, models, device, tile_dict, epoch, idx)
|
||||
idx += 1
|
||||
|
||||
# 每个 epoch 额外生成一张无条件自由生成图(不依赖任何 batch 样本)
|
||||
save_dir = f"result/seperated/e{epoch}"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict))
|
||||
|
||||
# 恢复训练模式
|
||||
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
||||
m.train()
|
||||
|
||||
return loss1_total, loss2_total, loss3_total, commit_total
|
||||
|
||||
def train(device: torch.device):
|
||||
args = parse_arguments()
|
||||
|
||||
models = build_model(device)
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
|
||||
start_epoch = 0
|
||||
|
||||
if args.resume:
|
||||
# 从指定检查点恢复:加载所有模型权重及训练状态
|
||||
ckpt = torch.load(args.state, map_location=device)
|
||||
vq1.load_state_dict(ckpt["vq1"])
|
||||
vq2.load_state_dict(ckpt["vq2"])
|
||||
vq3.load_state_dict(ckpt["vq3"])
|
||||
mg1.load_state_dict(ckpt["mg1"])
|
||||
mg2.load_state_dict(ckpt["mg2"])
|
||||
mg3.load_state_dict(ckpt["mg3"])
|
||||
quantizer.load_state_dict(ckpt["quantizer"])
|
||||
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
|
||||
if args.load_optim and "optimizer" in ckpt:
|
||||
optimizer.load_state_dict(ckpt["optimizer"])
|
||||
if args.load_optim and "scheduler" in ckpt:
|
||||
scheduler.load_state_dict(ckpt["scheduler"])
|
||||
start_epoch = ckpt.get("epoch", 0) # 从上次保存的 epoch 继续
|
||||
tqdm.write(f"Resumed from epoch {start_epoch}: {args.state}")
|
||||
|
||||
os.makedirs("result/seperated", exist_ok=True)
|
||||
|
||||
dataset = GinkaSeperatedDataset(
|
||||
args.train, subset_weights=SUBSET_WEIGHTS, subset2_wall_prob=SUBSET2_WALL_PROB
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=BATCH_SIZE, shuffle=True
|
||||
)
|
||||
|
||||
dataset_val = GinkaSeperatedDataset(
|
||||
args.validate, subset_weights=SUBSET_WEIGHTS, subset2_wall_prob=SUBSET2_WALL_PROB
|
||||
)
|
||||
dataloader_val = DataLoader(
|
||||
dataset_val, batch_size=min(BATCH_SIZE, len(dataset_val) // 8), shuffle=True
|
||||
)
|
||||
|
||||
# 预加载图块图像,键为文件名(不含扩展名),用于可视化时将 ID 映射为像素图
|
||||
tile_dict = {}
|
||||
for f in os.listdir("tiles"):
|
||||
name = os.path.splitext(f)[0]
|
||||
img = cv2.imread(f"tiles/{f}", cv2.IMREAD_UNCHANGED)
|
||||
if img is not None:
|
||||
tile_dict[name] = img
|
||||
|
||||
for epoch in tqdm(range(start_epoch, EPOCHS), desc="Seperated Training", disable=disable_tqdm):
|
||||
loss_total = torch.Tensor([0]).to(device)
|
||||
loss1_total = torch.Tensor([0]).to(device)
|
||||
loss2_total = torch.Tensor([0]).to(device)
|
||||
loss3_total = torch.Tensor([0]).to(device)
|
||||
commit_total = torch.Tensor([0]).to(device)
|
||||
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
# 三阶段各自的掩码输入序列、预测目标和编码器上下文
|
||||
inp1 = batch["input_stage1"].to(device).reshape(-1, MAP_SIZE)
|
||||
target1 = batch["target_stage1"].to(device).reshape(-1, MAP_SIZE)
|
||||
enc1 = batch["encoder_stage1"].to(device).reshape(-1, MAP_SIZE)
|
||||
|
||||
inp2 = batch["input_stage2"].to(device).reshape(-1, MAP_SIZE)
|
||||
target2 = batch["target_stage2"].to(device).reshape(-1, MAP_SIZE)
|
||||
enc2 = batch["encoder_stage2"].to(device).reshape(-1, MAP_SIZE)
|
||||
|
||||
inp3 = batch["input_stage3"].to(device).reshape(-1, MAP_SIZE)
|
||||
target3 = batch["target_stage3"].to(device).reshape(-1, MAP_SIZE)
|
||||
enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE)
|
||||
|
||||
# 结构条件向量:[cond_sym, cond_room, cond_branch, cond_outer]
|
||||
struct = batch["struct_inject"].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# VQ 编码:各阶段编码器分别处理各自上下文切片
|
||||
z_e1 = vq1(enc1) # [B, L, d_z]
|
||||
z_e2 = vq2(enc2)
|
||||
z_e3 = vq3(enc3)
|
||||
|
||||
# 合并三阶段编码后量化
|
||||
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
|
||||
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
|
||||
|
||||
# 三阶段 MaskGIT 前向(均接收完整三阶段 z_q)
|
||||
logits1 = mg1(inp1, z_q, struct)
|
||||
logits2 = mg2(inp2, z_q, struct)
|
||||
logits3 = mg3(inp3, z_q, struct)
|
||||
|
||||
# 三阶段 Focal Loss + VQ commit loss 加权求和
|
||||
loss1 = focal_loss(logits1, target1)
|
||||
loss2 = focal_loss(logits2, target2)
|
||||
loss3 = focal_loss(logits3, target3)
|
||||
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1
|
||||
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2
|
||||
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3
|
||||
commit_weighted = VQ_BETA * commit_loss
|
||||
loss = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# detach 后累加,避免保留计算图占用显存
|
||||
loss_total += loss.detach()
|
||||
loss1_total += loss1.detach()
|
||||
loss2_total += loss2.detach()
|
||||
loss3_total += loss3.detach()
|
||||
commit_total += commit_loss.detach()
|
||||
|
||||
# 每个 epoch 结束后更新学习率
|
||||
scheduler.step()
|
||||
|
||||
data_length = len(dataloader)
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
|
||||
f"L1: {loss1_total.item() / data_length:.6f} | "
|
||||
f"L2: {loss2_total.item() / data_length:.6f} | "
|
||||
f"L3: {loss3_total.item() / data_length:.6f} | "
|
||||
f"VQ: {commit_total.item() / data_length:.6f} | "
|
||||
f"LR: {scheduler.get_last_lr()[0]:.6f}"
|
||||
)
|
||||
|
||||
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
|
||||
if (epoch + 1) % CHECKPOINT == 0:
|
||||
losses = validate(dataloader_val, models, device, tile_dict, epoch + 1)
|
||||
loss1_total, loss2_total, loss3_total, commit_total = losses
|
||||
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1_total
|
||||
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2_total
|
||||
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3_total
|
||||
commit_weighted = VQ_BETA * commit_total
|
||||
loss_total = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
|
||||
|
||||
data_length = len(dataloader_val)
|
||||
tqdm.write(
|
||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
|
||||
f"L1: {loss1_total.item() / data_length:.6f} | "
|
||||
f"L2: {loss2_total.item() / data_length:.6f} | "
|
||||
f"L3: {loss3_total.item() / data_length:.6f} | "
|
||||
f"VQ: {commit_total.item() / data_length:.6f} | "
|
||||
)
|
||||
|
||||
ckpt_path = f"result/seperated/sep-{epoch + 1}.pth"
|
||||
torch.save({
|
||||
"epoch": epoch + 1,
|
||||
"vq1": vq1.state_dict(),
|
||||
"vq2": vq2.state_dict(),
|
||||
"vq3": vq3.state_dict(),
|
||||
"mg1": mg1.state_dict(),
|
||||
"mg2": mg2.state_dict(),
|
||||
"mg3": mg3.state_dict(),
|
||||
"quantizer": quantizer.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
}, ckpt_path)
|
||||
tqdm.write(f"Saved checkpoint: {ckpt_path}")
|
||||
|
||||
# 训练结束后保存最终完整权重(含优化器状态,可用于后续续训或推理)
|
||||
final_path = "result/seperated.pth"
|
||||
torch.save({
|
||||
"epoch": EPOCHS,
|
||||
"vq1": vq1.state_dict(),
|
||||
"vq2": vq2.state_dict(),
|
||||
"vq3": vq3.state_dict(),
|
||||
"mg1": mg1.state_dict(),
|
||||
"mg2": mg2.state_dict(),
|
||||
"mg3": mg3.state_dict(),
|
||||
"quantizer": quantizer.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
}, final_path)
|
||||
tqdm.write(f"Training complete. Final model saved: {final_path}")
|
||||
@ -285,26 +285,26 @@ def make_random_struct_cond():
|
||||
|
||||
def make_stage_init(stage: int, context_map: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
根据阶段构造 MaskGIT 的推理初始地图。
|
||||
根据阶段构造 MaskGIT 的推理初始地图(与训练端掩码策略一致)。
|
||||
|
||||
Stage 1: 全 MASK(或保留稀疏 wall 种子)
|
||||
Stage 2: 保留 floor/wall 上下文,其余 → MASK
|
||||
Stage 3: 保留完整上下文(floor/wall/door/monster/entrance),resource → MASK
|
||||
Stage 1: 全 MASK
|
||||
Stage 2: 只保留 wall(1),floor + 功能元素 → MASK
|
||||
Stage 3: 保留 wall(1)/door(2)/monster(4)/entrance(5),floor + resource → MASK
|
||||
"""
|
||||
init = context_map.clone()
|
||||
|
||||
if stage == 1:
|
||||
# 全 MASK(不依赖上下文地图)
|
||||
init = torch.full_like(init, MASK_TOKEN)
|
||||
|
||||
elif stage == 2:
|
||||
# 保留 floor/wall,其余 → MASK
|
||||
mask = ~torch.isin(init, torch.tensor([0, 1], device=init.device))
|
||||
init[mask] = MASK_TOKEN
|
||||
# 只保留 wall,其余全部 → MASK
|
||||
keep = torch.isin(init, torch.tensor([1], device=init.device))
|
||||
init[~keep] = MASK_TOKEN
|
||||
|
||||
else: # stage == 3
|
||||
# 保留非 resource,resource → MASK
|
||||
init[init == 3] = MASK_TOKEN
|
||||
# 保留 wall + 功能元素,floor + resource → MASK
|
||||
keep = torch.isin(init, torch.tensor([1, 2, 4, 5], device=init.device))
|
||||
init[~keep] = MASK_TOKEN
|
||||
|
||||
return init
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .quantize import VectorQuantizer
|
||||
from typing import Tuple
|
||||
|
||||
from ..utils import print_memory
|
||||
|
||||
class _DecodeLayer(nn.Module):
|
||||
"""单个解码层:Pre-LN Cross-Attention + Pre-LN FFN。"""
|
||||
@ -97,62 +98,13 @@ class VQDecodeHead(nn.Module):
|
||||
|
||||
|
||||
class GinkaVQVAE(nn.Module):
|
||||
"""
|
||||
VQ-VAE 风格地图编码器。
|
||||
|
||||
将一张完整的地图([B, H*W] 整数 tile ID 序列)编码为 L 个离散码字,
|
||||
输出 z [B, L, d_z] 作为 MaskGIT 模型的生成条件。
|
||||
|
||||
架构:
|
||||
tile embedding + 位置编码
|
||||
→ L 个可学习 summary token(拼接到序列头部)
|
||||
→ Transformer Encoder(Pre-LN,自注意力)
|
||||
→ 取前 L 个输出
|
||||
→ 线性投影到 d_z
|
||||
→ VectorQuantizer(直通估计 + 熵最大化正则)
|
||||
|
||||
设计约束:
|
||||
- 参数量目标 < 1M
|
||||
- 不含解码器,z 的语义由 MaskGIT 端的交叉熵损失间接约束
|
||||
- z 定位为风格/多样性控制信号,而非结构重建指导
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int = 16,
|
||||
L: int = 2,
|
||||
K: int = 16,
|
||||
d_z: int = 64,
|
||||
d_model: int = 128,
|
||||
nhead: int = 4,
|
||||
num_layers: int = 2,
|
||||
dim_ff: int = 256,
|
||||
map_size: int = 13 * 13,
|
||||
beta: float = 0.25,
|
||||
gamma: float = 0.1,
|
||||
vq_temp: float = 1.0,
|
||||
self, num_classes: int = 16, L: int = 2, K: int = 16, d_z: int = 64, d_model: int = 128,
|
||||
nhead: int = 4, num_layers: int = 2, dim_ff: int = 256, map_size: int = 13 * 13
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: tile 类别数(含 MASK token)
|
||||
L: 码字序列长度,即 z 的序列维度
|
||||
K: codebook 大小(码字总数)
|
||||
d_z: 码字嵌入维度
|
||||
d_model: Transformer 内部维度
|
||||
nhead: 注意力头数
|
||||
num_layers: Transformer 层数
|
||||
dim_ff: 前馈网络隐层维度
|
||||
map_size: 地图 token 总数(H * W)
|
||||
beta: 承诺损失权重
|
||||
gamma: 熵正则损失权重
|
||||
vq_temp: VQ 软分配 softmax 温度
|
||||
"""
|
||||
super().__init__()
|
||||
self.L = L
|
||||
self.K = K
|
||||
self.d_z = d_z
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
|
||||
# Tile 嵌入
|
||||
self.tile_embedding = nn.Embedding(num_classes, d_model)
|
||||
@ -165,13 +117,8 @@ class GinkaVQVAE(nn.Module):
|
||||
|
||||
# Pre-LN Transformer Encoder(训练更稳定)
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_ff,
|
||||
batch_first=True,
|
||||
activation='gelu',
|
||||
norm_first=True, # Pre-LN
|
||||
dropout=0.0,
|
||||
d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True,
|
||||
activation='gelu', norm_first=True, dropout=0.1,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
@ -181,127 +128,43 @@ class GinkaVQVAE(nn.Module):
|
||||
nn.LayerNorm(d_z),
|
||||
)
|
||||
|
||||
# 向量量化层
|
||||
self.vq = VectorQuantizer(K=K, d_z=d_z, temp=vq_temp)
|
||||
|
||||
def encode(self, map: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
将地图编码为量化前的连续向量序列。
|
||||
|
||||
Args:
|
||||
map: [B, H*W] 整数 tile ID
|
||||
|
||||
Returns:
|
||||
z_e: [B, L, d_z] 量化前的编码向量
|
||||
"""
|
||||
B = map.shape[0]
|
||||
|
||||
x = self.tile_embedding(map) # [B, H*W, d_model]
|
||||
x = x + self.pos_embedding # [B, H*W, d_model]
|
||||
|
||||
summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model]
|
||||
x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model]
|
||||
|
||||
x = self.transformer(x) # [B, L+H*W, d_model]
|
||||
|
||||
z_e = self.proj(x[:, :self.L]) # [B, L, d_z]
|
||||
return z_e
|
||||
|
||||
def encode_soft(self, soft_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
将软嵌入序列编码为量化前的连续向量序列(用于一致性约束)。
|
||||
|
||||
与 encode() 的区别:输入是已经过 softmax 加权求和得到的连续嵌入矩阵
|
||||
[B, H*W, d_model],而非整数 tile ID。梯度可完整回传到调用方的 logits。
|
||||
|
||||
Args:
|
||||
soft_emb: [B, H*W, d_model] softmax 加权 tile 嵌入(已在 d_model 空间)
|
||||
|
||||
Returns:
|
||||
z_e: [B, L, d_z] 量化前的编码向量
|
||||
"""
|
||||
B = soft_emb.shape[0]
|
||||
|
||||
x = soft_emb + self.pos_embedding # [B, H*W, d_model]
|
||||
|
||||
summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model]
|
||||
x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model]
|
||||
|
||||
x = self.transformer(x) # [B, L+H*W, d_model]
|
||||
|
||||
z_e = self.proj(x[:, :self.L]) # [B, L, d_z]
|
||||
return z_e
|
||||
|
||||
def forward(self, map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
完整前向传播:编码 → 量化 → 计算损失。
|
||||
|
||||
Args:
|
||||
map: [B, H*W] 整数 tile ID(训练时传入完整真实地图)
|
||||
|
||||
Returns:
|
||||
z_q: [B, L, d_z] 量化后的 z(含直通梯度),供 MaskGIT 使用
|
||||
z_e: [B, L, d_z] 量化前的连续编码向量,供一致性约束使用
|
||||
indices: [B, L] 每个位置对应的码字索引
|
||||
vq_loss: scalar VQ 总损失 = beta * commit_loss + gamma * entropy_loss
|
||||
commit_loss: scalar
|
||||
entropy_loss: scalar
|
||||
"""
|
||||
z_e = self.encode(map)
|
||||
z_q, indices, commit_loss, entropy_loss = self.vq(z_e)
|
||||
|
||||
vq_loss = self.beta * commit_loss + self.gamma * entropy_loss
|
||||
return z_q, z_e, indices, vq_loss, commit_loss, entropy_loss
|
||||
|
||||
def sample(self, B: int, device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
推理阶段:从 codebook 中随机均匀采样 L 个码字。
|
||||
|
||||
Args:
|
||||
B: batch size
|
||||
device: 目标设备
|
||||
|
||||
Returns:
|
||||
z: [B, L, d_z]
|
||||
"""
|
||||
indices = torch.randint(0, self.K, (B, self.L), device=device)
|
||||
z = self.vq.codebook(indices) # [B, L, d_z]
|
||||
return z
|
||||
# map: [B, H * W]
|
||||
B, L = map.shape
|
||||
|
||||
x = self.tile_embedding(map) # [B, H * W, d_model]
|
||||
x = x + self.pos_embedding # [B, H * W, d_model]
|
||||
|
||||
summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model]
|
||||
x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model]
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
z_e = self.proj(x[:, :self.L])
|
||||
|
||||
return z_e
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [B=4, 169]
|
||||
|
||||
model = GinkaVQVAE(
|
||||
num_classes=16,
|
||||
L=2,
|
||||
K=16,
|
||||
d_z=64,
|
||||
d_model=128,
|
||||
nhead=4,
|
||||
num_layers=2,
|
||||
dim_ff=256,
|
||||
map_size=13 * 13,
|
||||
num_classes=7, L=2, K=16, d_z=64, d_model=128,
|
||||
nhead=4, num_layers=2, dim_ff=256, map_size=13 * 13,
|
||||
).to(device)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)")
|
||||
print_memory(device, "初始化后")
|
||||
|
||||
# 分模块参数统计
|
||||
for name, module in model.named_children():
|
||||
n = sum(p.numel() for p in module.parameters())
|
||||
print(f" {name}: {n:,}")
|
||||
start = time.perf_counter()
|
||||
z_e = model(map_input)
|
||||
end = time.perf_counter()
|
||||
|
||||
# 前向传播测试
|
||||
map_input = torch.randint(0, 15, (4, 13 * 13)).to(device) # [B=4, 169]
|
||||
print_memory(device, "前向传播后")
|
||||
|
||||
z_q, z_e, indices, vq_loss, commit_loss, entropy_loss = model(map_input)
|
||||
|
||||
print(f"\nz_q shape: {z_q.shape}") # [4, 2, 64]
|
||||
print(f"z_e shape: {z_e.shape}") # [4, 2, 64]
|
||||
print(f"indices shape:{indices.shape}") # [4, 2]
|
||||
print(f"vq_loss: {vq_loss.item():.4f}")
|
||||
|
||||
# 推理采样测试
|
||||
z_sample = model.sample(B=4, device=device)
|
||||
print(f"sample shape: {z_sample.shape}") # [4, 2, 64]
|
||||
print(f"推理耗时: {end - start:.4f}s")
|
||||
print(f"输出形状: z_e={z_e.shape}")
|
||||
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
|
||||
print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}")
|
||||
print(f"Projection parameters: {sum(p.numel() for p in model.proj.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
@ -3,20 +3,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
"""
|
||||
向量量化层(Vector Quantization)。
|
||||
|
||||
将连续的编码向量序列映射到离散的 codebook 码字索引,
|
||||
并通过直通估计(Straight-Through Estimator)保持梯度流。
|
||||
|
||||
均匀分布正则化采用软分配熵最大化方案:
|
||||
通过对距离做 softmax 得到软分配概率,计算平均码字使用率的熵,
|
||||
最小化负熵以鼓励所有码字被均等使用。
|
||||
"""
|
||||
|
||||
def __init__(self, K: int, d_z: int, temp: float = 1.0):
|
||||
def __init__(self, K: int, d_z: int):
|
||||
"""
|
||||
Args:
|
||||
K: codebook 大小(码字数量)
|
||||
@ -26,12 +14,12 @@ class VectorQuantizer(nn.Module):
|
||||
super().__init__()
|
||||
self.K = K
|
||||
self.d_z = d_z
|
||||
self.temp = temp
|
||||
|
||||
self.codebook = nn.Embedding(K, d_z)
|
||||
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
|
||||
|
||||
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# z_e: [B, L * 3, d_z]
|
||||
"""
|
||||
Args:
|
||||
z_e: [B, L, d_z] 编码器输出的连续向量序列
|
||||
@ -40,28 +28,25 @@ class VectorQuantizer(nn.Module):
|
||||
z_q_st: [B, L, d_z] 量化后向量(直通梯度)
|
||||
indices: [B, L] 每个位置对应的码字索引
|
||||
commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2
|
||||
entropy_loss: scalar 负熵损失(最小化 = 最大化码字使用均匀度)
|
||||
"""
|
||||
B, L, d_z = z_e.shape
|
||||
|
||||
# 展平到 [B*L, d_z]
|
||||
z_flat = z_e.reshape(B * L, d_z)
|
||||
z_flat = z_e.reshape(B * L, d_z) # [B * L * 3, d_z]
|
||||
|
||||
codebook_w = self.codebook.weight # [K, d_z]
|
||||
|
||||
# 计算 L2 距离:||z_e - e_k||^2 = ||z_e||^2 + ||e_k||^2 - 2 * z_e · e_k
|
||||
# distances: [B*L, K]
|
||||
distances = (
|
||||
(z_flat ** 2).sum(dim=1, keepdim=True) # [B*L, 1]
|
||||
+ (codebook_w ** 2).sum(dim=1) # [K]
|
||||
- 2.0 * z_flat @ codebook_w.t() # [B*L, K]
|
||||
)
|
||||
ze_square = torch.sum(z_flat ** 2, dim=1, keepdim=True)
|
||||
ek_square = torch.sum(codebook_w ** 2, dim=1)
|
||||
mul = z_flat @ codebook_w.t()
|
||||
distances = ze_square + ek_square - 2 * mul
|
||||
|
||||
# Hard assignment:取最近码字索引
|
||||
indices = distances.argmin(dim=1) # [B*L]
|
||||
indices = distances.argmin(dim=1) # [B*L]
|
||||
|
||||
# 量化向量
|
||||
z_q_flat = self.codebook(indices) # [B*L, d_z]
|
||||
z_q_flat = self.codebook(indices) # [B*L, d_z]
|
||||
z_q = z_q_flat.reshape(B, L, d_z)
|
||||
|
||||
# 直通估计:前向传 z_q,反向传 z_e 的梯度
|
||||
@ -70,12 +55,14 @@ class VectorQuantizer(nn.Module):
|
||||
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
|
||||
commit_loss = F.mse_loss(z_e, z_q.detach())
|
||||
|
||||
# 熵最大化正则:通过软分配计算平均码字使用率,最小化负熵
|
||||
# soft_assign: [B*L, K],对距离做 softmax(距离越小,概率越大)
|
||||
soft_assign = F.softmax(-distances / self.temp, dim=1)
|
||||
avg_assign = soft_assign.mean(dim=0) # [K],平均码字使用率
|
||||
# entropy_loss = -H(p) = sum(p * log(p)),最小化即最大化熵
|
||||
entropy_loss = (avg_assign * torch.log(avg_assign + 1e-10)).sum()
|
||||
|
||||
indices = indices.reshape(B, L)
|
||||
return z_q_st, indices, commit_loss, entropy_loss
|
||||
return z_q_st, indices, commit_loss
|
||||
|
||||
def sample(self, B: int, L: int, device: torch.device) -> torch.Tensor:
|
||||
indices1 = torch.randint(0, self.K, (B, L), device=device)
|
||||
indices2 = torch.randint(0, self.K, (B, L), device=device)
|
||||
indices3 = torch.randint(0, self.K, (B, L), device=device)
|
||||
z1 = self.codebook(indices1)
|
||||
z2 = self.codebook(indices2)
|
||||
z3 = self.codebook(indices3)
|
||||
return torch.cat([z1, z2, z3], dim=1)
|
||||
|
||||
76
prompt.md
Normal file
76
prompt.md
Normal file
@ -0,0 +1,76 @@
|
||||
# Ginka 地图生成器 - Copilot 指引
|
||||
|
||||
## 项目概述
|
||||
|
||||
本项目是一个基于深度学习的二维网格状地图生成模型,用于生成魔塔(Magic Tower)类网页游戏地图。
|
||||
|
||||
- **模型结构**:VQ-VAE 风格编码器 + MaskGIT 解码器
|
||||
- VQ-VAE 编码器将完整地图压缩为离散隐变量 z(从 codebook 查得)
|
||||
- MaskGIT 以 z 为条件,通过迭代掩码预测生成地图
|
||||
- 推理时直接随机采样 z,无需用户输入
|
||||
- **地图规格**:13×13 格子,7 类图块
|
||||
- **目录结构**
|
||||
- `ginka/` — 模型定义与训练脚本(Python)
|
||||
- `data/` — 数据预处理(TypeScript,因游戏是网页游戏)
|
||||
- `docs/` — 设计文档
|
||||
- `shared/` — 可视化等共享工具
|
||||
|
||||
## 重要约束
|
||||
|
||||
### 训练
|
||||
|
||||
- **不要在当前设备上运行训练**,训练在其他设备上进行
|
||||
- 可以运行小规模验证、推理或单步测试,但不要触发完整训练流程
|
||||
|
||||
### 代码风格
|
||||
|
||||
#### Python
|
||||
|
||||
- 不使用三引号注释(`"""..."""`),一律改用 `#` 注释。对于行后的注释,注释的 # 应该在语句后面空一格的地方开始,不要多空,也不要少空,例如 `a = b # 注释内容`。
|
||||
- 不出现连续空行(即空行仅允许连续出现一行)不出现连续空格,例如下面的例子就不允许出现:
|
||||
|
||||
```python
|
||||
a = func1()
|
||||
abcdef = func2()
|
||||
```
|
||||
|
||||
应改为:
|
||||
|
||||
```python
|
||||
a = func1()
|
||||
abcdef = func2()
|
||||
```
|
||||
|
||||
- 遵循类似 Prettier 的风格,不出现尾逗号。
|
||||
- 不进行无意义的对齐,例如函数参数定义应该遵循这种风格,到达 80 字符左右换行:
|
||||
|
||||
```python
|
||||
def func(
|
||||
param1: type, param2: type, param3: type,
|
||||
param4: type, param5: type
|
||||
)
|
||||
```
|
||||
|
||||
而不是:
|
||||
|
||||
```python
|
||||
def func(param1: type, param2: type, param3: type,
|
||||
param4: type, param5: type)
|
||||
```
|
||||
|
||||
- 不使用下划线开头命名任何内容,包括私有方法。
|
||||
- 不写静态方法。
|
||||
- 仅允许在文件开头引入内容,不允许其他地方出现任何 `import`。
|
||||
- 文件尾添加空行。
|
||||
- 不允许出现连等。
|
||||
- 不允许使用元组语法同时给多个量分别赋值,比如 `a, b, c = d, e, f` 不允许出现,仅允许 `a, b, c = func()` 这种一赋多的场景。
|
||||
- 不要在文件开头添加注释,开头第一句应该是 `import`。文件注释应该在 `import` 之后写。
|
||||
|
||||
#### TypeScript
|
||||
|
||||
遵循 Prettier 风格。
|
||||
|
||||
### 验证与可视化
|
||||
|
||||
- 编写验证代码时,优先输出可视化结果(图片文件),使用 `shared/image.py` 中的工具
|
||||
- 验证阶段应对不同条件(不同 z 采样)分别生成图片,便于直观对比模型效果
|
||||
Loading…
Reference in New Issue
Block a user