feat: 分三阶段训练

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-05-07 20:59:22 +08:00
parent 3b86a78e0b
commit 3676958781
4 changed files with 1538 additions and 6 deletions

View File

@ -0,0 +1,499 @@
# 三阶段级联地图生成设计文档
## 背景与问题诊断
### 当前问题:墙壁过度生成
现有单模型方案VQ-VAE + MaskGIT 联合训练)在实践中呈现出**模型严重偏向生成墙壁**的现象,具体表现为:
- 生成地图中墙壁密度显著偏高,可通行空间不足;
- 门、怪物、入口等功能性元素稀少甚至缺失;
- 资源类 tile 几乎不出现。
### 根本原因分析
从 token 分布角度来看13×13 = 169 个格子中,各类 tile 的分布严重不均:
| tile 类型 | 典型比例 | 训练信号 |
| ------------- | -------- | -------- |
| 墙壁wall | 40%~60% | 强、密集 |
| 空地floor | 20%~40% | 强、密集 |
| 门door | 1%~5% | 极弱 |
| 怪物 | 5%~15% | 弱 |
| 入口 | 1%~3% | 极弱 |
| 资源 | 5%~15% | 弱 |
单模型将以上所有类别的 **交叉熵损失**混合在一起优化。由于墙壁和空地占据绝大多数 token模型可以通过反复预测墙壁/空地获得低训练损失,而无需真正学习如何放置稀有类别。
这是**类别不均衡问题的结构性体现**:即使引入 Focal Loss调整 loss 权重也难以从根本上解决三类任务学习难度不匹配的问题——因为结构骨架floor/wall是其他所有元素放置的前提混合训练会导致模型困于局部最优解"把所有位置都预测为墙壁")。
---
## 核心思路:三阶段级联生成
将单次全类别生成拆分为**三个独立的 MaskGIT 阶段**,每阶段只负责一组语义相近、结构约束相似的 tile 类别。后续阶段以前序阶段的输出作为已知上下文。
```
┌──────────────────────────────────────────────────────────────────────┐
│ 阶段一:结构骨架生成 │
│ │
│ 全 MASK ──► [Stage1-MaskGIT + z₁] ──► floor/wall 地图 │
└──────────────────────────────────────────────────────────────────────┘
▼ 已知 floor/wall 上下文
┌──────────────────────────────────────────────────────────────────────┐
│ 阶段二:功能元素放置 │
│ │
│ floor/wall + MASK ──► [Stage2-MaskGIT + z₂] ──► door/monster/入口│
└──────────────────────────────────────────────────────────────────────┘
▼ 已知 floor/wall/door/monster/入口 上下文
┌──────────────────────────────────────────────────────────────────────┐
│ 阶段三:资源放置 │
│ │
│ 完整上下文 + MASK ──► [Stage3-MaskGIT + z₃] ──► resource │
└──────────────────────────────────────────────────────────────────────┘
```
### 各阶段职责
| 阶段 | 负责类别 | tile 数 | 类别数(含 MASK | 结构约束强度 |
| ------ | -------------------------------- | ------- | ----------------- | ------------ |
| 阶段一 | floor(0)、wall(1) | 2 | 3 | 极强 |
| 阶段二 | door(2)、monster(4)、entrance(5) | 3 | 6 | 强 |
| 阶段三 | resource(3) | 1 | 3 | 弱 |
注:各阶段模型的词表不需要只含本阶段 tile完整词表7 类)可以保留不变,只是**被 MASK 的位置**和**计算 Loss 的位置**会有所不同(见训练策略)。
---
## 每阶段 tile 映射与 MASK 策略
### 阶段一
**输入**:所有位置均填充 MASK token`tile=6`)。
**目标**:预测 floor(0) 和 wall(1);原始地图中所有非 floor/wall 的 tile 在训练目标中被**重映射为 floor(0)**,因为阶段一模型不关心功能性元素的具体种类。
```python
STAGE1_REMAP = {
0: 0, # floor → floor
1: 1, # wall → wall
2: 0, # door → floor视作空地让阶段二填充
3: 0, # resource → floor
4: 0, # monster → floor
5: 0, # entrance → floor
6: 6, # MASK → MASK不参与损失计算
}
```
**Loss 计算范围**:所有非 MASK 位置(即全局所有位置,因为输入均为 MASK
### 阶段二
**输入**:将阶段一输出的 floor/wall 地图作为固定上下文,原始地图中属于 door/monster/entrance 的位置替换为 MASK token其余floor/wall位置保持不变。
```python
STAGE2_MASK_IDS = {2, 4, 5} # 需要在阶段二中被预测的 tile ID
```
**训练时输入构造**(使用 GT 地图中的 floor/wall 作为上下文,不使用阶段一的实际输出):
```python
def make_stage2_input(gt_map: np.ndarray) -> np.ndarray:
"""
gt_map: [H*W] 整数数组,包含完整原始地图
返回: stage2 输入地图door/monster/entrance 位置替换为 MASK
"""
inp = gt_map.copy()
# 先将所有资源归一化为 floor阶段二不负责资源
inp[np.isin(inp, [3])] = 0
# 将 door/monster/entrance 位置 MASK 掉
inp[np.isin(inp, [2, 4, 5])] = 6 # MASK
return inp
```
**目标**:预测 door(2)、monster(4)、entrance(5) 位置的类别(以及它们所在位置原本是否是 floor
**Loss 计算范围**:仅对输入为 MASK即原本是 door/monster/entrance 的位置计算损失floor/wall 位置不参与损失(它们已经确定)。
### 阶段三
**输入**将阶段二输出的完整功能性地图floor/wall/door/monster/entrance作为固定上下文原始地图中资源tile=3位置替换为 MASK token。
```python
def make_stage3_input(gt_map: np.ndarray) -> np.ndarray:
"""
gt_map: [H*W] 整数数组,包含完整原始地图
返回: stage3 输入地图resource 位置替换为 MASK
"""
inp = gt_map.copy()
inp[inp == 3] = 6 # resource → MASK
return inp
```
**Loss 计算范围**:仅对输入为 MASK即原本是 resource 的位置)计算损失。
---
## 模型架构
### 共享基础架构
三个阶段均采用与现有方案**相同的 MaskGIT + VQ-VAE 架构**,不需要设计新的模型结构。核心差异在于:
1. **输入词表**:与现有方案一致,均使用 7 类词表(`NUM_CLASSES=7``MASK_TOKEN=6`
2. **z 来源**:每阶段使用来自对应 VQ 通道的隐变量;
3. **Loss 掩码**:只对该阶段负责的位置计算 CE Loss。
### VQ-VAE z 的阶段分配
现有 VQ-VAE 已采用三通道设计CH1/CH2/CH3与三个生成阶段自然对应
| 通道 | 编码目标 | 供应阶段 | 描述 |
| ---- | ----------------- | -------- | ---------------------------- |
| CH1 | floor/wall 骨架 | 阶段一 | z₁控制墙壁结构多样性 |
| CH2 | door/monster/入口 | 阶段二 | z₂控制功能元素布局多样性 |
| CH3 | resource 分布 | 阶段三 | z₃控制资源密度与位置多样性 |
```
真实地图 ──► CH1 encoder ──► z₁ ──► Stage1 MaskGIT
──► CH2 encoder ──► z₂ ──► Stage2 MaskGIT
──► CH3 encoder ──► z₃ ──► Stage3 MaskGIT
```
**通道专属编码**:各通道编码器在编码前只"看"与自身相关的 tile其余 tile 视作 floor(0)
```python
def ch1_mask(gt_map):
"""只保留 floor/wall其余置 0"""
m = gt_map.copy()
m[~np.isin(m, [0, 1])] = 0
return m
def ch2_mask(gt_map):
"""只保留 door/monster/entrance其余置 0"""
m = gt_map.copy()
m[~np.isin(m, [2, 4, 5])] = 0
return m
def ch3_mask(gt_map):
"""只保留 resource其余置 0"""
m = gt_map.copy()
m[m != 3] = 0
return m
```
### Loss 计算掩码实现
训练时,每阶段额外接收一个 `loss_mask: torch.BoolTensor [B, H*W]`,指示哪些位置需要计算损失:
```python
# 阶段一:所有位置(因为输入全为 MASK
loss_mask_s1 = torch.ones(B, MAP_SIZE, dtype=torch.bool)
# 阶段二:只有原本是 door/monster/entrance 的位置
loss_mask_s2 = torch.isin(raw_map, torch.tensor([2, 4, 5]))
# 阶段三:只有原本是 resource 的位置
loss_mask_s3 = (raw_map == 3)
```
Focal Loss 修改为只对 `loss_mask` 为 True 的位置求和后做归一化:
```python
def stage_focal_loss(logits, targets, loss_mask, gamma=2.0):
# logits: [B, C, H*W], targets: [B, H*W], loss_mask: [B, H*W]
per_token_loss = focal_loss(logits, targets, gamma, reduction='none') # [B, H*W]
masked_loss = per_token_loss[loss_mask]
return masked_loss.mean() if masked_loss.numel() > 0 else per_token_loss.mean()
```
---
## 训练策略
### 训练方式:顺序训练(推荐)
三个阶段**依次训练**,后续阶段训练时使用 GT 地图作为前序阶段的"完美输出"teacher forcing而非使用前序阶段模型的实际推理结果。
```
训练阶段一:
data: (全MASK输入, stage1目标)
loss: focal loss on all positions
训练阶段二:
data: (floor/wall上下文 + MASK输入, stage2目标)
loss: focal loss only on door/monster/entrance positions
训练阶段三:
data: (floor/wall/door/monster/入口上下文 + MASK输入, stage3目标)
loss: focal loss only on resource positions
```
**使用 GT 而非前序模型输出的理由**
- 避免误差级联(前序模型若生成错误的骨架,后续模型的训练将在错误分布上进行);
- 各阶段训练更稳定,收敛更快;
- 阶段之间解耦,方便单独迭代和调试。
### 各阶段训练子集划分
每个阶段均沿用现有 A/B/C/D 子集划分逻辑,但 MASK 策略应用在对应阶段的目标 tile 上:
| 子集 | 阶段一 | 阶段二 | 阶段三 |
| ---- | ----------------------------- | ------------------------------ | --------------------------- |
| A | 随机遮盖部分 floor/wall | 随机遮盖部分 door/monster/入口 | 随机遮盖部分 resource |
| B | 保留全部 wallMASK floor | 给定全部 wallMASK 功能元素 | 给定全部骨架MASK 部分资源 |
| C | 随机保留部分 wallMASK 其余 | 同 B | 同 B |
| D | 保留 wall+entranceMASK 其余 | 给定 wall+entranceMASK 门/怪 | 同 B |
### 各阶段专属损失
每阶段的总损失:
$$\mathcal{L}^{(s)} = \mathcal{L}_{CE}^{(s)} + \beta \cdot \mathcal{L}_{commit}^{(s)} + \gamma \cdot \mathcal{L}_{uniform}^{(s)}$$
其中 $s \in \{1, 2, 3\}$ 表示阶段编号,$\mathcal{L}_{CE}^{(s)}$ 只在该阶段负责的 tile 位置上计算。
---
## 推理流程
### 完整推理管线
```
1. 随机采样 z₁, z₂, z₃各自从对应 codebook 均匀采样 L 个 index
2. 阶段一推理(结构骨架)
初始状态: 全部 169 个位置 = MASK token
迭代 MaskGIT 解码cosine schedule约 18 步):
输入: MASK地图 + z₁
输出: 逐步填充 floor/wall直到无 MASK 位置
结果: floor/wall 骨架地图 M₁
3. 阶段二推理(功能元素放置)
初始状态: 继承 M₁在 floor 位置随机选取候选位置置为 MASK
─── 或 ───
初始状态: 继承 M₁所有 floor 位置均置为 MASK让模型决定放置密度
迭代 MaskGIT 解码:
输入: 含已知 wall 的掩码地图 + z₂
约束: wall 位置不参与 unmask保持不变
输出: 逐步填充 door/monster/entrance
结果: 含功能元素的地图 M₂
4. 阶段三推理(资源放置)
初始状态: 继承 M₂所有 floor 位置(未被阶段二填充的)置为 MASK
迭代 MaskGIT 解码:
输入: 含已知 wall/door/monster/入口的掩码地图 + z₃
约束: 非 floor 位置保持不变
输出: 逐步填充 resource或保持 floor
结果: 完整地图 M₃
```
### 阶段二初始 MASK 策略
阶段二的初始状态有两种选择:
| 策略 | 描述 | 适用场景 |
| ------------- | -------------------------------------------------- | ---------------- |
| 全 floor MASK | 所有 floor 位置均置为 MASK让模型自主决定放置密度 | 完全随机生成 |
| 候选位置 MASK | 只 MASK 用户指定或随机抽取的少数位置 | 用户指定部分位置 |
对于完全随机生成场景,**推荐使用全 floor MASK**——此时 z₂ 决定功能元素的总体风格(密集/稀疏/集中/分散),模型负责在此约束下寻找合理位置。
### 用户交互场景
| 场景 | 阶段一输入 | 阶段二输入 | 阶段三输入 |
| ------------------- | -------------------- | ------------------------ | ---------- |
| 完全随机 | 全 MASK | 继承 M₁ | 继承 M₂ |
| 用户手绘墙壁 | 已知 wall + MASK | 继承 M₁固定 wall | 继承 M₂ |
| 用户指定入口 | 已知 entrance + MASK | 继承 M₁固定 entrance | 继承 M₂ |
| 用户手绘墙+指定入口 | 已知 wall/entrance | 继承 M₁ | 继承 M₂ |
---
## 实现方案
### 方案一:三个独立 GinkaMaskGIT 实例(推荐)
每个阶段分别实例化一个 `GinkaMaskGIT`,共用同一个 `GinkaVQVAE`(其中已含三通道编码器)。各阶段模型独立训练、独立存储。
```python
# 模型结构
vqvae = GinkaVQVAE(...) # 三通道共享编码器
stage1 = GinkaMaskGIT(...) # 结构骨架
stage2 = GinkaMaskGIT(...) # 功能元素
stage3 = GinkaMaskGIT(...) # 资源
```
**优点**
- 各阶段模型完全解耦,可独立调整超参、单独重训;
- 模型大小可以针对任务难度调整(阶段一可以更大,阶段三可以更小);
- 便于调试和增量式开发。
**缺点**
- 三个模型分别保存,推理时需依次加载;
- 总参数量约为原方案的 3 倍(但可以通过缩小各阶段模型来对冲)。
### 方案二:单模型 + 阶段 Embedding备选
复用同一个 `GinkaMaskGIT`,添加阶段 embedding类似 BERT 的 segment embedding
```python
self.stage_embedding = nn.Embedding(3, d_model) # 三个阶段
def forward(self, map, z, stage: int):
x = self.tile_embedding(map) + self.pos_embedding
x = x + self.stage_embedding(torch.tensor(stage)) # 阶段条件
...
```
**优点**:参数量与原方案相同,推理时只需加载一个模型。
**缺点**:三阶段共享所有权重,可能导致阶段间干扰;阶段一的结构任务与阶段三的稀疏资源任务表示空间差异大,单一模型难以同时擅长。
**结论****推荐方案一**,尤其是在当前阶段(验证多阶段框架可行性),可先分别训练三个轻量模型进行快速验证。
### 各阶段模型规模建议
| 阶段 | 任务难度 | 建议 d_model | 建议 num_layers | 参数量估算 |
| ------ | ---------- | ------------ | --------------- | ---------- |
| 阶段一 | 高(结构) | 256 | 6 | ~4M |
| 阶段二 | 中(功能) | 192 | 4 | ~2M |
| 阶段三 | 低(稀疏) | 128 | 3 | ~0.8M |
---
## Dataset 修改方案
### 新增 `GinkaStageDataset`
需要扩展 `dataset.py`,增加针对三阶段的 Dataset 类,或在 `GinkaVQDataset` 中添加 `stage` 参数:
```python
class GinkaStageDataset(Dataset):
"""
三阶段级联训练专用 Dataset。
返回 dict:
raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码)
stage_input: LongTensor [H*W] 当前阶段 MaskGIT 输入(含上下文 + MASK
target_map: LongTensor [H*W] CE loss ground truth等同 raw_map
loss_mask: BoolTensor [H*W] 只对 True 位置计算损失
subset: str 子集标识
"""
STAGE1_TARGETS = {0, 1} # floor, wall
STAGE2_TARGETS = {2, 4, 5} # door, monster, entrance
STAGE3_TARGETS = {3} # resource
```
### 数据构造函数
```python
def make_stage1_sample(gt_map: np.ndarray, mask_id: int = 6):
"""阶段一:全 MASK 输入,目标是 floor/wall其余归一为 floor"""
stage_input = np.full_like(gt_map, mask_id)
target = gt_map.copy()
target[~np.isin(target, [0, 1])] = 0 # 非结构 tile → floor
loss_mask = np.ones_like(gt_map, dtype=bool)
return stage_input, target, loss_mask
def make_stage2_sample(gt_map: np.ndarray, mask_id: int = 6):
"""阶段二floor/wall 为上下文door/monster/entrance 位置 MASK"""
stage_input = gt_map.copy()
stage_input[stage_input == 3] = 0 # 资源 → floor阶段二不负责
target_ids = np.isin(stage_input, [2, 4, 5])
stage_input[target_ids] = mask_id # 功能元素 → MASK
target = gt_map.copy()
target[target == 3] = 0 # target 中资源也视为 floor
loss_mask = (gt_map != 0) & (gt_map != 1) & (gt_map != 3) # 只计算功能元素位置
return stage_input, target, loss_mask
def make_stage3_sample(gt_map: np.ndarray, mask_id: int = 6):
"""阶段三:全上下文保留,只 MASK 资源位置"""
stage_input = gt_map.copy()
stage_input[stage_input == 3] = mask_id # 资源 → MASK
target = gt_map.copy()
loss_mask = (gt_map == 3) # 只计算资源位置
return stage_input, target, loss_mask
```
---
## 训练脚本设计
### 新增 `ginka/train_stage.py`
```
用法示例:
python -m ginka.train_stage --stage 1
python -m ginka.train_stage --stage 2
python -m ginka.train_stage --stage 3 --resume True --state result/stage3/stage3-10.pth
```
各阶段检查点分别存储到:
- `result/stage1/stage1-{epoch}.pth`
- `result/stage2/stage2-{epoch}.pth`
- `result/stage3/stage3-{epoch}.pth`
### 阶段二/三的 VQ 编码器冻结策略
阶段二训练时CH1 编码器(已在阶段一训练中充分收敛)可选择冻结,只更新 CH2 编码器和 Stage2 MaskGIT
```python
if args.stage == 2:
for p in vqvae.encoder_ch1.parameters():
p.requires_grad_(False) # 冻结 CH1
```
类似地,阶段三训练时可冻结 CH1 和 CH2 编码器。
---
## 与现有方案的对比
| 维度 | 现有单模型方案 | 三阶段级联方案 |
| -------------- | ------------------------ | ---------------------------------------- |
| 墙壁过度生成 | 存在,难以从根本解决 | 阶段一单独训练骨架Loss 聚焦 floor/wall |
| 训练信号均衡性 | 墙壁主导,稀有类欠拟合 | 各阶段 Loss 只计算本阶段 tile信号均衡 |
| 模型可调试性 | 单一模型,各类别相互干扰 | 各阶段独立,可单独分析每阶段表现 |
| 推理速度 | 1 次完整 MaskGIT 解码 | 3 次级联解码(总步数约为原来 3 倍) |
| 误差累积 | 无 | 存在,前序阶段错误会传播到后续阶段 |
| 用户可控性 | 较难(条件混合) | 好(可在任意阶段注入用户约束) |
| 参数量 | ~4M | 约 ~7M可通过缩减各阶段规模控制 |
---
## 预期收益
1. **解决墙壁过度生成**:阶段一专门针对 floor/wall 训练,类别分布从 7 类压缩到 3 类Loss 完全聚焦,模型不再有逃避路径;
2. **功能元素召回率提升**:阶段二以已知骨架为前提生成 door/monster/entrance训练信号不再被墙壁噪声稀释
3. **资源分布更合理**:阶段三在完整上下文下放置资源,能感知到门/怪物位置,避免资源与关键功能元素重叠;
4. **可交互性增强**:用户可在任意阶段注入约束(固定某些 tile天然支持层次化编辑。
---
## 风险与应对
| 风险 | 描述 | 应对策略 |
| ---------------- | ----------------------------------------- | -------------------------------------------------------- |
| 误差累积 | 阶段一骨架不准确会导致阶段二/三布局失真 | 阶段一优先保证质量,推理时对骨架做后处理校验 |
| 推理耗时增加 | 三次 MaskGIT 解码约 3 倍耗时 | 减少 MaskGIT 迭代步数(阶段二/三任务更简单,步数可减半) |
| 阶段二稀疏性问题 | 169 格子中功能元素极少Loss 计算覆盖率低 | 适当提高 stage 2 的 loss_mask 覆盖(周边 floor 也计入) |
| Codebook 对齐 | 三通道 VQ-VAE 的 z 在分阶段训练时各自优化 | 联合训练阶段一+VQ-CH1联合训练阶段二+VQ-CH2以此类推 |
---
## 实施顺序
- [ ] 新增 `GinkaStageDataset`,实现三阶段数据构造函数
- [ ] 新增 `ginka/train_stage.py`,支持 `--stage 1/2/3` 参数
- [ ] 阶段一训练:仅 floor/wall验证骨架生成质量关键里程碑
- [ ] 阶段二训练:以 GT floor/wall 为上下文,验证功能元素召回率
- [ ] 阶段三训练:以 GT 完整地图为上下文,验证资源放置合理性
- [ ] 实现三阶段级联推理脚本,接入现有可视化工具
- [ ] 对比实验:三阶段方案 vs 现有单模型方案在墙壁密度、功能元素召回率、资源分布等指标上的差异

View File

@ -4,14 +4,69 @@ import torch
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
def _compute_map_labels(map_2d) -> dict:
"""
2D 地图列表 numpy 数组推算结构标签
JSON 数据缺少 roomCount / highDegBranchCount / outerWall 字段时调用
"""
arr = np.array(map_2d, dtype=np.int64) # [H, W]
H, W = arr.shape
WALL, ENTRY = 1, 5
# outerWall最外圈中 wall+entry 占比 > 90%
border = np.concatenate([arr[0, :], arr[-1, :], arr[1:-1, 0], arr[1:-1, -1]])
total_b = border.size
outer_wall = int(total_b > 0 and np.sum((border == WALL) | (border == ENTRY)) / total_b > 0.9)
# roomCountBFS 统计 floor(0)+resource(3) 连通区域,
# 需满足:总格子 >= 4外接矩形宽 >= 2 且高 >= 2
FLOOR_SET = (0, 3)
visited = np.zeros((H, W), dtype=bool)
room_count = 0
for sy in range(H):
for sx in range(W):
if arr[sy, sx] not in FLOOR_SET or visited[sy, sx]:
continue
queue = [(sy, sx)]
visited[sy, sx] = True
tiles_y, tiles_x = [sy], [sx]
head = 0
while head < len(queue):
y, x = queue[head]; head += 1
for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
ny, nx = y + dy, x + dx
if 0 <= ny < H and 0 <= nx < W and not visited[ny, nx] and arr[ny, nx] in FLOOR_SET:
visited[ny, nx] = True
queue.append((ny, nx))
tiles_y.append(ny); tiles_x.append(nx)
if (len(tiles_y) >= 4
and max(tiles_y) - min(tiles_y) >= 1
and max(tiles_x) - min(tiles_x) >= 1):
room_count += 1
# highDegBranchCount非 wall 格子中4 邻域非 wall 邻居 >= 3 的数量
non_wall = (arr != WALL).astype(np.int32)
padded = np.pad(non_wall, 1, mode='constant', constant_values=0)
nbr_sum = (padded[:-2, 1:-1] + padded[2:, 1:-1] +
padded[1:-1, :-2] + padded[1:-1, 2:])
high_deg = int(np.sum((non_wall == 1) & (nbr_sum >= 3)))
return {'outerWall': outer_wall, 'roomCount': room_count, 'highDegBranchCount': high_deg}
def load_data(path: str): def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f: with open(path, 'r', encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
data_list = [] data_list = []
for value in data["data"].values(): for value in data["data"].values():
# 兼容旧版数据集(缺少结构标签字段)
if 'roomCount' not in value:
labels = _compute_map_labels(value['map'])
value.update(labels)
# symmetry 字段由 __getitem__ 在增强后重新计算,此处不需要从 JSON 读取
data_list.append(value) data_list.append(value)
return data_list return data_list
def _compute_symmetry(target_np: np.ndarray) -> tuple: def _compute_symmetry(target_np: np.ndarray) -> tuple:
@ -326,6 +381,322 @@ class GinkaSplitDataset(Dataset):
} }
# ---------------------------------------------------------------------------
# GinkaStageDataset三阶段级联训练专用数据集
# ---------------------------------------------------------------------------
class GinkaStageDataset(Dataset):
"""
三阶段级联生成训练专用 Dataset
每个阶段只预测特定类别的 tile后续阶段以前序阶段输出作为上下文
训练时统一使用 GT 作为前序上下文teacher forcing避免误差级联
阶段划分
stage=1 结构骨架预测 floor(0) + wall(1)
stage=2 功能元素预测 door(2) + monster(4) + entrance(5) floor/wall 为上下文
stage=3 资源放置预测 resource(3)以完整骨架为上下文
返回 dict:
raw_map: LongTensor [H*W] 完整原始地图 VQ-VAE 编码
vq_slice: LongTensor [H*W] 当前阶段 VQ 编码器的输入切片
stage_input: LongTensor [H*W] MaskGIT 输入含上下文 + MASK 位置
target_map: LongTensor [H*W] CE loss ground truth
loss_mask: BoolTensor [H*W] 只对 True 位置计算损失
subset: str 子集标识 A/B/C/D
struct_cond: LongTensor [4] [sym, room, branch, outer]
"""
FLOOR = 0
WALL = 1
DOOR = 2
RESOURCE = 3
MONSTER = 4
ENTRANCE = 5
MASK_ID = 6
STAGE1_TARGETS = frozenset({0, 1})
STAGE2_TARGETS = frozenset({2, 4, 5})
STAGE3_TARGETS = frozenset({3})
# VQ 切片集合:各阶段编码器只"看"与自身相关的 tile
_VQ_KEEP = {
1: frozenset({0, 1}),
2: frozenset({0, 1, 2, 4, 5}),
3: None, # 完整地图
}
def __init__(
self,
data_path: str,
stage: int,
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
wall_mask_ratio: float = 0.3,
room_thresholds: tuple = None,
branch_thresholds: tuple = None,
):
"""
Args:
data_path: JSON 数据文件路径
stage: 生成阶段 1/2/3
subset_weights: 子集 (A, B, C, D) 的采样权重自动归一化
wall_mask_ratio: Subset C 中额外随机 mask wall 比例上限
room_thresholds: 等频分箱阈值None 时自动计算
branch_thresholds: 等频分箱阈值None 时自动计算
"""
assert stage in (1, 2, 3), f"stage 必须是 1/2/3收到 {stage}"
self.stage = stage
self.data = load_data(data_path)
self.wall_mask_ratio = wall_mask_ratio
total_w = sum(subset_weights)
normalized = [x / total_w for x in subset_weights]
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
room_counts = [item['roomCount'] for item in self.data]
branch_counts = [item['highDegBranchCount'] for item in self.data]
if room_thresholds is None:
n = len(room_counts)
rs = sorted(room_counts)
bs = sorted(branch_counts)
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
if th1_r == th2_r: th2_r = th1_r + 1
if th1_b == th2_b: th2_b = th1_b + 1
self.room_th = (th1_r, th2_r)
self.branch_th = (th1_b, th2_b)
else:
self.room_th = room_thresholds
self.branch_th = branch_thresholds
def to_level(v, th):
return 0 if v < th[0] else (1 if v < th[1] else 2)
for item in self.data:
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
def __len__(self):
return len(self.data)
# ------------------------------------------------------------------
# 掩码辅助(与 GinkaVQDataset 相同逻辑)
# ------------------------------------------------------------------
@staticmethod
def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float:
r = np.random.beta(2, 2)
return min_r + (max_r - min_r) * r
@staticmethod
def _random_mask(h: int, w: int) -> np.ndarray:
ratio = GinkaStageDataset._sample_mask_ratio()
total = h * w
idx = np.random.choice(total, int(total * ratio), replace=False)
mask = np.zeros(total, dtype=bool)
mask[idx] = True
return mask
@staticmethod
def _block_mask(h: int, w: int) -> np.ndarray:
ratio = GinkaStageDataset._sample_mask_ratio()
max_block = max(2, min(h, w) // 2)
target = int(h * w * ratio)
mask = np.zeros((h, w), dtype=bool)
while mask.sum() < target:
bh = np.random.randint(2, max_block + 1)
bw = np.random.randint(2, max_block + 1)
x = np.random.randint(0, max(1, h - bh + 1))
y = np.random.randint(0, max(1, w - bw + 1))
mask[x:x + bh, y:y + bw] = True
return mask.reshape(-1)
def _std_mask(self, h: int, w: int) -> np.ndarray:
return self._random_mask(h, w) if random.random() < 0.5 else self._block_mask(h, w)
# ------------------------------------------------------------------
# 子集选择
# ------------------------------------------------------------------
def _choose_subset(self) -> str:
r = random.random()
if r < self.subset_cumw[0]: return 'A'
if r < self.subset_cumw[1]: return 'B'
if r < self.subset_cumw[2]: return 'C'
return 'D'
# ------------------------------------------------------------------
# 阶段一结构骨架floor + wall
# ------------------------------------------------------------------
def _make_stage1(self, raw_flat: np.ndarray, subset: str):
"""
阶段一预测 floor/wall所有非 floor/wall tile 在目标中重映射为 floor
子集决定向模型提供多少 wall 作为上下文条件
"""
H = W = 13
# 目标:非 floor/wall → floor
target = raw_flat.copy()
target[~np.isin(target, [self.FLOOR, self.WALL])] = self.FLOOR
inp = target.copy()
if subset == 'A':
# 标准随机 mask随机遮盖部分 floor/wall
mask = self._std_mask(H, W)
inp[mask] = self.MASK_ID
elif subset == 'B':
# 保留全部 wallMASK floor
inp[inp == self.FLOOR] = self.MASK_ID
elif subset == 'C':
# 随机保留部分 wallMASK 其余(含全部 floor
inp[inp == self.FLOOR] = self.MASK_ID
wall_idx = np.where(inp == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
inp[chosen] = self.MASK_ID
else: # D与 B 相同(阶段一无 entrance 维度)
inp[inp == self.FLOOR] = self.MASK_ID
loss_mask = (inp == self.MASK_ID)
return inp, target, loss_mask
# ------------------------------------------------------------------
# 阶段二功能元素door + monster + entrance
# ------------------------------------------------------------------
def _make_stage2(self, raw_flat: np.ndarray, subset: str):
"""
阶段二 floor/wall 为上下文预测 door/monster/entrance
resource 在输入与目标中均视为 floor阶段二不负责资源
子集决定 wall 上下文的完整程度与 door/monster/entrance 的掩码方式
"""
# 目标resource → floor
target = raw_flat.copy()
target[target == self.RESOURCE] = self.FLOOR
# 基础输入resource → floor功能元素先保留再按子集处理
inp = raw_flat.copy()
inp[inp == self.RESOURCE] = self.FLOOR
if subset == 'A':
# 随机遮盖部分 door/monster/entrance部分上下文补全
func_idx = np.where(np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE]))[0]
if len(func_idx) > 0:
ratio = random.random() * 0.8 + 0.2 # 20%~100%
n = max(1, int(len(func_idx) * ratio))
chosen = np.random.choice(func_idx, n, replace=False)
inp[chosen] = self.MASK_ID
else:
# B/C/D全部 door/monster/entrance → MASK
inp[np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE])] = self.MASK_ID
if subset == 'C':
# 额外随机 mask 部分 wall降低 wall 上下文质量)
wall_idx = np.where(inp == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
inp[chosen] = self.MASK_ID
# loss_mask阶段二只对 door/monster/entrance 原始位置计算损失,
# 不对被额外 mask 的 wall 位置计算(它们在 target 中已知为 wall
loss_mask = np.isin(raw_flat, [self.DOOR, self.MONSTER, self.ENTRANCE])
return inp, target, loss_mask
# ------------------------------------------------------------------
# 阶段三资源放置resource
# ------------------------------------------------------------------
def _make_stage3(self, raw_flat: np.ndarray, subset: str):
"""
阶段三以完整骨架为上下文预测 resource 位置
所有 resource 位置在输入中替换为 MASK
子集 A 随机保留部分 resource 作为上下文部分补全训练
其余子集始终 MASK 全部 resource
"""
target = raw_flat.copy()
inp = raw_flat.copy()
if subset == 'A':
# 随机遮盖部分 resource部分上下文补全
res_idx = np.where(inp == self.RESOURCE)[0]
if len(res_idx) > 0:
ratio = random.random() * 0.8 + 0.2 # 20%~100%
n = max(1, int(len(res_idx) * ratio))
chosen = np.random.choice(res_idx, n, replace=False)
inp[chosen] = self.MASK_ID
else:
pass # 无 resource 时无需处理
else:
# B/C/D全部 resource → MASK
inp[inp == self.RESOURCE] = self.MASK_ID
loss_mask = (inp == self.MASK_ID)
return inp, target, loss_mask
# ------------------------------------------------------------------
# __getitem__
# ------------------------------------------------------------------
def _augment(self, arr: np.ndarray) -> np.ndarray:
if np.random.rand() > 0.5:
k = np.random.randint(1, 4)
arr = np.rot90(arr, k).copy()
if np.random.rand() > 0.5:
arr = np.fliplr(arr).copy()
if np.random.rand() > 0.5:
arr = np.flipud(arr).copy()
return arr
def __getitem__(self, idx):
item = self.data[idx]
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
raw_flat = raw_np.reshape(-1) # [H*W]
subset = self._choose_subset()
if self.stage == 1:
stage_input_np, target_np, loss_mask_np = self._make_stage1(raw_flat, subset)
elif self.stage == 2:
stage_input_np, target_np, loss_mask_np = self._make_stage2(raw_flat, subset)
else:
stage_input_np, target_np, loss_mask_np = self._make_stage3(raw_flat, subset)
# 若 loss_mask 全为 False如地图中无 resource 时的 stage3
# 退回为全图损失,避免 NaN
if not loss_mask_np.any():
loss_mask_np = np.ones_like(loss_mask_np)
# VQ 切片:当前阶段编码器的输入(仅保留相关 tile
raw_t = torch.LongTensor(raw_flat)
vq_keep = self._VQ_KEEP[self.stage]
if vq_keep is None:
vq_slice = raw_t.clone()
else:
vq_slice = make_slice(raw_t, vq_keep)
# 结构标签
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
cond_room = item['roomCountLevel']
cond_branch = item['branchLevel']
cond_outer = item['outerWall']
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
return {
"raw_map": raw_t,
"vq_slice": vq_slice,
"stage_input": torch.LongTensor(stage_input_np),
"target_map": torch.LongTensor(target_np),
"loss_mask": torch.BoolTensor(loss_mask_np),
"subset": subset,
"struct_cond": struct_cond,
}
if __name__ == "__main__": if __name__ == "__main__":
import os import os
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json') data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json')

