ginka-generator/docs/three-stage-generation-design.md
unanmed 3676958781 feat: 分三阶段训练
Co-authored-by: Copilot <copilot@github.com>
2026-05-07 20:59:22 +08:00

22 KiB
Raw Blame History

三阶段级联地图生成设计文档

背景与问题诊断

当前问题:墙壁过度生成

现有单模型方案VQ-VAE + MaskGIT 联合训练)在实践中呈现出模型严重偏向生成墙壁的现象,具体表现为:

  • 生成地图中墙壁密度显著偏高,可通行空间不足;
  • 门、怪物、入口等功能性元素稀少甚至缺失;
  • 资源类 tile 几乎不出现。

根本原因分析

从 token 分布角度来看13×13 = 169 个格子中,各类 tile 的分布严重不均:

tile 类型 典型比例 训练信号
墙壁wall 40%~60% 强、密集
空地floor 20%~40% 强、密集
door 1%~5% 极弱
怪物 5%~15%
入口 1%~3% 极弱
资源 5%~15%

单模型将以上所有类别的 交叉熵损失混合在一起优化。由于墙壁和空地占据绝大多数 token模型可以通过反复预测墙壁/空地获得低训练损失,而无需真正学习如何放置稀有类别。

这是类别不均衡问题的结构性体现:即使引入 Focal Loss调整 loss 权重也难以从根本上解决三类任务学习难度不匹配的问题——因为结构骨架floor/wall是其他所有元素放置的前提混合训练会导致模型困于局部最优解"把所有位置都预测为墙壁")。


核心思路:三阶段级联生成

将单次全类别生成拆分为三个独立的 MaskGIT 阶段,每阶段只负责一组语义相近、结构约束相似的 tile 类别。后续阶段以前序阶段的输出作为已知上下文。

┌──────────────────────────────────────────────────────────────────────┐
│  阶段一:结构骨架生成                                                  │
│                                                                      │
│  全 MASK  ──► [Stage1-MaskGIT + z₁]  ──► floor/wall 地图            │
└──────────────────────────────────────────────────────────────────────┘
                               │
                               ▼ 已知 floor/wall 上下文
┌──────────────────────────────────────────────────────────────────────┐
│  阶段二:功能元素放置                                                  │
│                                                                      │
│  floor/wall + MASK  ──► [Stage2-MaskGIT + z₂]  ──► door/monster/入口│
└──────────────────────────────────────────────────────────────────────┘
                               │
                               ▼ 已知 floor/wall/door/monster/入口 上下文
┌──────────────────────────────────────────────────────────────────────┐
│  阶段三:资源放置                                                      │
│                                                                      │
│  完整上下文 + MASK  ──► [Stage3-MaskGIT + z₃]  ──► resource          │
└──────────────────────────────────────────────────────────────────────┘

各阶段职责

阶段 负责类别 tile 数 类别数(含 MASK 结构约束强度
阶段一 floor(0)、wall(1) 2 3 极强
阶段二 door(2)、monster(4)、entrance(5) 3 6
阶段三 resource(3) 1 3

注:各阶段模型的词表不需要只含本阶段 tile完整词表7 类)可以保留不变,只是被 MASK 的位置计算 Loss 的位置会有所不同(见训练策略)。


每阶段 tile 映射与 MASK 策略

阶段一

输入:所有位置均填充 MASK tokentile=6)。

目标:预测 floor(0) 和 wall(1);原始地图中所有非 floor/wall 的 tile 在训练目标中被重映射为 floor(0),因为阶段一模型不关心功能性元素的具体种类。

STAGE1_REMAP = {
    0: 0,   # floor → floor
    1: 1,   # wall  → wall
    2: 0,   # door  → floor视作空地让阶段二填充
    3: 0,   # resource → floor
    4: 0,   # monster → floor
    5: 0,   # entrance → floor
    6: 6,   # MASK  → MASK不参与损失计算
}

