mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-13 20:32:44 +08:00
feat: 分三阶段训练
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
3b86a78e0b
commit
3676958781
499
docs/three-stage-generation-design.md
Normal file
499
docs/three-stage-generation-design.md
Normal file
@ -0,0 +1,499 @@
|
||||
# 三阶段级联地图生成设计文档
|
||||
|
||||
## 背景与问题诊断
|
||||
|
||||
### 当前问题:墙壁过度生成
|
||||
|
||||
现有单模型方案(VQ-VAE + MaskGIT 联合训练)在实践中呈现出**模型严重偏向生成墙壁**的现象,具体表现为:
|
||||
|
||||
- 生成地图中墙壁密度显著偏高,可通行空间不足;
|
||||
- 门、怪物、入口等功能性元素稀少甚至缺失;
|
||||
- 资源类 tile 几乎不出现。
|
||||
|
||||
### 根本原因分析
|
||||
|
||||
从 token 分布角度来看,13×13 = 169 个格子中,各类 tile 的分布严重不均:
|
||||
|
||||
| tile 类型 | 典型比例 | 训练信号 |
|
||||
| ------------- | -------- | -------- |
|
||||
| 墙壁(wall) | 40%~60% | 强、密集 |
|
||||
| 空地(floor) | 20%~40% | 强、密集 |
|
||||
| 门(door) | 1%~5% | 极弱 |
|
||||
| 怪物 | 5%~15% | 弱 |
|
||||
| 入口 | 1%~3% | 极弱 |
|
||||
| 资源 | 5%~15% | 弱 |
|
||||
|
||||
单模型将以上所有类别的 **交叉熵损失**混合在一起优化。由于墙壁和空地占据绝大多数 token,模型可以通过反复预测墙壁/空地获得低训练损失,而无需真正学习如何放置稀有类别。
|
||||
|
||||
这是**类别不均衡问题的结构性体现**:即使引入 Focal Loss,调整 loss 权重,也难以从根本上解决三类任务学习难度不匹配的问题——因为结构骨架(floor/wall)是其他所有元素放置的前提,混合训练会导致模型困于局部最优解("把所有位置都预测为墙壁")。
|
||||
|
||||
---
|
||||
|
||||
## 核心思路:三阶段级联生成
|
||||
|
||||
将单次全类别生成拆分为**三个独立的 MaskGIT 阶段**,每阶段只负责一组语义相近、结构约束相似的 tile 类别。后续阶段以前序阶段的输出作为已知上下文。
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ 阶段一:结构骨架生成 │
|
||||
│ │
|
||||
│ 全 MASK ──► [Stage1-MaskGIT + z₁] ──► floor/wall 地图 │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼ 已知 floor/wall 上下文
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ 阶段二:功能元素放置 │
|
||||
│ │
|
||||
│ floor/wall + MASK ──► [Stage2-MaskGIT + z₂] ──► door/monster/入口│
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼ 已知 floor/wall/door/monster/入口 上下文
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ 阶段三:资源放置 │
|
||||
│ │
|
||||
│ 完整上下文 + MASK ──► [Stage3-MaskGIT + z₃] ──► resource │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 各阶段职责
|
||||
|
||||
| 阶段 | 负责类别 | tile 数 | 类别数(含 MASK) | 结构约束强度 |
|
||||
| ------ | -------------------------------- | ------- | ----------------- | ------------ |
|
||||
| 阶段一 | floor(0)、wall(1) | 2 | 3 | 极强 |
|
||||
| 阶段二 | door(2)、monster(4)、entrance(5) | 3 | 6 | 强 |
|
||||
| 阶段三 | resource(3) | 1 | 3 | 弱 |
|
||||
|
||||
注:各阶段模型的词表不需要只含本阶段 tile,完整词表(7 类)可以保留不变,只是**被 MASK 的位置**和**计算 Loss 的位置**会有所不同(见训练策略)。
|
||||
|
||||
---
|
||||
|
||||
## 每阶段 tile 映射与 MASK 策略
|
||||
|
||||
### 阶段一
|
||||
|
||||
**输入**:所有位置均填充 MASK token(`tile=6`)。
|
||||
|
||||
**目标**:预测 floor(0) 和 wall(1);原始地图中所有非 floor/wall 的 tile 在训练目标中被**重映射为 floor(0)**,因为阶段一模型不关心功能性元素的具体种类。
|
||||
|
||||
```python
|
||||
STAGE1_REMAP = {
|
||||
0: 0, # floor → floor
|
||||
1: 1, # wall → wall
|
||||
2: 0, # door → floor(视作空地,让阶段二填充)
|
||||
3: 0, # resource → floor
|
||||
4: 0, # monster → floor
|
||||
5: 0, # entrance → floor
|
||||
6: 6, # MASK → MASK(不参与损失计算)
|
||||
}
|
||||
```
|
||||
|
||||
**Loss 计算范围**:所有非 MASK 位置(即全局所有位置,因为输入均为 MASK)。
|
||||
|
||||
### 阶段二
|
||||
|
||||
**输入**:将阶段一输出的 floor/wall 地图作为固定上下文,原始地图中属于 door/monster/entrance 的位置替换为 MASK token,其余(floor/wall)位置保持不变。
|
||||
|
||||
```python
|
||||
STAGE2_MASK_IDS = {2, 4, 5} # 需要在阶段二中被预测的 tile ID
|
||||
```
|
||||
|
||||
**训练时输入构造**(使用 GT 地图中的 floor/wall 作为上下文,不使用阶段一的实际输出):
|
||||
|
||||
```python
|
||||
def make_stage2_input(gt_map: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
gt_map: [H*W] 整数数组,包含完整原始地图
|
||||
返回: stage2 输入地图,door/monster/entrance 位置替换为 MASK
|
||||
"""
|
||||
inp = gt_map.copy()
|
||||
# 先将所有资源归一化为 floor(阶段二不负责资源)
|
||||
inp[np.isin(inp, [3])] = 0
|
||||
# 将 door/monster/entrance 位置 MASK 掉
|
||||
inp[np.isin(inp, [2, 4, 5])] = 6 # MASK
|
||||
return inp
|
||||
```
|
||||
|
||||
**目标**:预测 door(2)、monster(4)、entrance(5) 位置的类别(以及它们所在位置原本是否是 floor)。
|
||||
|
||||
**Loss 计算范围**:仅对输入为 MASK(即原本是 door/monster/entrance 的位置)计算损失,floor/wall 位置不参与损失(它们已经确定)。
|
||||
|
||||
### 阶段三
|
||||
|
||||
**输入**:将阶段二输出的完整功能性地图(floor/wall/door/monster/entrance)作为固定上下文,原始地图中资源(tile=3)位置替换为 MASK token。
|
||||
|
||||
```python
|
||||
def make_stage3_input(gt_map: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
gt_map: [H*W] 整数数组,包含完整原始地图
|
||||
返回: stage3 输入地图,resource 位置替换为 MASK
|
||||
"""
|
||||
inp = gt_map.copy()
|
||||
inp[inp == 3] = 6 # resource → MASK
|
||||
return inp
|
||||
```
|
||||
|
||||
**Loss 计算范围**:仅对输入为 MASK(即原本是 resource 的位置)计算损失。
|
||||
|
||||
---
|
||||
|
||||
## 模型架构
|
||||
|
||||
### 共享基础架构
|
||||
|
||||
三个阶段均采用与现有方案**相同的 MaskGIT + VQ-VAE 架构**,不需要设计新的模型结构。核心差异在于:
|
||||
|
||||
1. **输入词表**:与现有方案一致,均使用 7 类词表(`NUM_CLASSES=7`,`MASK_TOKEN=6`);
|
||||
2. **z 来源**:每阶段使用来自对应 VQ 通道的隐变量;
|
||||
3. **Loss 掩码**:只对该阶段负责的位置计算 CE Loss。
|
||||
|
||||
### VQ-VAE z 的阶段分配
|
||||
|
||||
现有 VQ-VAE 已采用三通道设计(CH1/CH2/CH3),与三个生成阶段自然对应:
|
||||
|
||||
| 通道 | 编码目标 | 供应阶段 | 描述 |
|
||||
| ---- | ----------------- | -------- | ---------------------------- |
|
||||
| CH1 | floor/wall 骨架 | 阶段一 | z₁,控制墙壁结构多样性 |
|
||||
| CH2 | door/monster/入口 | 阶段二 | z₂,控制功能元素布局多样性 |
|
||||
| CH3 | resource 分布 | 阶段三 | z₃,控制资源密度与位置多样性 |
|
||||
|
||||
```
|
||||
真实地图 ──► CH1 encoder ──► z₁ ──► Stage1 MaskGIT
|
||||
──► CH2 encoder ──► z₂ ──► Stage2 MaskGIT
|
||||
──► CH3 encoder ──► z₃ ──► Stage3 MaskGIT
|
||||
```
|
||||
|
||||
**通道专属编码**:各通道编码器在编码前只"看"与自身相关的 tile,其余 tile 视作 floor(0):
|
||||
|
||||
```python
|
||||
def ch1_mask(gt_map):
|
||||
"""只保留 floor/wall,其余置 0"""
|
||||
m = gt_map.copy()
|
||||
m[~np.isin(m, [0, 1])] = 0
|
||||
return m
|
||||
|
||||
def ch2_mask(gt_map):
|
||||
"""只保留 door/monster/entrance,其余置 0"""
|
||||
m = gt_map.copy()
|
||||
m[~np.isin(m, [2, 4, 5])] = 0
|
||||
return m
|
||||
|
||||
def ch3_mask(gt_map):
|
||||
"""只保留 resource,其余置 0"""
|
||||
m = gt_map.copy()
|
||||
m[m != 3] = 0
|
||||
return m
|
||||
```
|
||||
|
||||
### Loss 计算掩码实现
|
||||
|
||||
训练时,每阶段额外接收一个 `loss_mask: torch.BoolTensor [B, H*W]`,指示哪些位置需要计算损失:
|
||||
|
||||
```python
|
||||
# 阶段一:所有位置(因为输入全为 MASK)
|
||||
loss_mask_s1 = torch.ones(B, MAP_SIZE, dtype=torch.bool)
|
||||
|
||||
# 阶段二:只有原本是 door/monster/entrance 的位置
|
||||
loss_mask_s2 = torch.isin(raw_map, torch.tensor([2, 4, 5]))
|
||||
|
||||
# 阶段三:只有原本是 resource 的位置
|
||||
loss_mask_s3 = (raw_map == 3)
|
||||
```
|
||||
|
||||
Focal Loss 修改为只对 `loss_mask` 为 True 的位置求和后做归一化:
|
||||
|
||||
```python
|
||||
def stage_focal_loss(logits, targets, loss_mask, gamma=2.0):
|
||||
# logits: [B, C, H*W], targets: [B, H*W], loss_mask: [B, H*W]
|
||||
per_token_loss = focal_loss(logits, targets, gamma, reduction='none') # [B, H*W]
|
||||
masked_loss = per_token_loss[loss_mask]
|
||||
return masked_loss.mean() if masked_loss.numel() > 0 else per_token_loss.mean()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 训练策略
|
||||
|
||||
### 训练方式:顺序训练(推荐)
|
||||
|
||||
三个阶段**依次训练**,后续阶段训练时使用 GT 地图作为前序阶段的"完美输出"(teacher forcing),而非使用前序阶段模型的实际推理结果。
|
||||
|
||||
```
|
||||
训练阶段一:
|
||||
data: (全MASK输入, stage1目标)
|
||||
loss: focal loss on all positions
|
||||
|
||||
训练阶段二:
|
||||
data: (floor/wall上下文 + MASK输入, stage2目标)
|
||||
loss: focal loss only on door/monster/entrance positions
|
||||
|
||||
训练阶段三:
|
||||
data: (floor/wall/door/monster/入口上下文 + MASK输入, stage3目标)
|
||||
loss: focal loss only on resource positions
|
||||
```
|
||||
|
||||
**使用 GT 而非前序模型输出的理由**:
|
||||
|
||||
- 避免误差级联(前序模型若生成错误的骨架,后续模型的训练将在错误分布上进行);
|
||||
- 各阶段训练更稳定,收敛更快;
|
||||
- 阶段之间解耦,方便单独迭代和调试。
|
||||
|
||||
### 各阶段训练子集划分
|
||||
|
||||
每个阶段均沿用现有 A/B/C/D 子集划分逻辑,但 MASK 策略应用在对应阶段的目标 tile 上:
|
||||
|
||||
| 子集 | 阶段一 | 阶段二 | 阶段三 |
|
||||
| ---- | ----------------------------- | ------------------------------ | --------------------------- |
|
||||
| A | 随机遮盖部分 floor/wall | 随机遮盖部分 door/monster/入口 | 随机遮盖部分 resource |
|
||||
| B | 保留全部 wall,MASK floor | 给定全部 wall,MASK 功能元素 | 给定全部骨架,MASK 部分资源 |
|
||||
| C | 随机保留部分 wall,MASK 其余 | 同 B | 同 B |
|
||||
| D | 保留 wall+entrance,MASK 其余 | 给定 wall+entrance,MASK 门/怪 | 同 B |
|
||||
|
||||
### 各阶段专属损失
|
||||
|
||||
每阶段的总损失:
|
||||
|
||||
$$\mathcal{L}^{(s)} = \mathcal{L}_{CE}^{(s)} + \beta \cdot \mathcal{L}_{commit}^{(s)} + \gamma \cdot \mathcal{L}_{uniform}^{(s)}$$
|
||||
|
||||
其中 $s \in \{1, 2, 3\}$ 表示阶段编号,$\mathcal{L}_{CE}^{(s)}$ 只在该阶段负责的 tile 位置上计算。
|
||||
|
||||
---
|
||||
|
||||
## 推理流程
|
||||
|
||||
### 完整推理管线
|
||||
|
||||
```
|
||||
1. 随机采样 z₁, z₂, z₃(各自从对应 codebook 均匀采样 L 个 index)
|
||||
|
||||
2. 阶段一推理(结构骨架)
|
||||
初始状态: 全部 169 个位置 = MASK token
|
||||
迭代 MaskGIT 解码(cosine schedule,约 18 步):
|
||||
输入: MASK地图 + z₁
|
||||
输出: 逐步填充 floor/wall,直到无 MASK 位置
|
||||
结果: floor/wall 骨架地图 M₁
|
||||
|
||||
3. 阶段二推理(功能元素放置)
|
||||
初始状态: 继承 M₁,在 floor 位置随机选取候选位置置为 MASK
|
||||
─── 或 ───
|
||||
初始状态: 继承 M₁,所有 floor 位置均置为 MASK(让模型决定放置密度)
|
||||
迭代 MaskGIT 解码:
|
||||
输入: 含已知 wall 的掩码地图 + z₂
|
||||
约束: wall 位置不参与 unmask,保持不变
|
||||
输出: 逐步填充 door/monster/entrance
|
||||
结果: 含功能元素的地图 M₂
|
||||
|
||||
4. 阶段三推理(资源放置)
|
||||
初始状态: 继承 M₂,所有 floor 位置(未被阶段二填充的)置为 MASK
|
||||
迭代 MaskGIT 解码:
|
||||
输入: 含已知 wall/door/monster/入口的掩码地图 + z₃
|
||||
约束: 非 floor 位置保持不变
|
||||
输出: 逐步填充 resource(或保持 floor)
|
||||
结果: 完整地图 M₃
|
||||
```
|
||||
|
||||
### 阶段二初始 MASK 策略
|
||||
|
||||
阶段二的初始状态有两种选择:
|
||||
|
||||
| 策略 | 描述 | 适用场景 |
|
||||
| ------------- | -------------------------------------------------- | ---------------- |
|
||||
| 全 floor MASK | 所有 floor 位置均置为 MASK,让模型自主决定放置密度 | 完全随机生成 |
|
||||
| 候选位置 MASK | 只 MASK 用户指定或随机抽取的少数位置 | 用户指定部分位置 |
|
||||
|
||||
对于完全随机生成场景,**推荐使用全 floor MASK**——此时 z₂ 决定功能元素的总体风格(密集/稀疏/集中/分散),模型负责在此约束下寻找合理位置。
|
||||
|
||||
### 用户交互场景
|
||||
|
||||
| 场景 | 阶段一输入 | 阶段二输入 | 阶段三输入 |
|
||||
| ------------------- | -------------------- | ------------------------ | ---------- |
|
||||
| 完全随机 | 全 MASK | 继承 M₁ | 继承 M₂ |
|
||||
| 用户手绘墙壁 | 已知 wall + MASK | 继承 M₁(固定 wall) | 继承 M₂ |
|
||||
| 用户指定入口 | 已知 entrance + MASK | 继承 M₁(固定 entrance) | 继承 M₂ |
|
||||
| 用户手绘墙+指定入口 | 已知 wall/entrance | 继承 M₁ | 继承 M₂ |
|
||||
|
||||
---
|
||||
|
||||
## 实现方案
|
||||
|
||||
### 方案一:三个独立 GinkaMaskGIT 实例(推荐)
|
||||
|
||||
每个阶段分别实例化一个 `GinkaMaskGIT`,共用同一个 `GinkaVQVAE`(其中已含三通道编码器)。各阶段模型独立训练、独立存储。
|
||||
|
||||
```python
|
||||
# 模型结构
|
||||
vqvae = GinkaVQVAE(...) # 三通道共享编码器
|
||||
stage1 = GinkaMaskGIT(...) # 结构骨架
|
||||
stage2 = GinkaMaskGIT(...) # 功能元素
|
||||
stage3 = GinkaMaskGIT(...) # 资源
|
||||
```
|
||||
|
||||
**优点**:
|
||||
|
||||
- 各阶段模型完全解耦,可独立调整超参、单独重训;
|
||||
- 模型大小可以针对任务难度调整(阶段一可以更大,阶段三可以更小);
|
||||
- 便于调试和增量式开发。
|
||||
|
||||
**缺点**:
|
||||
|
||||
- 三个模型分别保存,推理时需依次加载;
|
||||
- 总参数量约为原方案的 3 倍(但可以通过缩小各阶段模型来对冲)。
|
||||
|
||||
### 方案二:单模型 + 阶段 Embedding(备选)
|
||||
|
||||
复用同一个 `GinkaMaskGIT`,添加阶段 embedding(类似 BERT 的 segment embedding):
|
||||
|
||||
```python
|
||||
self.stage_embedding = nn.Embedding(3, d_model) # 三个阶段
|
||||
|
||||
def forward(self, map, z, stage: int):
|
||||
x = self.tile_embedding(map) + self.pos_embedding
|
||||
x = x + self.stage_embedding(torch.tensor(stage)) # 阶段条件
|
||||
...
|
||||
```
|
||||
|
||||
**优点**:参数量与原方案相同,推理时只需加载一个模型。
|
||||
**缺点**:三阶段共享所有权重,可能导致阶段间干扰;阶段一的结构任务与阶段三的稀疏资源任务表示空间差异大,单一模型难以同时擅长。
|
||||
|
||||
**结论**:**推荐方案一**,尤其是在当前阶段(验证多阶段框架可行性),可先分别训练三个轻量模型进行快速验证。
|
||||
|
||||
### 各阶段模型规模建议
|
||||
|
||||
| 阶段 | 任务难度 | 建议 d_model | 建议 num_layers | 参数量估算 |
|
||||
| ------ | ---------- | ------------ | --------------- | ---------- |
|
||||
| 阶段一 | 高(结构) | 256 | 6 | ~4M |
|
||||
| 阶段二 | 中(功能) | 192 | 4 | ~2M |
|
||||
| 阶段三 | 低(稀疏) | 128 | 3 | ~0.8M |
|
||||
|
||||
---
|
||||
|
||||
## Dataset 修改方案
|
||||
|
||||
### 新增 `GinkaStageDataset`
|
||||
|
||||
需要扩展 `dataset.py`,增加针对三阶段的 Dataset 类,或在 `GinkaVQDataset` 中添加 `stage` 参数:
|
||||
|
||||
```python
|
||||
class GinkaStageDataset(Dataset):
|
||||
"""
|
||||
三阶段级联训练专用 Dataset。
|
||||
|
||||
返回 dict:
|
||||
raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码)
|
||||
stage_input: LongTensor [H*W] 当前阶段 MaskGIT 输入(含上下文 + MASK)
|
||||
target_map: LongTensor [H*W] CE loss ground truth(等同 raw_map)
|
||||
loss_mask: BoolTensor [H*W] 只对 True 位置计算损失
|
||||
subset: str 子集标识
|
||||
"""
|
||||
|
||||
STAGE1_TARGETS = {0, 1} # floor, wall
|
||||
STAGE2_TARGETS = {2, 4, 5} # door, monster, entrance
|
||||
STAGE3_TARGETS = {3} # resource
|
||||
```
|
||||
|
||||
### 数据构造函数
|
||||
|
||||
```python
|
||||
def make_stage1_sample(gt_map: np.ndarray, mask_id: int = 6):
|
||||
"""阶段一:全 MASK 输入,目标是 floor/wall(其余归一为 floor)"""
|
||||
stage_input = np.full_like(gt_map, mask_id)
|
||||
target = gt_map.copy()
|
||||
target[~np.isin(target, [0, 1])] = 0 # 非结构 tile → floor
|
||||
loss_mask = np.ones_like(gt_map, dtype=bool)
|
||||
return stage_input, target, loss_mask
|
||||
|
||||
def make_stage2_sample(gt_map: np.ndarray, mask_id: int = 6):
|
||||
"""阶段二:floor/wall 为上下文,door/monster/entrance 位置 MASK"""
|
||||
stage_input = gt_map.copy()
|
||||
stage_input[stage_input == 3] = 0 # 资源 → floor(阶段二不负责)
|
||||
target_ids = np.isin(stage_input, [2, 4, 5])
|
||||
stage_input[target_ids] = mask_id # 功能元素 → MASK
|
||||
target = gt_map.copy()
|
||||
target[target == 3] = 0 # target 中资源也视为 floor
|
||||
loss_mask = (gt_map != 0) & (gt_map != 1) & (gt_map != 3) # 只计算功能元素位置
|
||||
return stage_input, target, loss_mask
|
||||
|
||||
def make_stage3_sample(gt_map: np.ndarray, mask_id: int = 6):
|
||||
"""阶段三:全上下文保留,只 MASK 资源位置"""
|
||||
stage_input = gt_map.copy()
|
||||
stage_input[stage_input == 3] = mask_id # 资源 → MASK
|
||||
target = gt_map.copy()
|
||||
loss_mask = (gt_map == 3) # 只计算资源位置
|
||||
return stage_input, target, loss_mask
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 训练脚本设计
|
||||
|
||||
### 新增 `ginka/train_stage.py`
|
||||
|
||||
```
|
||||
用法示例:
|
||||
python -m ginka.train_stage --stage 1
|
||||
python -m ginka.train_stage --stage 2
|
||||
python -m ginka.train_stage --stage 3 --resume True --state result/stage3/stage3-10.pth
|
||||
```
|
||||
|
||||
各阶段检查点分别存储到:
|
||||
|
||||
- `result/stage1/stage1-{epoch}.pth`
|
||||
- `result/stage2/stage2-{epoch}.pth`
|
||||
- `result/stage3/stage3-{epoch}.pth`
|
||||
|
||||
### 阶段二/三的 VQ 编码器冻结策略
|
||||
|
||||
阶段二训练时,CH1 编码器(已在阶段一训练中充分收敛)可选择冻结,只更新 CH2 编码器和 Stage2 MaskGIT:
|
||||
|
||||
```python
|
||||
if args.stage == 2:
|
||||
for p in vqvae.encoder_ch1.parameters():
|
||||
p.requires_grad_(False) # 冻结 CH1
|
||||
```
|
||||
|
||||
类似地,阶段三训练时可冻结 CH1 和 CH2 编码器。
|
||||
|
||||
---
|
||||
|
||||
## 与现有方案的对比
|
||||
|
||||
| 维度 | 现有单模型方案 | 三阶段级联方案 |
|
||||
| -------------- | ------------------------ | ---------------------------------------- |
|
||||
| 墙壁过度生成 | 存在,难以从根本解决 | 阶段一单独训练骨架,Loss 聚焦 floor/wall |
|
||||
| 训练信号均衡性 | 墙壁主导,稀有类欠拟合 | 各阶段 Loss 只计算本阶段 tile,信号均衡 |
|
||||
| 模型可调试性 | 单一模型,各类别相互干扰 | 各阶段独立,可单独分析每阶段表现 |
|
||||
| 推理速度 | 1 次完整 MaskGIT 解码 | 3 次级联解码(总步数约为原来 3 倍) |
|
||||
| 误差累积 | 无 | 存在,前序阶段错误会传播到后续阶段 |
|
||||
| 用户可控性 | 较难(条件混合) | 好(可在任意阶段注入用户约束) |
|
||||
| 参数量 | ~4M | 约 ~7M(可通过缩减各阶段规模控制) |
|
||||
|
||||
---
|
||||
|
||||
## 预期收益
|
||||
|
||||
1. **解决墙壁过度生成**:阶段一专门针对 floor/wall 训练,类别分布从 7 类压缩到 3 类,Loss 完全聚焦,模型不再有逃避路径;
|
||||
2. **功能元素召回率提升**:阶段二以已知骨架为前提生成 door/monster/entrance,训练信号不再被墙壁噪声稀释;
|
||||
3. **资源分布更合理**:阶段三在完整上下文下放置资源,能感知到门/怪物位置,避免资源与关键功能元素重叠;
|
||||
4. **可交互性增强**:用户可在任意阶段注入约束(固定某些 tile),天然支持层次化编辑。
|
||||
|
||||
---
|
||||
|
||||
## 风险与应对
|
||||
|
||||
| 风险 | 描述 | 应对策略 |
|
||||
| ---------------- | ----------------------------------------- | -------------------------------------------------------- |
|
||||
| 误差累积 | 阶段一骨架不准确会导致阶段二/三布局失真 | 阶段一优先保证质量,推理时对骨架做后处理校验 |
|
||||
| 推理耗时增加 | 三次 MaskGIT 解码约 3 倍耗时 | 减少 MaskGIT 迭代步数(阶段二/三任务更简单,步数可减半) |
|
||||
| 阶段二稀疏性问题 | 169 格子中功能元素极少,Loss 计算覆盖率低 | 适当提高 stage 2 的 loss_mask 覆盖(周边 floor 也计入) |
|
||||
| Codebook 对齐 | 三通道 VQ-VAE 的 z 在分阶段训练时各自优化 | 联合训练阶段一+VQ-CH1,联合训练阶段二+VQ-CH2,以此类推 |
|
||||
|
||||
---
|
||||
|
||||
## 实施顺序
|
||||
|
||||
- [ ] 新增 `GinkaStageDataset`,实现三阶段数据构造函数
|
||||
- [ ] 新增 `ginka/train_stage.py`,支持 `--stage 1/2/3` 参数
|
||||
- [ ] 阶段一训练:仅 floor/wall,验证骨架生成质量(关键里程碑)
|
||||
- [ ] 阶段二训练:以 GT floor/wall 为上下文,验证功能元素召回率
|
||||
- [ ] 阶段三训练:以 GT 完整地图为上下文,验证资源放置合理性
|
||||
- [ ] 实现三阶段级联推理脚本,接入现有可视化工具
|
||||
- [ ] 对比实验:三阶段方案 vs 现有单模型方案在墙壁密度、功能元素召回率、资源分布等指标上的差异
|
||||
375
ginka/dataset.py
375
ginka/dataset.py
@ -4,14 +4,69 @@ import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
def _compute_map_labels(map_2d) -> dict:
|
||||
"""
|
||||
从 2D 地图列表(或 numpy 数组)推算结构标签。
|
||||
当 JSON 数据缺少 roomCount / highDegBranchCount / outerWall 字段时调用。
|
||||
"""
|
||||
arr = np.array(map_2d, dtype=np.int64) # [H, W]
|
||||
H, W = arr.shape
|
||||
WALL, ENTRY = 1, 5
|
||||
|
||||
# outerWall:最外圈中 wall+entry 占比 > 90%
|
||||
border = np.concatenate([arr[0, :], arr[-1, :], arr[1:-1, 0], arr[1:-1, -1]])
|
||||
total_b = border.size
|
||||
outer_wall = int(total_b > 0 and np.sum((border == WALL) | (border == ENTRY)) / total_b > 0.9)
|
||||
|
||||
# roomCount:BFS 统计 floor(0)+resource(3) 连通区域,
|
||||
# 需满足:总格子 >= 4,外接矩形宽 >= 2 且高 >= 2
|
||||
FLOOR_SET = (0, 3)
|
||||
visited = np.zeros((H, W), dtype=bool)
|
||||
room_count = 0
|
||||
for sy in range(H):
|
||||
for sx in range(W):
|
||||
if arr[sy, sx] not in FLOOR_SET or visited[sy, sx]:
|
||||
continue
|
||||
queue = [(sy, sx)]
|
||||
visited[sy, sx] = True
|
||||
tiles_y, tiles_x = [sy], [sx]
|
||||
head = 0
|
||||
while head < len(queue):
|
||||
y, x = queue[head]; head += 1
|
||||
for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
|
||||
ny, nx = y + dy, x + dx
|
||||
if 0 <= ny < H and 0 <= nx < W and not visited[ny, nx] and arr[ny, nx] in FLOOR_SET:
|
||||
visited[ny, nx] = True
|
||||
queue.append((ny, nx))
|
||||
tiles_y.append(ny); tiles_x.append(nx)
|
||||
if (len(tiles_y) >= 4
|
||||
and max(tiles_y) - min(tiles_y) >= 1
|
||||
and max(tiles_x) - min(tiles_x) >= 1):
|
||||
room_count += 1
|
||||
|
||||
# highDegBranchCount:非 wall 格子中,4 邻域非 wall 邻居 >= 3 的数量
|
||||
non_wall = (arr != WALL).astype(np.int32)
|
||||
padded = np.pad(non_wall, 1, mode='constant', constant_values=0)
|
||||
nbr_sum = (padded[:-2, 1:-1] + padded[2:, 1:-1] +
|
||||
padded[1:-1, :-2] + padded[1:-1, 2:])
|
||||
high_deg = int(np.sum((non_wall == 1) & (nbr_sum >= 3)))
|
||||
|
||||
return {'outerWall': outer_wall, 'roomCount': room_count, 'highDegBranchCount': high_deg}
|
||||
|
||||
|
||||
def load_data(path: str):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
data_list = []
|
||||
for value in data["data"].values():
|
||||
# 兼容旧版数据集(缺少结构标签字段)
|
||||
if 'roomCount' not in value:
|
||||
labels = _compute_map_labels(value['map'])
|
||||
value.update(labels)
|
||||
# symmetry 字段由 __getitem__ 在增强后重新计算,此处不需要从 JSON 读取
|
||||
data_list.append(value)
|
||||
|
||||
|
||||
return data_list
|
||||
|
||||
def _compute_symmetry(target_np: np.ndarray) -> tuple:
|
||||
@ -326,6 +381,322 @@ class GinkaSplitDataset(Dataset):
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GinkaStageDataset:三阶段级联训练专用数据集
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class GinkaStageDataset(Dataset):
|
||||
"""
|
||||
三阶段级联生成训练专用 Dataset。
|
||||
|
||||
每个阶段只预测特定类别的 tile,后续阶段以前序阶段输出作为上下文。
|
||||
训练时统一使用 GT 作为前序上下文(teacher forcing),避免误差级联。
|
||||
|
||||
阶段划分:
|
||||
stage=1 结构骨架:预测 floor(0) + wall(1)
|
||||
stage=2 功能元素:预测 door(2) + monster(4) + entrance(5),以 floor/wall 为上下文
|
||||
stage=3 资源放置:预测 resource(3),以完整骨架为上下文
|
||||
|
||||
返回 dict:
|
||||
raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码)
|
||||
vq_slice: LongTensor [H*W] 当前阶段 VQ 编码器的输入切片
|
||||
stage_input: LongTensor [H*W] MaskGIT 输入(含上下文 + MASK 位置)
|
||||
target_map: LongTensor [H*W] CE loss ground truth
|
||||
loss_mask: BoolTensor [H*W] 只对 True 位置计算损失
|
||||
subset: str 子集标识 A/B/C/D
|
||||
struct_cond: LongTensor [4] [sym, room, branch, outer]
|
||||
"""
|
||||
|
||||
FLOOR = 0
|
||||
WALL = 1
|
||||
DOOR = 2
|
||||
RESOURCE = 3
|
||||
MONSTER = 4
|
||||
ENTRANCE = 5
|
||||
MASK_ID = 6
|
||||
|
||||
STAGE1_TARGETS = frozenset({0, 1})
|
||||
STAGE2_TARGETS = frozenset({2, 4, 5})
|
||||
STAGE3_TARGETS = frozenset({3})
|
||||
|
||||
# VQ 切片集合:各阶段编码器只"看"与自身相关的 tile
|
||||
_VQ_KEEP = {
|
||||
1: frozenset({0, 1}),
|
||||
2: frozenset({0, 1, 2, 4, 5}),
|
||||
3: None, # 完整地图
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
stage: int,
|
||||
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
|
||||
wall_mask_ratio: float = 0.3,
|
||||
room_thresholds: tuple = None,
|
||||
branch_thresholds: tuple = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
data_path: JSON 数据文件路径
|
||||
stage: 生成阶段 1/2/3
|
||||
subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化
|
||||
wall_mask_ratio: Subset C 中额外随机 mask 的 wall 比例上限
|
||||
room_thresholds: 等频分箱阈值(None 时自动计算)
|
||||
branch_thresholds: 等频分箱阈值(None 时自动计算)
|
||||
"""
|
||||
assert stage in (1, 2, 3), f"stage 必须是 1/2/3,收到 {stage}"
|
||||
self.stage = stage
|
||||
self.data = load_data(data_path)
|
||||
self.wall_mask_ratio = wall_mask_ratio
|
||||
|
||||
total_w = sum(subset_weights)
|
||||
normalized = [x / total_w for x in subset_weights]
|
||||
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
|
||||
|
||||
room_counts = [item['roomCount'] for item in self.data]
|
||||
branch_counts = [item['highDegBranchCount'] for item in self.data]
|
||||
|
||||
if room_thresholds is None:
|
||||
n = len(room_counts)
|
||||
rs = sorted(room_counts)
|
||||
bs = sorted(branch_counts)
|
||||
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
|
||||
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
|
||||
if th1_r == th2_r: th2_r = th1_r + 1
|
||||
if th1_b == th2_b: th2_b = th1_b + 1
|
||||
self.room_th = (th1_r, th2_r)
|
||||
self.branch_th = (th1_b, th2_b)
|
||||
else:
|
||||
self.room_th = room_thresholds
|
||||
self.branch_th = branch_thresholds
|
||||
|
||||
def to_level(v, th):
|
||||
return 0 if v < th[0] else (1 if v < th[1] else 2)
|
||||
|
||||
for item in self.data:
|
||||
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
|
||||
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 掩码辅助(与 GinkaVQDataset 相同逻辑)
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float:
|
||||
r = np.random.beta(2, 2)
|
||||
return min_r + (max_r - min_r) * r
|
||||
|
||||
@staticmethod
|
||||
def _random_mask(h: int, w: int) -> np.ndarray:
|
||||
ratio = GinkaStageDataset._sample_mask_ratio()
|
||||
total = h * w
|
||||
idx = np.random.choice(total, int(total * ratio), replace=False)
|
||||
mask = np.zeros(total, dtype=bool)
|
||||
mask[idx] = True
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def _block_mask(h: int, w: int) -> np.ndarray:
|
||||
ratio = GinkaStageDataset._sample_mask_ratio()
|
||||
max_block = max(2, min(h, w) // 2)
|
||||
target = int(h * w * ratio)
|
||||
mask = np.zeros((h, w), dtype=bool)
|
||||
while mask.sum() < target:
|
||||
bh = np.random.randint(2, max_block + 1)
|
||||
bw = np.random.randint(2, max_block + 1)
|
||||
x = np.random.randint(0, max(1, h - bh + 1))
|
||||
y = np.random.randint(0, max(1, w - bw + 1))
|
||||
mask[x:x + bh, y:y + bw] = True
|
||||
return mask.reshape(-1)
|
||||
|
||||
def _std_mask(self, h: int, w: int) -> np.ndarray:
|
||||
return self._random_mask(h, w) if random.random() < 0.5 else self._block_mask(h, w)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 子集选择
|
||||
# ------------------------------------------------------------------
|
||||
def _choose_subset(self) -> str:
|
||||
r = random.random()
|
||||
if r < self.subset_cumw[0]: return 'A'
|
||||
if r < self.subset_cumw[1]: return 'B'
|
||||
if r < self.subset_cumw[2]: return 'C'
|
||||
return 'D'
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 阶段一:结构骨架(floor + wall)
|
||||
# ------------------------------------------------------------------
|
||||
def _make_stage1(self, raw_flat: np.ndarray, subset: str):
|
||||
"""
|
||||
阶段一:预测 floor/wall,所有非 floor/wall tile 在目标中重映射为 floor。
|
||||
子集决定向模型提供多少 wall 作为上下文条件。
|
||||
"""
|
||||
H = W = 13
|
||||
|
||||
# 目标:非 floor/wall → floor
|
||||
target = raw_flat.copy()
|
||||
target[~np.isin(target, [self.FLOOR, self.WALL])] = self.FLOOR
|
||||
|
||||
inp = target.copy()
|
||||
|
||||
if subset == 'A':
|
||||
# 标准随机 mask:随机遮盖部分 floor/wall
|
||||
mask = self._std_mask(H, W)
|
||||
inp[mask] = self.MASK_ID
|
||||
|
||||
elif subset == 'B':
|
||||
# 保留全部 wall,MASK floor
|
||||
inp[inp == self.FLOOR] = self.MASK_ID
|
||||
|
||||
elif subset == 'C':
|
||||
# 随机保留部分 wall,MASK 其余(含全部 floor)
|
||||
inp[inp == self.FLOOR] = self.MASK_ID
|
||||
wall_idx = np.where(inp == self.WALL)[0]
|
||||
if len(wall_idx) > 0:
|
||||
ratio = random.random() * self.wall_mask_ratio
|
||||
n = max(1, int(len(wall_idx) * ratio))
|
||||
chosen = np.random.choice(wall_idx, n, replace=False)
|
||||
inp[chosen] = self.MASK_ID
|
||||
|
||||
else: # D:与 B 相同(阶段一无 entrance 维度)
|
||||
inp[inp == self.FLOOR] = self.MASK_ID
|
||||
|
||||
loss_mask = (inp == self.MASK_ID)
|
||||
return inp, target, loss_mask
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 阶段二:功能元素(door + monster + entrance)
|
||||
# ------------------------------------------------------------------
|
||||
def _make_stage2(self, raw_flat: np.ndarray, subset: str):
|
||||
"""
|
||||
阶段二:以 floor/wall 为上下文,预测 door/monster/entrance。
|
||||
resource 在输入与目标中均视为 floor(阶段二不负责资源)。
|
||||
子集决定 wall 上下文的完整程度与 door/monster/entrance 的掩码方式。
|
||||
"""
|
||||
# 目标:resource → floor
|
||||
target = raw_flat.copy()
|
||||
target[target == self.RESOURCE] = self.FLOOR
|
||||
|
||||
# 基础输入:resource → floor,功能元素先保留,再按子集处理
|
||||
inp = raw_flat.copy()
|
||||
inp[inp == self.RESOURCE] = self.FLOOR
|
||||
|
||||
if subset == 'A':
|
||||
# 随机遮盖部分 door/monster/entrance(部分上下文补全)
|
||||
func_idx = np.where(np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE]))[0]
|
||||
if len(func_idx) > 0:
|
||||
ratio = random.random() * 0.8 + 0.2 # 20%~100%
|
||||
n = max(1, int(len(func_idx) * ratio))
|
||||
chosen = np.random.choice(func_idx, n, replace=False)
|
||||
inp[chosen] = self.MASK_ID
|
||||
else:
|
||||
# B/C/D:全部 door/monster/entrance → MASK
|
||||
inp[np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE])] = self.MASK_ID
|
||||
|
||||
if subset == 'C':
|
||||
# 额外随机 mask 部分 wall(降低 wall 上下文质量)
|
||||
wall_idx = np.where(inp == self.WALL)[0]
|
||||
if len(wall_idx) > 0:
|
||||
ratio = random.random() * self.wall_mask_ratio
|
||||
n = max(1, int(len(wall_idx) * ratio))
|
||||
chosen = np.random.choice(wall_idx, n, replace=False)
|
||||
inp[chosen] = self.MASK_ID
|
||||
|
||||
# loss_mask:阶段二只对 door/monster/entrance 原始位置计算损失,
|
||||
# 不对被额外 mask 的 wall 位置计算(它们在 target 中已知为 wall)
|
||||
loss_mask = np.isin(raw_flat, [self.DOOR, self.MONSTER, self.ENTRANCE])
|
||||
return inp, target, loss_mask
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 阶段三:资源放置(resource)
|
||||
# ------------------------------------------------------------------
|
||||
def _make_stage3(self, raw_flat: np.ndarray, subset: str):
|
||||
"""
|
||||
阶段三:以完整骨架为上下文,预测 resource 位置。
|
||||
所有 resource 位置在输入中替换为 MASK。
|
||||
子集 A 随机保留部分 resource 作为上下文(部分补全训练),
|
||||
其余子集始终 MASK 全部 resource。
|
||||
"""
|
||||
target = raw_flat.copy()
|
||||
inp = raw_flat.copy()
|
||||
|
||||
if subset == 'A':
|
||||
# 随机遮盖部分 resource(部分上下文补全)
|
||||
res_idx = np.where(inp == self.RESOURCE)[0]
|
||||
if len(res_idx) > 0:
|
||||
ratio = random.random() * 0.8 + 0.2 # 20%~100%
|
||||
n = max(1, int(len(res_idx) * ratio))
|
||||
chosen = np.random.choice(res_idx, n, replace=False)
|
||||
inp[chosen] = self.MASK_ID
|
||||
else:
|
||||
pass # 无 resource 时无需处理
|
||||
else:
|
||||
# B/C/D:全部 resource → MASK
|
||||
inp[inp == self.RESOURCE] = self.MASK_ID
|
||||
|
||||
loss_mask = (inp == self.MASK_ID)
|
||||
return inp, target, loss_mask
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# __getitem__
|
||||
# ------------------------------------------------------------------
|
||||
def _augment(self, arr: np.ndarray) -> np.ndarray:
|
||||
if np.random.rand() > 0.5:
|
||||
k = np.random.randint(1, 4)
|
||||
arr = np.rot90(arr, k).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.fliplr(arr).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.flipud(arr).copy()
|
||||
return arr
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
|
||||
raw_flat = raw_np.reshape(-1) # [H*W]
|
||||
subset = self._choose_subset()
|
||||
|
||||
if self.stage == 1:
|
||||
stage_input_np, target_np, loss_mask_np = self._make_stage1(raw_flat, subset)
|
||||
elif self.stage == 2:
|
||||
stage_input_np, target_np, loss_mask_np = self._make_stage2(raw_flat, subset)
|
||||
else:
|
||||
stage_input_np, target_np, loss_mask_np = self._make_stage3(raw_flat, subset)
|
||||
|
||||
# 若 loss_mask 全为 False(如地图中无 resource 时的 stage3),
|
||||
# 退回为全图损失,避免 NaN
|
||||
if not loss_mask_np.any():
|
||||
loss_mask_np = np.ones_like(loss_mask_np)
|
||||
|
||||
# VQ 切片:当前阶段编码器的输入(仅保留相关 tile)
|
||||
raw_t = torch.LongTensor(raw_flat)
|
||||
vq_keep = self._VQ_KEEP[self.stage]
|
||||
if vq_keep is None:
|
||||
vq_slice = raw_t.clone()
|
||||
else:
|
||||
vq_slice = make_slice(raw_t, vq_keep)
|
||||
|
||||
# 结构标签
|
||||
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
|
||||
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
|
||||
cond_room = item['roomCountLevel']
|
||||
cond_branch = item['branchLevel']
|
||||
cond_outer = item['outerWall']
|
||||
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
|
||||
|
||||
return {
|
||||
"raw_map": raw_t,
|
||||
"vq_slice": vq_slice,
|
||||
"stage_input": torch.LongTensor(stage_input_np),
|
||||
"target_map": torch.LongTensor(target_np),
|
||||
"loss_mask": torch.BoolTensor(loss_mask_np),
|
||||
"subset": subset,
|
||||
"struct_cond": struct_cond,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json')
|
||||
|
||||
@ -57,8 +57,12 @@ class GinkaMaskGIT(nn.Module):
|
||||
|
||||
# z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用
|
||||
self.z_proj = nn.Sequential(
|
||||
nn.Linear(d_z, d_model),
|
||||
nn.LayerNorm(d_model),
|
||||
nn.Linear(d_z, d_model * 2),
|
||||
nn.LayerNorm(d_model * 2),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(d_model * 2, d_model),
|
||||
nn.LayerNorm(d_model)
|
||||
)
|
||||
|
||||
# 结构标签嵌入(编码到 d_z 维度)
|
||||
@ -69,8 +73,12 @@ class GinkaMaskGIT(nn.Module):
|
||||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||
|
||||
self.struct_proj = nn.Sequential(
|
||||
nn.Linear(d_z, d_model),
|
||||
nn.LayerNorm(d_model),
|
||||
nn.Linear(d_z, d_model * 2),
|
||||
nn.LayerNorm(d_model * 2),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(d_model * 2, d_model),
|
||||
nn.LayerNorm(d_model)
|
||||
)
|
||||
|
||||
# Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention
|
||||
|
||||
654
ginka/train_stage.py
Normal file
654
ginka/train_stage.py
Normal file
@ -0,0 +1,654 @@
|
||||
"""
|
||||
三阶段级联训练脚本:各阶段独立训练,使用 GinkaStageDataset。
|
||||
|
||||
总损失 = L_CE(只对本阶段负责的 tile 位置计算)+ beta * L_commit + gamma * L_entropy
|
||||
|
||||
各阶段分工:
|
||||
stage=1 结构骨架:floor(0) + wall(1)
|
||||
stage=2 功能元素:door(2) + monster(4) + entrance(5)
|
||||
stage=3 资源放置:resource(3)
|
||||
|
||||
用法示例:
|
||||
python -m ginka.train_stage --stage 1
|
||||
python -m ginka.train_stage --stage 2
|
||||
python -m ginka.train_stage --stage 3
|
||||
python -m ginka.train_stage --stage 1 --resume True --state result/stage1/stage1-10.pth
|
||||
python -m ginka.train_stage --stage 2 --pretrain_vq result/joint/joint-50.pth
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .vqvae.model import GinkaVQVAE
|
||||
from .maskGIT.model import GinkaMaskGIT
|
||||
from .dataset import GinkaStageDataset
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 各阶段配置
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 共用 VQ-VAE 超参
|
||||
VQ_L = 2 # 码字序列长度
|
||||
VQ_K = 8 # codebook 大小
|
||||
VQ_D_Z = 64 # 码字维度
|
||||
VQ_BETA = 0.5 # commit loss 权重
|
||||
VQ_GAMMA = 0.0 # entropy loss 权重
|
||||
VQ_LAYERS = 3
|
||||
VQ_DIM_FF = 512
|
||||
VQ_D_MODEL = 64
|
||||
VQ_NHEAD = 8
|
||||
|
||||
# 各阶段 MaskGIT 超参(按任务复杂度差异化配置)
|
||||
STAGE_MG_CONFIGS = {
|
||||
1: dict(d_model=256, nhead=8, num_layers=6, dim_ff=2048), # 结构骨架,最重要
|
||||
2: dict(d_model=192, nhead=8, num_layers=4, dim_ff=1024), # 功能元素
|
||||
3: dict(d_model=128, nhead=8, num_layers=3, dim_ff=512), # 资源放置,最简单
|
||||
}
|
||||
|
||||
# 各阶段监控的 tile 集合(用于分类别召回率统计)
|
||||
STAGE_TILE_SETS = {
|
||||
1: {0: "floor", 1: "wall"},
|
||||
2: {2: "door", 4: "monster", 5: "entrance"},
|
||||
3: {3: "resource"},
|
||||
}
|
||||
|
||||
# 各阶段损失权重(可单独调节 CE 与 VQ 损失的平衡)
|
||||
# stage3 的 resource 极稀疏,大幅上调 ce_weight 以补偿类别不均衡
|
||||
STAGE_LOSS_CONFIG = {
|
||||
1: dict(ce_weight=1.0, vq_weight=1.0), # 结构骨架,标准权重
|
||||
2: dict(ce_weight=1.5, vq_weight=0.5), # 功能元素较稀疏,上调 CE
|
||||
3: dict(ce_weight=3.0, vq_weight=0.5), # resource 极稀疏,显著上调 CE
|
||||
}
|
||||
|
||||
NUM_CLASSES = 7
|
||||
MASK_TOKEN = 6
|
||||
MAP_SIZE = 13 * 13
|
||||
MAP_H = MAP_W = 13
|
||||
FOCAL_GAMMA = 2.0
|
||||
GENERATE_STEP = 18
|
||||
BATCH_SIZE = 64
|
||||
WALL_MASK_RATIO = 0.8
|
||||
|
||||
MG_Z_DROPOUT = 0.1
|
||||
MG_STRUCT_DROPOUT = 0.1
|
||||
|
||||
SUBSET_WEIGHTS = (0.5, 0.2, 0.2, 0.1)
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 参数解析
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _str2bool(v):
|
||||
if isinstance(v, bool): return v
|
||||
if v.lower() in ('true', '1', 'yes'): return True
|
||||
if v.lower() in ('false', '0', 'no'): return False
|
||||
raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}")
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="三阶段级联训练")
|
||||
parser.add_argument("--stage", type=int, required=True, choices=[1, 2, 3])
|
||||
parser.add_argument("--resume", type=_str2bool, default=False)
|
||||
parser.add_argument(
|
||||
"--state", type=str, default="",
|
||||
help="续训时加载的检查点路径(自动推断 stage{N}/stage{N}-*.pth)",
|
||||
)
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--checkpoint", type=int, default=5)
|
||||
parser.add_argument("--load_optim", type=_str2bool, default=True)
|
||||
parser.add_argument(
|
||||
"--freeze_vq", type=_str2bool, default=False,
|
||||
help="冻结 VQ 编码器,只训练 MaskGIT(适合加载预训练编码器后热身)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrain_vq", type=str, default="",
|
||||
help="从 train_vq.py 的联合训练检查点中导入对应通道的 VQ 编码器权重",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Focal Loss(与 train_vq.py 一致)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def focal_loss(logits, targets, gamma=FOCAL_GAMMA, reduction='none'):
|
||||
ce = F.cross_entropy(logits, targets, reduction='none')
|
||||
pt = torch.exp(-ce)
|
||||
fl = (1.0 - pt) ** gamma * ce
|
||||
if reduction == 'mean': return fl.mean()
|
||||
if reduction == 'sum': return fl.sum()
|
||||
return fl
|
||||
|
||||
|
||||
def masked_focal_loss(logits, targets, loss_mask, gamma=FOCAL_GAMMA):
|
||||
"""
|
||||
只对 loss_mask 为 True 的位置计算 focal loss 均值。
|
||||
|
||||
Args:
|
||||
logits: [B, C, H*W]
|
||||
targets: [B, H*W]
|
||||
loss_mask: [B, H*W] bool
|
||||
"""
|
||||
per_token = focal_loss(logits, targets, gamma, reduction='none') # [B, H*W]
|
||||
selected = per_token[loss_mask]
|
||||
if selected.numel() == 0:
|
||||
return per_token.mean()
|
||||
return selected.mean()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MaskGIT 推理(cosine schedule)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def maskgit_generate(
|
||||
model_mg: GinkaMaskGIT,
|
||||
z: torch.Tensor,
|
||||
steps: int = GENERATE_STEP,
|
||||
init_map: torch.Tensor = None,
|
||||
struct_cond: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
迭代生成地图(cosine schedule unmasking)。
|
||||
|
||||
Args:
|
||||
init_map: 可选初始地图;非 MASK 位置在生成中保持不变。
|
||||
|
||||
Returns:
|
||||
[B, MAP_SIZE] LongTensor
|
||||
"""
|
||||
B = z.shape[0]
|
||||
map_seq = (
|
||||
torch.full((B, MAP_SIZE), MASK_TOKEN, device=device)
|
||||
if init_map is None else init_map.clone().to(device)
|
||||
)
|
||||
|
||||
generatable = (map_seq == MASK_TOKEN)
|
||||
|
||||
for step in range(steps):
|
||||
if not generatable.any():
|
||||
break
|
||||
|
||||
logits = model_mg(map_seq, z, struct_cond=struct_cond) # [B, S, C]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
dist = torch.distributions.Categorical(probs)
|
||||
sampled = dist.sample()
|
||||
confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1)
|
||||
confidences = confidences.masked_fill(~generatable, float('inf'))
|
||||
|
||||
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
|
||||
new_map = map_seq.clone()
|
||||
|
||||
for b in range(B):
|
||||
n_gen = int(generatable[b].sum().item())
|
||||
n_keep = int(ratio * n_gen)
|
||||
if n_keep > 0:
|
||||
_, keep_idx = torch.topk(confidences[b], k=n_keep, largest=False)
|
||||
pred_b = sampled[b].clone()
|
||||
pred_b[keep_idx] = MASK_TOKEN
|
||||
new_map[b] = torch.where(generatable[b], pred_b, map_seq[b])
|
||||
else:
|
||||
new_map[b] = torch.where(generatable[b], sampled[b], map_seq[b])
|
||||
|
||||
map_seq = new_map
|
||||
|
||||
return map_seq
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 可视化工具(与 train_vq.py 保持一致)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_map_image(map_flat, tile_dict):
|
||||
arr = map_flat.cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
return matrix_to_image_cv(arr, tile_dict)
|
||||
|
||||
|
||||
def hstack_images(imgs, gap=4, color=(255, 255, 255)):
|
||||
max_h = max(img.shape[0] for img in imgs)
|
||||
|
||||
def _pad(img):
|
||||
dh = max_h - img.shape[0]
|
||||
return img if dh == 0 else np.concatenate(
|
||||
[img, np.full((dh, img.shape[1], 3), color, dtype=np.uint8)], axis=0)
|
||||
|
||||
vline = np.full((max_h, gap, 3), color, dtype=np.uint8)
|
||||
result = _pad(imgs[0])
|
||||
for img in imgs[1:]:
|
||||
result = np.concatenate([result, vline, _pad(img)], axis=1)
|
||||
return result
|
||||
|
||||
|
||||
def grid_images(imgs, gap=4, bg=(255, 255, 255)):
|
||||
n = len(imgs)
|
||||
if n == 0: return np.zeros((1, 1, 3), dtype=np.uint8)
|
||||
if n == 1: return imgs[0]
|
||||
mid = math.ceil(n / 2)
|
||||
top = hstack_images(imgs[:mid], gap, bg)
|
||||
bot_imgs = imgs[mid:]
|
||||
if not bot_imgs: return top
|
||||
bot = hstack_images(bot_imgs, gap, bg)
|
||||
tw, bw = top.shape[1], bot.shape[1]
|
||||
if tw > bw:
|
||||
bot = np.concatenate(
|
||||
[bot, np.full((bot.shape[0], tw - bw, 3), bg, dtype=np.uint8)], axis=1)
|
||||
elif bw > tw:
|
||||
top = np.concatenate(
|
||||
[top, np.full((top.shape[0], bw - tw, 3), bg, dtype=np.uint8)], axis=1)
|
||||
hline = np.full((gap, top.shape[1], 3), bg, dtype=np.uint8)
|
||||
return np.concatenate([top, hline, bot], axis=0)
|
||||
|
||||
|
||||
def label_image(img, text, font_scale=0.45):
|
||||
bar = np.full((16, img.shape[1], 3), (40, 40, 40), dtype=np.uint8)
|
||||
cv2.putText(
|
||||
bar, text, (2, 13), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
font_scale, (200, 200, 200), 1, cv2.LINE_AA,
|
||||
)
|
||||
return np.concatenate([bar, img], axis=0)
|
||||
|
||||
|
||||
def make_random_struct_cond():
|
||||
from .maskGIT.model import SYM_VOCAB, ROOM_VOCAB, BRANCH_VOCAB, OUTER_VOCAB
|
||||
return torch.tensor([[
|
||||
random.randint(0, SYM_VOCAB - 2),
|
||||
random.randint(0, ROOM_VOCAB - 2),
|
||||
random.randint(0, BRANCH_VOCAB - 2),
|
||||
random.randint(0, OUTER_VOCAB - 2),
|
||||
]], dtype=torch.long, device=device)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 按阶段构造推理初始地图
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_stage_init(stage: int, context_map: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
根据阶段构造 MaskGIT 的推理初始地图。
|
||||
|
||||
Stage 1: 全 MASK(或保留稀疏 wall 种子)
|
||||
Stage 2: 保留 floor/wall 上下文,其余 → MASK
|
||||
Stage 3: 保留完整上下文(floor/wall/door/monster/entrance),resource → MASK
|
||||
"""
|
||||
init = context_map.clone()
|
||||
|
||||
if stage == 1:
|
||||
# 全 MASK(不依赖上下文地图)
|
||||
init = torch.full_like(init, MASK_TOKEN)
|
||||
|
||||
elif stage == 2:
|
||||
# 保留 floor/wall,其余 → MASK
|
||||
mask = ~torch.isin(init, torch.tensor([0, 1], device=init.device))
|
||||
init[mask] = MASK_TOKEN
|
||||
|
||||
else: # stage == 3
|
||||
# 保留非 resource,resource → MASK
|
||||
init[init == 3] = MASK_TOKEN
|
||||
|
||||
return init
|
||||
|
||||
|
||||
def make_random_wall_seed(ratio_min=0.02, ratio_max=0.08):
|
||||
ratio = random.uniform(ratio_min, ratio_max)
|
||||
n_wall = max(2, int(MAP_SIZE * ratio))
|
||||
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
|
||||
idx = torch.randperm(MAP_SIZE)[:n_wall]
|
||||
seed[0, idx] = 1
|
||||
return seed
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 验证函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
stage: int,
|
||||
enc: GinkaVQVAE,
|
||||
model_mg: GinkaMaskGIT,
|
||||
dataloader_val: DataLoader,
|
||||
tile_dict: dict,
|
||||
epoch: int,
|
||||
n_rand: int = 3,
|
||||
):
|
||||
enc.eval()
|
||||
model_mg.eval()
|
||||
|
||||
epoch_dir = f"result/stage{stage}_img/e{epoch:04d}"
|
||||
os.makedirs(epoch_dir, exist_ok=True)
|
||||
|
||||
val_loss_total = 0.0
|
||||
val_steps = 0
|
||||
captured = {s: None for s in ('A', 'B', 'C', 'D')}
|
||||
|
||||
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
|
||||
raw_map = batch["raw_map"].to(device)
|
||||
vq_slice = batch["vq_slice"].to(device)
|
||||
stage_input = batch["stage_input"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
loss_mask = batch["loss_mask"].to(device)
|
||||
struct_cond = batch["struct_cond"].to(device)
|
||||
subsets = batch["subset"]
|
||||
|
||||
z_q, _, _, vq_loss, _, _ = enc(vq_slice)
|
||||
logits = model_mg(stage_input, z_q, struct_cond=struct_cond)
|
||||
|
||||
ce = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask)
|
||||
val_loss_total += (ce + vq_loss).item()
|
||||
val_steps += 1
|
||||
|
||||
for i in range(raw_map.shape[0]):
|
||||
s = subsets[i]
|
||||
if captured[s] is None:
|
||||
captured[s] = {
|
||||
"raw": raw_map[i:i+1].clone(),
|
||||
"stage_input": stage_input[i:i+1].clone(),
|
||||
"z_q": z_q[i:i+1].clone(),
|
||||
"struct_cond": struct_cond[i:i+1].clone(),
|
||||
}
|
||||
|
||||
if all(v is not None for v in captured.values()):
|
||||
break
|
||||
|
||||
# ---- 可视化:每个子集一张图 ----------------------------------------
|
||||
for sub, cap in captured.items():
|
||||
if cap is None:
|
||||
continue
|
||||
|
||||
raw_img = label_image(make_map_image(cap["raw"][0], tile_dict), "GT")
|
||||
inp_img = label_image(make_map_image(cap["stage_input"][0], tile_dict), f"stage{stage} input")
|
||||
|
||||
# 真实 z 的迭代生成
|
||||
init = make_stage_init(stage, cap["stage_input"][0].unsqueeze(0))
|
||||
gen = maskgit_generate(
|
||||
model_mg, cap["z_q"],
|
||||
init_map=init, struct_cond=cap["struct_cond"],
|
||||
)
|
||||
gen_img = label_image(make_map_image(gen[0], tile_dict), "z_real gen")
|
||||
|
||||
# 随机 z 的生成
|
||||
rand_imgs = []
|
||||
for i in range(n_rand):
|
||||
z_r = enc.sample(1, device)
|
||||
sc_r = make_random_struct_cond()
|
||||
init2 = make_stage_init(stage, cap["raw"][0].unsqueeze(0))
|
||||
gen_r = maskgit_generate(model_mg, z_r, init_map=init2, struct_cond=sc_r)
|
||||
rand_imgs.append(label_image(make_map_image(gen_r[0], tile_dict), f"z_rand_{i+1}"))
|
||||
|
||||
row = [raw_img, inp_img, gen_img] + rand_imgs
|
||||
cv2.imwrite(f"{epoch_dir}/subset_{sub}.png", grid_images(row))
|
||||
|
||||
# ---- 场景:完全自主生成 -----------------------------------------------
|
||||
# stage1:从随机稀疏墙壁种子出发(完全不依赖 GT)
|
||||
# stage2:以验证集中采样的 floor/wall 结构为上下文,随机 z₂(模拟级联推理)
|
||||
# stage3:以验证集中采样的完整功能地图为上下文,随机 z₃(模拟级联推理)
|
||||
context_pool = [cap["raw"][0] for cap in captured.values() if cap is not None]
|
||||
|
||||
rand_free = []
|
||||
for i in range(n_rand + 1):
|
||||
z_r = enc.sample(1, device)
|
||||
sc_r = make_random_struct_cond()
|
||||
|
||||
if stage == 1:
|
||||
# 稀疏 wall 种子作为提示,模型自主补全 floor/wall
|
||||
init = make_random_wall_seed()
|
||||
else:
|
||||
# 从验证集上下文池中轮流取一张图作为前序阶段的输出
|
||||
ctx = context_pool[i % len(context_pool)].unsqueeze(0)
|
||||
# make_stage_init 会自动将本阶段负责的 tile 位置替换为 MASK
|
||||
init = make_stage_init(stage, ctx)
|
||||
|
||||
gen = maskgit_generate(model_mg, z_r, init_map=init, struct_cond=sc_r)
|
||||
rand_free.append(label_image(make_map_image(gen[0], tile_dict), f"free_{i+1}"))
|
||||
cv2.imwrite(f"{epoch_dir}/scene_free_random.png", grid_images(rand_free))
|
||||
|
||||
return val_loss_total / max(val_steps, 1)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 主训练函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def train():
|
||||
print(f"Using device: {device}")
|
||||
args = parse_arguments()
|
||||
stage = args.stage
|
||||
|
||||
result_dir = f"result/stage{stage}"
|
||||
result_img_dir = f"result/stage{stage}_img"
|
||||
os.makedirs(result_dir, exist_ok=True)
|
||||
os.makedirs(result_img_dir, exist_ok=True)
|
||||
|
||||
# ---- VQ 编码器(单路)----
|
||||
mg_cfg = STAGE_MG_CONFIGS[stage]
|
||||
enc = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES,
|
||||
L=VQ_L,
|
||||
K=VQ_K,
|
||||
d_z=VQ_D_Z,
|
||||
d_model=VQ_D_MODEL,
|
||||
nhead=VQ_NHEAD,
|
||||
num_layers=VQ_LAYERS,
|
||||
dim_ff=VQ_DIM_FF,
|
||||
map_size=MAP_SIZE,
|
||||
beta=VQ_BETA,
|
||||
gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
|
||||
model_mg = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES,
|
||||
d_model=mg_cfg["d_model"],
|
||||
d_z=VQ_D_Z,
|
||||
dim_ff=mg_cfg["dim_ff"],
|
||||
nhead=mg_cfg["nhead"],
|
||||
num_layers=mg_cfg["num_layers"],
|
||||
map_size=MAP_SIZE,
|
||||
z_dropout=MG_Z_DROPOUT,
|
||||
struct_dropout=MG_STRUCT_DROPOUT,
|
||||
).to(device)
|
||||
|
||||
enc_params = sum(p.numel() for p in enc.parameters())
|
||||
mg_params = sum(p.numel() for p in model_mg.parameters())
|
||||
print(f"[Stage {stage}] VQ Encoder 参数量: {enc_params:,} ({enc_params/1e6:.3f}M)")
|
||||
print(f"[Stage {stage}] MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)")
|
||||
|
||||
# ---- 数据集 ----
|
||||
dataset_train = GinkaStageDataset(
|
||||
args.train,
|
||||
stage=stage,
|
||||
subset_weights=SUBSET_WEIGHTS,
|
||||
wall_mask_ratio=WALL_MASK_RATIO,
|
||||
)
|
||||
dataset_val = GinkaStageDataset(
|
||||
args.validate,
|
||||
stage=stage,
|
||||
subset_weights=SUBSET_WEIGHTS,
|
||||
room_thresholds=dataset_train.room_th,
|
||||
branch_thresholds=dataset_train.branch_th,
|
||||
wall_mask_ratio=WALL_MASK_RATIO,
|
||||
)
|
||||
dataloader_train = DataLoader(
|
||||
dataset_train,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
pin_memory=(device.type == "cuda"),
|
||||
)
|
||||
dataloader_val = DataLoader(
|
||||
dataset_val,
|
||||
batch_size=8,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
# ---- 优化器 ----
|
||||
all_params = list(enc.parameters()) + list(model_mg.parameters())
|
||||
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=args.epochs, eta_min=1e-6,
|
||||
)
|
||||
|
||||
# ---- 权重加载 ----
|
||||
start_epoch = 0
|
||||
|
||||
if args.pretrain_vq:
|
||||
# 从 train_vq.py 的联合训练检查点加载对应通道的 VQ 编码器
|
||||
ckpt = torch.load(args.pretrain_vq, map_location=device)
|
||||
enc_key = f"enc{stage}"
|
||||
if enc_key in ckpt:
|
||||
enc.load_state_dict(ckpt[enc_key], strict=False)
|
||||
print(f"已从 {args.pretrain_vq} 加载 {enc_key} 权重。")
|
||||
else:
|
||||
print(f"警告:检查点中未找到 {enc_key},跳过权重加载。")
|
||||
|
||||
if args.resume:
|
||||
state_path = args.state or f"{result_dir}/stage{stage}-latest.pth"
|
||||
ckpt = torch.load(state_path, map_location=device)
|
||||
enc.load_state_dict(ckpt["enc"], strict=False)
|
||||
model_mg.load_state_dict(ckpt["mg_state"], strict=False)
|
||||
if args.load_optim and ckpt.get("optim_state") is not None:
|
||||
optimizer.load_state_dict(ckpt["optim_state"])
|
||||
start_epoch = ckpt.get("epoch", 0)
|
||||
print(f"从 epoch {start_epoch} 接续训练。")
|
||||
|
||||
# ---- tile 贴图 ----
|
||||
tile_dict = {}
|
||||
for f in os.listdir("tiles"):
|
||||
name = os.path.splitext(f)[0]
|
||||
img = cv2.imread(f"tiles/{f}", cv2.IMREAD_UNCHANGED)
|
||||
if img is not None:
|
||||
tile_dict[name] = img
|
||||
|
||||
# ---- 冻结 VQ 编码器(可选)----
|
||||
if args.freeze_vq:
|
||||
for p in enc.parameters():
|
||||
p.requires_grad_(False)
|
||||
print(f"[Stage {stage}] VQ 编码器已冻结。")
|
||||
|
||||
# ---- 训练循环 ----
|
||||
for epoch in tqdm(
|
||||
range(start_epoch, start_epoch + args.epochs),
|
||||
desc=f"Stage{stage} Training",
|
||||
disable=disable_tqdm,
|
||||
):
|
||||
enc.train()
|
||||
model_mg.train()
|
||||
|
||||
loss_total = 0.0
|
||||
ce_total = 0.0
|
||||
vq_loss_total = 0.0
|
||||
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
|
||||
|
||||
# 按 tile 统计召回率(用于监控各类 tile 的预测准确性)
|
||||
tile_correct = {tid: 0 for tid in STAGE_TILE_SETS[stage]}
|
||||
tile_total = {tid: 0 for tid in STAGE_TILE_SETS[stage]}
|
||||
|
||||
for batch in tqdm(
|
||||
dataloader_train,
|
||||
leave=False,
|
||||
desc="Epoch Progress",
|
||||
disable=disable_tqdm,
|
||||
):
|
||||
raw_map = batch["raw_map"].to(device)
|
||||
vq_slice = batch["vq_slice"].to(device)
|
||||
stage_input = batch["stage_input"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
loss_mask = batch["loss_mask"].to(device)
|
||||
struct_cond = batch["struct_cond"].to(device)
|
||||
|
||||
for s in batch["subset"]:
|
||||
subset_stats[s] = subset_stats.get(s, 0) + 1
|
||||
|
||||
# ---- 前向传播 ----
|
||||
z_q, _, _, vq_loss, commit_loss, entropy_loss = enc(vq_slice)
|
||||
logits = model_mg(stage_input, z_q, struct_cond=struct_cond) # [B, S, C]
|
||||
|
||||
# ---- 仅对本阶段 tile 位置计算 focal loss ----
|
||||
ce_loss = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask)
|
||||
loss_cfg = STAGE_LOSS_CONFIG[stage]
|
||||
loss = loss_cfg["ce_weight"] * ce_loss + loss_cfg["vq_weight"] * vq_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_total += loss.detach().item()
|
||||
ce_total += ce_loss.detach().item()
|
||||
vq_loss_total += vq_loss.detach().item()
|
||||
|
||||
# ---- 分 tile 召回率统计 ----
|
||||
with torch.no_grad():
|
||||
preds = logits.argmax(dim=-1) # [B, S]
|
||||
for tid in STAGE_TILE_SETS[stage]:
|
||||
gt_mask = (target_map == tid) & loss_mask
|
||||
tile_total[tid] += gt_mask.sum().item()
|
||||
tile_correct[tid] += (preds[gt_mask] == tid).sum().item()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
n = len(dataloader_train)
|
||||
recall_str = " ".join(
|
||||
f"{STAGE_TILE_SETS[stage][tid]}={tile_correct[tid]/(tile_total[tid]+1e-6):.2%}"
|
||||
for tid in STAGE_TILE_SETS[stage]
|
||||
)
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"Epoch {epoch + 1:4d} | "
|
||||
f"Loss {loss_total/n:.5f} "
|
||||
f"Focal {ce_total/n:.5f} "
|
||||
f"VQ {vq_loss_total/n:.5f} | "
|
||||
f"Recall: {recall_str} | "
|
||||
f"LR {scheduler.get_last_lr()[0]:.6f} | "
|
||||
f"Subsets {subset_stats}"
|
||||
)
|
||||
|
||||
# ---- 检查点 + 验证 ----
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
ckpt_path = f"{result_dir}/stage{stage}-{epoch + 1}.pth"
|
||||
torch.save({
|
||||
"epoch": epoch + 1,
|
||||
"stage": stage,
|
||||
"enc": enc.state_dict(),
|
||||
"mg_state": model_mg.state_dict(),
|
||||
"optim_state": optimizer.state_dict(),
|
||||
}, ckpt_path)
|
||||
tqdm.write(f" 检查点已保存: {ckpt_path}")
|
||||
|
||||
val_loss = validate(stage, enc, model_mg, dataloader_val, tile_dict, epoch + 1)
|
||||
tqdm.write(f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}")
|
||||
|
||||
enc.train()
|
||||
model_mg.train()
|
||||
|
||||
# ---- 最终存档 ----
|
||||
torch.save({
|
||||
"epoch": start_epoch + args.epochs,
|
||||
"stage": stage,
|
||||
"enc": enc.state_dict(),
|
||||
"mg_state": model_mg.state_dict(),
|
||||
}, f"{result_dir}/stage{stage}_final.pth")
|
||||
print(f"[Stage {stage}] 训练结束。")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(4)
|
||||
train()
|
||||
Loading…
Reference in New Issue
Block a user