View File

@ -57,8 +57,12 @@ class GinkaMaskGIT(nn.Module):
# z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用 # z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用
self.z_proj = nn.Sequential( self.z_proj = nn.Sequential(
nn.Linear(d_z, d_model), nn.Linear(d_z, d_model * 2),
nn.LayerNorm(d_model), nn.LayerNorm(d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.LayerNorm(d_model)
) )
# 结构标签嵌入(编码到 d_z 维度) # 结构标签嵌入(编码到 d_z 维度)
@ -69,8 +73,12 @@ class GinkaMaskGIT(nn.Module):
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
self.struct_proj = nn.Sequential( self.struct_proj = nn.Sequential(
nn.Linear(d_z, d_model), nn.Linear(d_z, d_model * 2),
nn.LayerNorm(d_model), nn.LayerNorm(d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.LayerNorm(d_model)
) )
# Transformerencoder 做 map token 自注意力decoder 做与 z 的 cross-attention # Transformerencoder 做 map token 自注意力decoder 做与 z 的 cross-attention

654
ginka/train_stage.py Normal file
View File

@ -0,0 +1,654 @@
"""
三阶段级联训练脚本各阶段独立训练使用 GinkaStageDataset
总损失 = L_CE只对本阶段负责的 tile 位置计算+ beta * L_commit + gamma * L_entropy
各阶段分工
stage=1 结构骨架floor(0) + wall(1)
stage=2 功能元素door(2) + monster(4) + entrance(5)
stage=3 资源放置resource(3)
用法示例
python -m ginka.train_stage --stage 1
python -m ginka.train_stage --stage 2
python -m ginka.train_stage --stage 3
python -m ginka.train_stage --stage 1 --resume True --state result/stage1/stage1-10.pth
python -m ginka.train_stage --stage 2 --pretrain_vq result/joint/joint-50.pth
"""
import argparse
import math
import os
import sys
import random
from datetime import datetime
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from .vqvae.model import GinkaVQVAE
from .maskGIT.model import GinkaMaskGIT
from .dataset import GinkaStageDataset
from shared.image import matrix_to_image_cv
# ---------------------------------------------------------------------------
# 各阶段配置
# ---------------------------------------------------------------------------
# 共用 VQ-VAE 超参
VQ_L = 2 # 码字序列长度
VQ_K = 8 # codebook 大小
VQ_D_Z = 64 # 码字维度
VQ_BETA = 0.5 # commit loss 权重
VQ_GAMMA = 0.0 # entropy loss 权重
VQ_LAYERS = 3
VQ_DIM_FF = 512
VQ_D_MODEL = 64
VQ_NHEAD = 8
# 各阶段 MaskGIT 超参(按任务复杂度差异化配置)
STAGE_MG_CONFIGS = {
1: dict(d_model=256, nhead=8, num_layers=6, dim_ff=2048), # 结构骨架,最重要
2: dict(d_model=192, nhead=8, num_layers=4, dim_ff=1024), # 功能元素
3: dict(d_model=128, nhead=8, num_layers=3, dim_ff=512), # 资源放置,最简单
}
# 各阶段监控的 tile 集合(用于分类别召回率统计)
STAGE_TILE_SETS = {
1: {0: "floor", 1: "wall"},
2: {2: "door", 4: "monster", 5: "entrance"},
3: {3: "resource"},
}
# 各阶段损失权重(可单独调节 CE 与 VQ 损失的平衡)
# stage3 的 resource 极稀疏,大幅上调 ce_weight 以补偿类别不均衡
STAGE_LOSS_CONFIG = {
1: dict(ce_weight=1.0, vq_weight=1.0), # 结构骨架,标准权重
2: dict(ce_weight=1.5, vq_weight=0.5), # 功能元素较稀疏,上调 CE
3: dict(ce_weight=3.0, vq_weight=0.5), # resource 极稀疏,显著上调 CE
}
NUM_CLASSES = 7
MASK_TOKEN = 6
MAP_SIZE = 13 * 13
MAP_H = MAP_W = 13
FOCAL_GAMMA = 2.0
GENERATE_STEP = 18
BATCH_SIZE = 64
WALL_MASK_RATIO = 0.8
MG_Z_DROPOUT = 0.1
MG_STRUCT_DROPOUT = 0.1
SUBSET_WEIGHTS = (0.5, 0.2, 0.2, 0.1)
device = torch.device(
"cuda:1" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
disable_tqdm = not sys.stdout.isatty()
# ---------------------------------------------------------------------------
# 参数解析
# ---------------------------------------------------------------------------
def _str2bool(v):
if isinstance(v, bool): return v
if v.lower() in ('true', '1', 'yes'): return True
if v.lower() in ('false', '0', 'no'): return False
raise argparse.ArgumentTypeError(f"布尔值应为 True/False收到: {v!r}")
def parse_arguments():
parser = argparse.ArgumentParser(description="三阶段级联训练")
parser.add_argument("--stage", type=int, required=True, choices=[1, 2, 3])
parser.add_argument("--resume", type=_str2bool, default=False)
parser.add_argument(
"--state", type=str, default="",
help="续训时加载的检查点路径(自动推断 stage{N}/stage{N}-*.pth",
)
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=_str2bool, default=True)
parser.add_argument(
"--freeze_vq", type=_str2bool, default=False,
help="冻结 VQ 编码器,只训练 MaskGIT适合加载预训练编码器后热身",
)
parser.add_argument(
"--pretrain_vq", type=str, default="",
help="从 train_vq.py 的联合训练检查点中导入对应通道的 VQ 编码器权重",
)
return parser.parse_args()
# ---------------------------------------------------------------------------
# Focal Loss与 train_vq.py 一致)
# ---------------------------------------------------------------------------
def focal_loss(logits, targets, gamma=FOCAL_GAMMA, reduction='none'):
ce = F.cross_entropy(logits, targets, reduction='none')
pt = torch.exp(-ce)
fl = (1.0 - pt) ** gamma * ce
if reduction == 'mean': return fl.mean()
if reduction == 'sum': return fl.sum()
return fl
def masked_focal_loss(logits, targets, loss_mask, gamma=FOCAL_GAMMA):
"""
只对 loss_mask True 的位置计算 focal loss 均值
Args:
logits: [B, C, H*W]
targets: [B, H*W]
loss_mask: [B, H*W] bool
"""
per_token = focal_loss(logits, targets, gamma, reduction='none') # [B, H*W]
selected = per_token[loss_mask]
if selected.numel() == 0:
return per_token.mean()
return selected.mean()
# ---------------------------------------------------------------------------
# MaskGIT 推理cosine schedule
# ---------------------------------------------------------------------------
@torch.no_grad()
def maskgit_generate(
model_mg: GinkaMaskGIT,
z: torch.Tensor,
steps: int = GENERATE_STEP,
init_map: torch.Tensor = None,
struct_cond: torch.Tensor = None,
) -> torch.Tensor:
"""
迭代生成地图cosine schedule unmasking
Args:
init_map: 可选初始地图 MASK 位置在生成中保持不变
Returns:
[B, MAP_SIZE] LongTensor
"""
B = z.shape[0]
map_seq = (
torch.full((B, MAP_SIZE), MASK_TOKEN, device=device)
if init_map is None else init_map.clone().to(device)
)
generatable = (map_seq == MASK_TOKEN)
for step in range(steps):
if not generatable.any():
break
logits = model_mg(map_seq, z, struct_cond=struct_cond) # [B, S, C]
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
sampled = dist.sample()
confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1)
confidences = confidences.masked_fill(~generatable, float('inf'))
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
new_map = map_seq.clone()
for b in range(B):
n_gen = int(generatable[b].sum().item())
n_keep = int(ratio * n_gen)
if n_keep > 0:
_, keep_idx = torch.topk(confidences[b], k=n_keep, largest=False)
pred_b = sampled[b].clone()
pred_b[keep_idx] = MASK_TOKEN
new_map[b] = torch.where(generatable[b], pred_b, map_seq[b])
else:
new_map[b] = torch.where(generatable[b], sampled[b], map_seq[b])
map_seq = new_map
return map_seq
# ---------------------------------------------------------------------------
# 可视化工具(与 train_vq.py 保持一致)
# ---------------------------------------------------------------------------
def make_map_image(map_flat, tile_dict):
arr = map_flat.cpu().numpy().reshape(MAP_H, MAP_W)
return matrix_to_image_cv(arr, tile_dict)
def hstack_images(imgs, gap=4, color=(255, 255, 255)):
max_h = max(img.shape[0] for img in imgs)
def _pad(img):
dh = max_h - img.shape[0]
return img if dh == 0 else np.concatenate(
[img, np.full((dh, img.shape[1], 3), color, dtype=np.uint8)], axis=0)
vline = np.full((max_h, gap, 3), color, dtype=np.uint8)
result = _pad(imgs[0])
for img in imgs[1:]:
result = np.concatenate([result, vline, _pad(img)], axis=1)
return result
def grid_images(imgs, gap=4, bg=(255, 255, 255)):
n = len(imgs)
if n == 0: return np.zeros((1, 1, 3), dtype=np.uint8)
if n == 1: return imgs[0]
mid = math.ceil(n / 2)
top = hstack_images(imgs[:mid], gap, bg)
bot_imgs = imgs[mid:]
if not bot_imgs: return top
bot = hstack_images(bot_imgs, gap, bg)
tw, bw = top.shape[1], bot.shape[1]
if tw > bw:
bot = np.concatenate(
[bot, np.full((bot.shape[0], tw - bw, 3), bg, dtype=np.uint8)], axis=1)
elif bw > tw:
top = np.concatenate(
[top, np.full((top.shape[0], bw - tw, 3), bg, dtype=np.uint8)], axis=1)
hline = np.full((gap, top.shape[1], 3), bg, dtype=np.uint8)
return np.concatenate([top, hline, bot], axis=0)
def label_image(img, text, font_scale=0.45):
bar = np.full((16, img.shape[1], 3), (40, 40, 40), dtype=np.uint8)
cv2.putText(
bar, text, (2, 13), cv2.FONT_HERSHEY_SIMPLEX,
font_scale, (200, 200, 200), 1, cv2.LINE_AA,
)
return np.concatenate([bar, img], axis=0)
def make_random_struct_cond():
from .maskGIT.model import SYM_VOCAB, ROOM_VOCAB, BRANCH_VOCAB, OUTER_VOCAB
return torch.tensor([[
random.randint(0, SYM_VOCAB - 2),
random.randint(0, ROOM_VOCAB - 2),
random.randint(0, BRANCH_VOCAB - 2),
random.randint(0, OUTER_VOCAB - 2),
]], dtype=torch.long, device=device)
# ---------------------------------------------------------------------------
# 按阶段构造推理初始地图
# ---------------------------------------------------------------------------
def make_stage_init(stage: int, context_map: torch.Tensor) -> torch.Tensor:
"""
根据阶段构造 MaskGIT 的推理初始地图
Stage 1: MASK或保留稀疏 wall 种子
Stage 2: 保留 floor/wall 上下文其余 MASK
Stage 3: 保留完整上下文floor/wall/door/monster/entranceresource MASK
"""
init = context_map.clone()
if stage == 1:
# 全 MASK不依赖上下文地图
init = torch.full_like(init, MASK_TOKEN)
elif stage == 2:
# 保留 floor/wall其余 → MASK
mask = ~torch.isin(init, torch.tensor([0, 1], device=init.device))
init[mask] = MASK_TOKEN
else: # stage == 3
# 保留非 resourceresource → MASK
init[init == 3] = MASK_TOKEN
return init
def make_random_wall_seed(ratio_min=0.02, ratio_max=0.08):
ratio = random.uniform(ratio_min, ratio_max)
n_wall = max(2, int(MAP_SIZE * ratio))
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
idx = torch.randperm(MAP_SIZE)[:n_wall]
seed[0, idx] = 1
return seed
# ---------------------------------------------------------------------------
# 验证函数
# ---------------------------------------------------------------------------
@torch.no_grad()
def validate(
stage: int,
enc: GinkaVQVAE,
model_mg: GinkaMaskGIT,
dataloader_val: DataLoader,
tile_dict: dict,
epoch: int,
n_rand: int = 3,
):
enc.eval()
model_mg.eval()
epoch_dir = f"result/stage{stage}_img/e{epoch:04d}"
os.makedirs(epoch_dir, exist_ok=True)
val_loss_total = 0.0
val_steps = 0
captured = {s: None for s in ('A', 'B', 'C', 'D')}
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
raw_map = batch["raw_map"].to(device)
vq_slice = batch["vq_slice"].to(device)
stage_input = batch["stage_input"].to(device)
target_map = batch["target_map"].to(device)
loss_mask = batch["loss_mask"].to(device)
struct_cond = batch["struct_cond"].to(device)
subsets = batch["subset"]
z_q, _, _, vq_loss, _, _ = enc(vq_slice)
logits = model_mg(stage_input, z_q, struct_cond=struct_cond)
ce = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask)
val_loss_total += (ce + vq_loss).item()
val_steps += 1
for i in range(raw_map.shape[0]):
s = subsets[i]
if captured[s] is None:
captured[s] = {
"raw": raw_map[i:i+1].clone(),
"stage_input": stage_input[i:i+1].clone(),
"z_q": z_q[i:i+1].clone(),
"struct_cond": struct_cond[i:i+1].clone(),
}
if all(v is not None for v in captured.values()):
break
# ---- 可视化:每个子集一张图 ----------------------------------------
for sub, cap in captured.items():
if cap is None:
continue
raw_img = label_image(make_map_image(cap["raw"][0], tile_dict), "GT")
inp_img = label_image(make_map_image(cap["stage_input"][0], tile_dict), f"stage{stage} input")
# 真实 z 的迭代生成
init = make_stage_init(stage, cap["stage_input"][0].unsqueeze(0))
gen = maskgit_generate(
model_mg, cap["z_q"],
init_map=init, struct_cond=cap["struct_cond"],
)
gen_img = label_image(make_map_image(gen[0], tile_dict), "z_real gen")
# 随机 z 的生成
rand_imgs = []
for i in range(n_rand):
z_r = enc.sample(1, device)
sc_r = make_random_struct_cond()
init2 = make_stage_init(stage, cap["raw"][0].unsqueeze(0))
gen_r = maskgit_generate(model_mg, z_r, init_map=init2, struct_cond=sc_r)
rand_imgs.append(label_image(make_map_image(gen_r[0], tile_dict), f"z_rand_{i+1}"))
row = [raw_img, inp_img, gen_img] + rand_imgs
cv2.imwrite(f"{epoch_dir}/subset_{sub}.png", grid_images(row))
# ---- 场景:完全自主生成 -----------------------------------------------
# stage1从随机稀疏墙壁种子出发完全不依赖 GT
# stage2以验证集中采样的 floor/wall 结构为上下文,随机 z₂模拟级联推理
# stage3以验证集中采样的完整功能地图为上下文随机 z₃模拟级联推理
context_pool = [cap["raw"][0] for cap in captured.values() if cap is not None]
rand_free = []
for i in range(n_rand + 1):
z_r = enc.sample(1, device)
sc_r = make_random_struct_cond()
if stage == 1:
# 稀疏 wall 种子作为提示,模型自主补全 floor/wall
init = make_random_wall_seed()
else:
# 从验证集上下文池中轮流取一张图作为前序阶段的输出
ctx = context_pool[i % len(context_pool)].unsqueeze(0)
# make_stage_init 会自动将本阶段负责的 tile 位置替换为 MASK
init = make_stage_init(stage, ctx)
gen = maskgit_generate(model_mg, z_r, init_map=init, struct_cond=sc_r)
rand_free.append(label_image(make_map_image(gen[0], tile_dict), f"free_{i+1}"))
cv2.imwrite(f"{epoch_dir}/scene_free_random.png", grid_images(rand_free))
return val_loss_total / max(val_steps, 1)
# ---------------------------------------------------------------------------
# 主训练函数
# ---------------------------------------------------------------------------
def train():
print(f"Using device: {device}")
args = parse_arguments()
stage = args.stage
result_dir = f"result/stage{stage}"
result_img_dir = f"result/stage{stage}_img"
os.makedirs(result_dir, exist_ok=True)
os.makedirs(result_img_dir, exist_ok=True)
# ---- VQ 编码器(单路)----
mg_cfg = STAGE_MG_CONFIGS[stage]
enc = GinkaVQVAE(
num_classes=NUM_CLASSES,
L=VQ_L,
K=VQ_K,
d_z=VQ_D_Z,
d_model=VQ_D_MODEL,
nhead=VQ_NHEAD,
num_layers=VQ_LAYERS,
dim_ff=VQ_DIM_FF,
map_size=MAP_SIZE,
beta=VQ_BETA,
gamma=VQ_GAMMA,
).to(device)
model_mg = GinkaMaskGIT(
num_classes=NUM_CLASSES,
d_model=mg_cfg["d_model"],
d_z=VQ_D_Z,
dim_ff=mg_cfg["dim_ff"],
nhead=mg_cfg["nhead"],
num_layers=mg_cfg["num_layers"],
map_size=MAP_SIZE,
z_dropout=MG_Z_DROPOUT,
struct_dropout=MG_STRUCT_DROPOUT,
).to(device)
enc_params = sum(p.numel() for p in enc.parameters())
mg_params = sum(p.numel() for p in model_mg.parameters())
print(f"[Stage {stage}] VQ Encoder 参数量: {enc_params:,} ({enc_params/1e6:.3f}M)")
print(f"[Stage {stage}] MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)")
# ---- 数据集 ----
dataset_train = GinkaStageDataset(
args.train,
stage=stage,
subset_weights=SUBSET_WEIGHTS,
wall_mask_ratio=WALL_MASK_RATIO,
)
dataset_val = GinkaStageDataset(
args.validate,
stage=stage,
subset_weights=SUBSET_WEIGHTS,
room_thresholds=dataset_train.room_th,
branch_thresholds=dataset_train.branch_th,
wall_mask_ratio=WALL_MASK_RATIO,
)
dataloader_train = DataLoader(
dataset_train,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
pin_memory=(device.type == "cuda"),
)
dataloader_val = DataLoader(
dataset_val,
batch_size=8,
shuffle=True,
num_workers=0,
)
# ---- 优化器 ----
all_params = list(enc.parameters()) + list(model_mg.parameters())
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs, eta_min=1e-6,
)
# ---- 权重加载 ----
start_epoch = 0
if args.pretrain_vq:
# 从 train_vq.py 的联合训练检查点加载对应通道的 VQ 编码器
ckpt = torch.load(args.pretrain_vq, map_location=device)
enc_key = f"enc{stage}"
if enc_key in ckpt:
enc.load_state_dict(ckpt[enc_key], strict=False)
print(f"已从 {args.pretrain_vq} 加载 {enc_key} 权重。")
else:
print(f"警告:检查点中未找到 {enc_key},跳过权重加载。")
if args.resume:
state_path = args.state or f"{result_dir}/stage{stage}-latest.pth"
ckpt = torch.load(state_path, map_location=device)
enc.load_state_dict(ckpt["enc"], strict=False)
model_mg.load_state_dict(ckpt["mg_state"], strict=False)
if args.load_optim and ckpt.get("optim_state") is not None:
optimizer.load_state_dict(ckpt["optim_state"])
start_epoch = ckpt.get("epoch", 0)
print(f"从 epoch {start_epoch} 接续训练。")
# ---- tile 贴图 ----
tile_dict = {}
for f in os.listdir("tiles"):
name = os.path.splitext(f)[0]
img = cv2.imread(f"tiles/{f}", cv2.IMREAD_UNCHANGED)
if img is not None:
tile_dict[name] = img
# ---- 冻结 VQ 编码器(可选)----
if args.freeze_vq:
for p in enc.parameters():
p.requires_grad_(False)
print(f"[Stage {stage}] VQ 编码器已冻结。")
# ---- 训练循环 ----
for epoch in tqdm(
range(start_epoch, start_epoch + args.epochs),
desc=f"Stage{stage} Training",
disable=disable_tqdm,
):
enc.train()
model_mg.train()
loss_total = 0.0
ce_total = 0.0
vq_loss_total = 0.0
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
# 按 tile 统计召回率(用于监控各类 tile 的预测准确性)
tile_correct = {tid: 0 for tid in STAGE_TILE_SETS[stage]}
tile_total = {tid: 0 for tid in STAGE_TILE_SETS[stage]}
for batch in tqdm(
dataloader_train,
leave=False,
desc="Epoch Progress",
disable=disable_tqdm,
):
raw_map = batch["raw_map"].to(device)
vq_slice = batch["vq_slice"].to(device)
stage_input = batch["stage_input"].to(device)
target_map = batch["target_map"].to(device)
loss_mask = batch["loss_mask"].to(device)
struct_cond = batch["struct_cond"].to(device)
for s in batch["subset"]:
subset_stats[s] = subset_stats.get(s, 0) + 1
# ---- 前向传播 ----
z_q, _, _, vq_loss, commit_loss, entropy_loss = enc(vq_slice)
logits = model_mg(stage_input, z_q, struct_cond=struct_cond) # [B, S, C]
# ---- 仅对本阶段 tile 位置计算 focal loss ----
ce_loss = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask)
loss_cfg = STAGE_LOSS_CONFIG[stage]
loss = loss_cfg["ce_weight"] * ce_loss + loss_cfg["vq_weight"] * vq_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
optimizer.step()
loss_total += loss.detach().item()
ce_total += ce_loss.detach().item()
vq_loss_total += vq_loss.detach().item()
# ---- 分 tile 召回率统计 ----
with torch.no_grad():
preds = logits.argmax(dim=-1) # [B, S]
for tid in STAGE_TILE_SETS[stage]:
gt_mask = (target_map == tid) & loss_mask
tile_total[tid] += gt_mask.sum().item()
tile_correct[tid] += (preds[gt_mask] == tid).sum().item()
scheduler.step()
n = len(dataloader_train)
recall_str = " ".join(
f"{STAGE_TILE_SETS[stage][tid]}={tile_correct[tid]/(tile_total[tid]+1e-6):.2%}"
for tid in STAGE_TILE_SETS[stage]
)
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"Epoch {epoch + 1:4d} | "
f"Loss {loss_total/n:.5f} "
f"Focal {ce_total/n:.5f} "
f"VQ {vq_loss_total/n:.5f} | "
f"Recall: {recall_str} | "
f"LR {scheduler.get_last_lr()[0]:.6f} | "
f"Subsets {subset_stats}"
)
# ---- 检查点 + 验证 ----
if (epoch + 1) % args.checkpoint == 0:
ckpt_path = f"{result_dir}/stage{stage}-{epoch + 1}.pth"
torch.save({
"epoch": epoch + 1,
"stage": stage,
"enc": enc.state_dict(),
"mg_state": model_mg.state_dict(),
"optim_state": optimizer.state_dict(),
}, ckpt_path)
tqdm.write(f" 检查点已保存: {ckpt_path}")
val_loss = validate(stage, enc, model_mg, dataloader_val, tile_dict, epoch + 1)
tqdm.write(f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}")
enc.train()
model_mg.train()
# ---- 最终存档 ----
torch.save({
"epoch": start_epoch + args.epochs,
"stage": stage,
"enc": enc.state_dict(),
"mg_state": model_mg.state_dict(),
}, f"{result_dir}/stage{stage}_final.pth")
print(f"[Stage {stage}] 训练结束。")
# ---------------------------------------------------------------------------
if __name__ == "__main__":
torch.set_num_threads(4)
train()