Loss 计算范围:所有非 MASK 位置(即全局所有位置,因为输入均为 MASK

阶段二

输入:将阶段一输出的 floor/wall 地图作为固定上下文,原始地图中属于 door/monster/entrance 的位置替换为 MASK token其余floor/wall位置保持不变。

STAGE2_MASK_IDS = {2, 4, 5}  # 需要在阶段二中被预测的 tile ID

训练时输入构造(使用 GT 地图中的 floor/wall 作为上下文,不使用阶段一的实际输出):

def make_stage2_input(gt_map: np.ndarray) -> np.ndarray:
    """
    gt_map: [H*W] 整数数组,包含完整原始地图
    返回: stage2 输入地图door/monster/entrance 位置替换为 MASK
    """
    inp = gt_map.copy()
    # 先将所有资源归一化为 floor阶段二不负责资源
    inp[np.isin(inp, [3])] = 0
    # 将 door/monster/entrance 位置 MASK 掉
    inp[np.isin(inp, [2, 4, 5])] = 6  # MASK
    return inp

目标:预测 door(2)、monster(4)、entrance(5) 位置的类别(以及它们所在位置原本是否是 floor

Loss 计算范围:仅对输入为 MASK即原本是 door/monster/entrance 的位置计算损失floor/wall 位置不参与损失(它们已经确定)。

阶段三

输入将阶段二输出的完整功能性地图floor/wall/door/monster/entrance作为固定上下文原始地图中资源tile=3位置替换为 MASK token。

def make_stage3_input(gt_map: np.ndarray) -> np.ndarray:
    """
    gt_map: [H*W] 整数数组,包含完整原始地图
    返回: stage3 输入地图resource 位置替换为 MASK
    """
    inp = gt_map.copy()
    inp[inp == 3] = 6  # resource → MASK
    return inp

Loss 计算范围:仅对输入为 MASK即原本是 resource 的位置)计算损失。


模型架构

共享基础架构

三个阶段均采用与现有方案相同的 MaskGIT + VQ-VAE 架构,不需要设计新的模型结构。核心差异在于:

  1. 输入词表:与现有方案一致,均使用 7 类词表(NUM_CLASSES=7MASK_TOKEN=6
  2. z 来源:每阶段使用来自对应 VQ 通道的隐变量;
  3. Loss 掩码:只对该阶段负责的位置计算 CE Loss。

VQ-VAE z 的阶段分配

现有 VQ-VAE 已采用三通道设计CH1/CH2/CH3与三个生成阶段自然对应

通道 编码目标 供应阶段 描述
CH1 floor/wall 骨架 阶段一 z₁控制墙壁结构多样性
CH2 door/monster/入口 阶段二 z₂控制功能元素布局多样性
CH3 resource 分布 阶段三 z₃控制资源密度与位置多样性
真实地图 ──► CH1 encoder ──► z₁  ──► Stage1 MaskGIT
         ──► CH2 encoder ──► z₂  ──► Stage2 MaskGIT
         ──► CH3 encoder ──► z₃  ──► Stage3 MaskGIT

通道专属编码:各通道编码器在编码前只"看"与自身相关的 tile其余 tile 视作 floor(0)

def ch1_mask(gt_map):
    """只保留 floor/wall其余置 0"""
    m = gt_map.copy()
    m[~np.isin(m, [0, 1])] = 0
    return m

def ch2_mask(gt_map):
    """只保留 door/monster/entrance其余置 0"""
    m = gt_map.copy()
    m[~np.isin(m, [2, 4, 5])] = 0
    return m

def ch3_mask(gt_map):
    """只保留 resource其余置 0"""
    m = gt_map.copy()
    m[m != 3] = 0
    return m

Loss 计算掩码实现

训练时,每阶段额外接收一个 loss_mask: torch.BoolTensor [B, H*W],指示哪些位置需要计算损失:

# 阶段一:所有位置(因为输入全为 MASK
loss_mask_s1 = torch.ones(B, MAP_SIZE, dtype=torch.bool)

# 阶段二:只有原本是 door/monster/entrance 的位置
loss_mask_s2 = torch.isin(raw_map, torch.tensor([2, 4, 5]))

# 阶段三:只有原本是 resource 的位置
loss_mask_s3 = (raw_map == 3)

Focal Loss 修改为只对 loss_mask 为 True 的位置求和后做归一化:

def stage_focal_loss(logits, targets, loss_mask, gamma=2.0):
    # logits: [B, C, H*W], targets: [B, H*W], loss_mask: [B, H*W]
    per_token_loss = focal_loss(logits, targets, gamma, reduction='none')  # [B, H*W]
    masked_loss = per_token_loss[loss_mask]
    return masked_loss.mean() if masked_loss.numel() > 0 else per_token_loss.mean()

训练策略

训练方式:顺序训练(推荐)

三个阶段依次训练,后续阶段训练时使用 GT 地图作为前序阶段的"完美输出"teacher forcing而非使用前序阶段模型的实际推理结果。

训练阶段一:
    data:     (全MASK输入, stage1目标)
    loss:     focal loss on all positions

训练阶段二:
    data:     (floor/wall上下文 + MASK输入, stage2目标)
    loss:     focal loss only on door/monster/entrance positions

训练阶段三:
    data:     (floor/wall/door/monster/入口上下文 + MASK输入, stage3目标)
    loss:     focal loss only on resource positions

使用 GT 而非前序模型输出的理由

  • 避免误差级联(前序模型若生成错误的骨架,后续模型的训练将在错误分布上进行);
  • 各阶段训练更稳定,收敛更快;
  • 阶段之间解耦,方便单独迭代和调试。

各阶段训练子集划分

每个阶段均沿用现有 A/B/C/D 子集划分逻辑,但 MASK 策略应用在对应阶段的目标 tile 上:

子集 阶段一 阶段二 阶段三
A 随机遮盖部分 floor/wall 随机遮盖部分 door/monster/入口 随机遮盖部分 resource
B 保留全部 wallMASK floor 给定全部 wallMASK 功能元素 给定全部骨架MASK 部分资源
C 随机保留部分 wallMASK 其余 同 B 同 B
D 保留 wall+entranceMASK 其余 给定 wall+entranceMASK 门/怪 同 B

各阶段专属损失

每阶段的总损失:

\mathcal{L}^{(s)} = \mathcal{L}_{CE}^{(s)} + \beta \cdot \mathcal{L}_{commit}^{(s)} + \gamma \cdot \mathcal{L}_{uniform}^{(s)}

其中 s \in \{1, 2, 3\} 表示阶段编号,\mathcal{L}_{CE}^{(s)} 只在该阶段负责的 tile 位置上计算。


推理流程

完整推理管线

1. 随机采样 z₁, z₂, z₃各自从对应 codebook 均匀采样 L 个 index

2. 阶段一推理(结构骨架)
   初始状态: 全部 169 个位置 = MASK token
   迭代 MaskGIT 解码cosine schedule约 18 步):
     输入: MASK地图 + z₁
     输出: 逐步填充 floor/wall直到无 MASK 位置
   结果: floor/wall 骨架地图 M₁

3. 阶段二推理(功能元素放置)
   初始状态: 继承 M₁在 floor 位置随机选取候选位置置为 MASK
   ─── 或 ───
   初始状态: 继承 M₁所有 floor 位置均置为 MASK让模型决定放置密度
   迭代 MaskGIT 解码:
     输入: 含已知 wall 的掩码地图 + z₂
     约束: wall 位置不参与 unmask保持不变
     输出: 逐步填充 door/monster/entrance
   结果: 含功能元素的地图 M₂

4. 阶段三推理(资源放置)
   初始状态: 继承 M₂所有 floor 位置(未被阶段二填充的)置为 MASK
   迭代 MaskGIT 解码:
     输入: 含已知 wall/door/monster/入口的掩码地图 + z₃
     约束: 非 floor 位置保持不变
     输出: 逐步填充 resource或保持 floor
   结果: 完整地图 M₃

阶段二初始 MASK 策略

阶段二的初始状态有两种选择:

策略 描述 适用场景
全 floor MASK 所有 floor 位置均置为 MASK让模型自主决定放置密度 完全随机生成
候选位置 MASK 只 MASK 用户指定或随机抽取的少数位置 用户指定部分位置

对于完全随机生成场景,推荐使用全 floor MASK——此时 z₂ 决定功能元素的总体风格(密集/稀疏/集中/分散),模型负责在此约束下寻找合理位置。

用户交互场景

场景 阶段一输入 阶段二输入 阶段三输入
完全随机 全 MASK 继承 M₁ 继承 M₂
用户手绘墙壁 已知 wall + MASK 继承 M₁固定 wall 继承 M₂
用户指定入口 已知 entrance + MASK 继承 M₁固定 entrance 继承 M₂
用户手绘墙+指定入口 已知 wall/entrance 继承 M₁ 继承 M₂

实现方案

方案一:三个独立 GinkaMaskGIT 实例(推荐)

每个阶段分别实例化一个 GinkaMaskGIT,共用同一个 GinkaVQVAE(其中已含三通道编码器)。各阶段模型独立训练、独立存储。

# 模型结构
vqvae   = GinkaVQVAE(...)         # 三通道共享编码器
stage1  = GinkaMaskGIT(...)       # 结构骨架
stage2  = GinkaMaskGIT(...)       # 功能元素
stage3  = GinkaMaskGIT(...)       # 资源

优点

  • 各阶段模型完全解耦,可独立调整超参、单独重训;
  • 模型大小可以针对任务难度调整(阶段一可以更大,阶段三可以更小);
  • 便于调试和增量式开发。

缺点

  • 三个模型分别保存,推理时需依次加载;
  • 总参数量约为原方案的 3 倍(但可以通过缩小各阶段模型来对冲)。

方案二:单模型 + 阶段 Embedding备选

复用同一个 GinkaMaskGIT,添加阶段 embedding类似 BERT 的 segment embedding

self.stage_embedding = nn.Embedding(3, d_model)  # 三个阶段

def forward(self, map, z, stage: int):
    x = self.tile_embedding(map) + self.pos_embedding
    x = x + self.stage_embedding(torch.tensor(stage))  # 阶段条件
    ...

优点:参数量与原方案相同,推理时只需加载一个模型。
缺点:三阶段共享所有权重,可能导致阶段间干扰;阶段一的结构任务与阶段三的稀疏资源任务表示空间差异大,单一模型难以同时擅长。

结论推荐方案一,尤其是在当前阶段(验证多阶段框架可行性),可先分别训练三个轻量模型进行快速验证。

各阶段模型规模建议

阶段 任务难度 建议 d_model 建议 num_layers 参数量估算
阶段一 高(结构) 256 6 ~4M
阶段二 中(功能) 192 4 ~2M
阶段三 低(稀疏) 128 3 ~0.8M

Dataset 修改方案

新增 GinkaStageDataset

需要扩展 dataset.py,增加针对三阶段的 Dataset 类,或在 GinkaVQDataset 中添加 stage 参数:

class GinkaStageDataset(Dataset):
    """
    三阶段级联训练专用 Dataset。

    返回 dict:
      raw_map:      LongTensor [H*W]  完整原始地图(供 VQ-VAE 编码)
      stage_input:  LongTensor [H*W]  当前阶段 MaskGIT 输入(含上下文 + MASK
      target_map:   LongTensor [H*W]  CE loss ground truth等同 raw_map
      loss_mask:    BoolTensor [H*W]  只对 True 位置计算损失
      subset:       str               子集标识
    """

    STAGE1_TARGETS = {0, 1}       # floor, wall
    STAGE2_TARGETS = {2, 4, 5}    # door, monster, entrance
    STAGE3_TARGETS = {3}          # resource

数据构造函数

def make_stage1_sample(gt_map: np.ndarray, mask_id: int = 6):
    """阶段一:全 MASK 输入,目标是 floor/wall其余归一为 floor"""
    stage_input = np.full_like(gt_map, mask_id)
    target = gt_map.copy()
    target[~np.isin(target, [0, 1])] = 0  # 非结构 tile → floor
    loss_mask = np.ones_like(gt_map, dtype=bool)
    return stage_input, target, loss_mask

def make_stage2_sample(gt_map: np.ndarray, mask_id: int = 6):
    """阶段二floor/wall 为上下文door/monster/entrance 位置 MASK"""
    stage_input = gt_map.copy()
    stage_input[stage_input == 3] = 0      # 资源 → floor阶段二不负责
    target_ids = np.isin(stage_input, [2, 4, 5])
    stage_input[target_ids] = mask_id      # 功能元素 → MASK
    target = gt_map.copy()
    target[target == 3] = 0               # target 中资源也视为 floor
    loss_mask = (gt_map != 0) & (gt_map != 1) & (gt_map != 3)  # 只计算功能元素位置
    return stage_input, target, loss_mask

def make_stage3_sample(gt_map: np.ndarray, mask_id: int = 6):
    """阶段三:全上下文保留,只 MASK 资源位置"""
    stage_input = gt_map.copy()
    stage_input[stage_input == 3] = mask_id  # 资源 → MASK
    target = gt_map.copy()
    loss_mask = (gt_map == 3)               # 只计算资源位置
    return stage_input, target, loss_mask

训练脚本设计

新增 ginka/train_stage.py

用法示例:
    python -m ginka.train_stage --stage 1
    python -m ginka.train_stage --stage 2
    python -m ginka.train_stage --stage 3 --resume True --state result/stage3/stage3-10.pth

各阶段检查点分别存储到:

  • result/stage1/stage1-{epoch}.pth
  • result/stage2/stage2-{epoch}.pth
  • result/stage3/stage3-{epoch}.pth

阶段二/三的 VQ 编码器冻结策略

阶段二训练时CH1 编码器(已在阶段一训练中充分收敛)可选择冻结,只更新 CH2 编码器和 Stage2 MaskGIT

if args.stage == 2:
    for p in vqvae.encoder_ch1.parameters():
        p.requires_grad_(False)  # 冻结 CH1

类似地,阶段三训练时可冻结 CH1 和 CH2 编码器。


与现有方案的对比

维度 现有单模型方案 三阶段级联方案
墙壁过度生成 存在,难以从根本解决 阶段一单独训练骨架Loss 聚焦 floor/wall
训练信号均衡性 墙壁主导,稀有类欠拟合 各阶段 Loss 只计算本阶段 tile信号均衡
模型可调试性 单一模型,各类别相互干扰 各阶段独立,可单独分析每阶段表现
推理速度 1 次完整 MaskGIT 解码 3 次级联解码(总步数约为原来 3 倍)
误差累积 存在,前序阶段错误会传播到后续阶段
用户可控性 较难(条件混合) 好(可在任意阶段注入用户约束)
参数量 ~4M 约 ~7M可通过缩减各阶段规模控制

预期收益

  1. 解决墙壁过度生成:阶段一专门针对 floor/wall 训练,类别分布从 7 类压缩到 3 类Loss 完全聚焦,模型不再有逃避路径;
  2. 功能元素召回率提升:阶段二以已知骨架为前提生成 door/monster/entrance训练信号不再被墙壁噪声稀释
  3. 资源分布更合理:阶段三在完整上下文下放置资源,能感知到门/怪物位置,避免资源与关键功能元素重叠;
  4. 可交互性增强:用户可在任意阶段注入约束(固定某些 tile天然支持层次化编辑。

风险与应对

风险 描述 应对策略
误差累积 阶段一骨架不准确会导致阶段二/三布局失真 阶段一优先保证质量,推理时对骨架做后处理校验
推理耗时增加 三次 MaskGIT 解码约 3 倍耗时 减少 MaskGIT 迭代步数(阶段二/三任务更简单,步数可减半)
阶段二稀疏性问题 169 格子中功能元素极少Loss 计算覆盖率低 适当提高 stage 2 的 loss_mask 覆盖(周边 floor 也计入)
Codebook 对齐 三通道 VQ-VAE 的 z 在分阶段训练时各自优化 联合训练阶段一+VQ-CH1联合训练阶段二+VQ-CH2以此类推

实施顺序

  • 新增 GinkaStageDataset,实现三阶段数据构造函数
  • 新增 ginka/train_stage.py,支持 --stage 1/2/3 参数
  • 阶段一训练:仅 floor/wall验证骨架生成质量关键里程碑
  • 阶段二训练:以 GT floor/wall 为上下文,验证功能元素召回率
  • 阶段三训练:以 GT 完整地图为上下文,验证资源放置合理性
  • 实现三阶段级联推理脚本,接入现有可视化工具
  • 对比实验:三阶段方案 vs 现有单模型方案在墙壁密度、功能元素召回率、资源分布等指标上的差异