mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 16:41:10 +08:00
feat: 输入剩余密度
This commit is contained in:
parent
98e1734248
commit
306d585a28
@ -1,311 +1,93 @@
|
||||
# 条件简化与密度连续化设计文档
|
||||
# 条件简化与剩余密度条件设计文档
|
||||
|
||||
## 背景
|
||||
## 背景问题
|
||||
|
||||
当前三阶段级联生成模型的条件系统存在以下问题:
|
||||
当前三阶段级联生成模型的条件系统主要有五个问题:
|
||||
|
||||
1. **结构条件中的房间数和分支数对生成指导意义有限**:这两个指标依赖数据集中预计算的离散分档,与实际生成质量的相关性较弱,且分档边界处噪声大,容易引入无效条件信号。
|
||||
1. 结构条件中的房间数和分支数指导意义有限。这两个标签依赖离散分档,噪声大,和最终生成质量的关系不稳定。
|
||||
|
||||
2. **实体密度条件(门/怪物/资源)的离散三档存在明显一对多问题**:三档划分过于粗糙,同一档内样本分布差异极大(例如 Medium 档中资源数可以从 2 到 8 不等),导致模型无法建立条件与生成结果之间的精确映射。连续值能够更精确地描述目标密度,避免档位内分布散乱导致的条件信号模糊。
|
||||
2. 直接给模型输入整张图的最终密度仍然过于间接。模型拿到的是终态目标,但生成过程是逐步展开的,模型需要自己从当前地图里数出已经放了多少,再反推出还差多少,这对 Transformer 并不友好。
|
||||
|
||||
## 改动总览
|
||||
3. 第一阶段缺少足够的全局统计条件。实验上第二、三阶段可以较好过拟合,但第一阶段明显更难,说明骨架生成还缺少能够直接约束复杂度的全局量。墙壁密度正好可以承担这个角色。
|
||||
|
||||
| 模块 | 改动类型 | 说明 |
|
||||
| -------------------------- | -------- | -------------------------------------------------------------- |
|
||||
| `ginka/dataset.py` | 修改 | 删除房间/分支分档;密度改为连续归一化;输出 FloatTensor |
|
||||
| `ginka/maskGIT/model.py` | 修改 | 删除房间/分支嵌入;密度嵌入层改为线性投影;更新 cond_proj 维度 |
|
||||
| `ginka/train_seperated.py` | 修改 | 更新 random_struct/random_density;更新 annotate_labels |
|
||||
4. 第二阶段除了门和怪物,入口也存在生成失控的问题。既然入口本来就属于第二阶段负责的功能性实体,那么它也应该进入同一套密度条件,而不是继续裸奔。
|
||||
|
||||
---
|
||||
5. 不适合直接把真实计数作为条件,也不适合用相对目标量做归一化。前者会把条件值拉得过大,后者又会让非零目标在初始时刻统一退化成 1,容易诱发整图铺满某类图块的极端行为。
|
||||
|
||||
## 一、条件向量格式变更
|
||||
因此,条件系统应统一改成“真实剩余密度”方案,并覆盖三个阶段需要的全部关键对象。
|
||||
|
||||
### 1.1 struct_inject
|
||||
## 核心方案
|
||||
|
||||
**当前格式**(4 个离散整数):
|
||||
保留简化后的结构条件,只使用对称性和外围墙信息;删除房间数、分支数这类噪声较大的离散标签。
|
||||
|
||||
```
|
||||
[cond_sym(0-7), cond_room(0-2), cond_branch(0-2), cond_outer(0-1)]
|
||||
```
|
||||
|
||||
**新格式**(2 个离散整数,删除 room 和 branch):
|
||||
|
||||
```
|
||||
[cond_sym(0-7), cond_outer(0-1)]
|
||||
```
|
||||
|
||||
`cond_sym` 的计算方式不变(水平/垂直/中心对称的三位二进制组合,0–7),`cond_outer` 不变。
|
||||
|
||||
### 1.2 density_inject
|
||||
|
||||
**当前格式**(3 个离散整数,`LongTensor`):
|
||||
|
||||
```
|
||||
[door_level(0-2), monster_level(0-2), resource_level(0-2)]
|
||||
```
|
||||
|
||||
**新格式**(3 个连续浮点数,`FloatTensor`,值域 [0, 1]):
|
||||
|
||||
```
|
||||
[door_norm, monster_norm, resource_norm] ∈ [0.0, 1.0]^3
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 二、密度归一化方案
|
||||
|
||||
### 2.1 统计量定义
|
||||
|
||||
在训练集初始化阶段,对原始地图统计三类图块的实际数量(非密度,直接计数):
|
||||
|
||||
- `door_count = 图块ID为2的数量`
|
||||
- `monster_count = 图块ID为4的数量`
|
||||
- `resource_count = 图块ID为3的数量`
|
||||
|
||||
对每类分别求训练集内的 **最小值** 和 **最大值**:
|
||||
|
||||
```
|
||||
door_min, door_max
|
||||
monster_min, monster_max
|
||||
resource_min, resource_max
|
||||
```
|
||||
|
||||
### 2.2 归一化公式
|
||||
|
||||
对每个样本的 count,归一化为 [0, 1]:
|
||||
密度条件不再表示“最终应该有多少”,而是表示“当前还剩多少没放完”。所有密度都统一按固定地图面积计算:
|
||||
|
||||
$$
|
||||
\text{norm}(x) = \frac{x - x_{\min}}{x_{\max} - x_{\min} + \epsilon}
|
||||
d^{\text{remain}} = d^{\text{target}} - d^{\text{visible}}
|
||||
$$
|
||||
|
||||
其中 $\epsilon = 1\text{e-}6$,防止分母为零(当所有样本计数相同时)。
|
||||
其中:
|
||||
|
||||
结果裁剪到 [0, 1]:`norm = clamp(norm, 0.0, 1.0)`。
|
||||
- $d^{\text{target}}$ 表示目标地图中的真实密度
|
||||
- $d^{\text{visible}}$ 表示当前输入地图中已经可见的真实密度
|
||||
- 分母固定为地图总面积 $13 \times 13 = 169$
|
||||
|
||||
### 2.3 验证集复用训练集统计量
|
||||
这样做有两个直接好处:
|
||||
|
||||
`GinkaSeperatedDataset` 新增参数:
|
||||
- 条件值始终处于小范围浮点数区间,不会因为资源、墙壁或其他图块数量较大而把激活值拉爆。
|
||||
- 初始条件不再因为“目标非零”而统一退化成 1,模型能真正区分稀疏目标和稠密目标。
|
||||
|
||||
```python
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
subset_weights: tuple = (0.5, 0.3, 0.2),
|
||||
density_stats: dict | None = None # 新增:外部传入统计量
|
||||
):
|
||||
```
|
||||
## 条件内容
|
||||
|
||||
- 训练集:`density_stats=None`,自行计算并保存 `min/max` 到 `self.density_stats`
|
||||
- 验证集:传入训练集的 `self.density_stats`,直接复用,保证归一化语义一致
|
||||
结构条件继续采用两维:对称性和外围墙。
|
||||
|
||||
`density_stats` 的结构:
|
||||
密度条件扩展为五维,顺序为:墙壁、门、怪物、入口、资源。
|
||||
|
||||
```python
|
||||
{
|
||||
"door_min": float, "door_max": float,
|
||||
"monster_min": float, "monster_max": float,
|
||||
"resource_min": float, "resource_max": float,
|
||||
}
|
||||
```
|
||||
这样设计的原因是:
|
||||
|
||||
### 2.4 输出字段变更
|
||||
- 墙壁密度直接约束第一阶段骨架复杂度。
|
||||
- 门、怪物、入口都属于第二阶段的功能性实体,应该共享同一套“还剩多少要放”的控制思路。
|
||||
- 资源属于第三阶段,继续保留独立控制。
|
||||
|
||||
`__getitem__` 中 `density_inject` 由 `LongTensor` 改为 `FloatTensor`:
|
||||
不单独加入 floor 密度,因为在地图总面积固定时,floor 和 wall 基本互补,额外增加一维收益有限。
|
||||
|
||||
```python
|
||||
# 删除旧的离散分档逻辑
|
||||
density_inject = torch.FloatTensor([
|
||||
self.norm_density(count_door, "door"),
|
||||
self.norm_density(count_monster, "monster"),
|
||||
self.norm_density(count_resource, "resource"),
|
||||
])
|
||||
```
|
||||
## 分阶段使用方式
|
||||
|
||||
删除以下字段(不再写入 item 也不再输出):
|
||||
第一阶段只使用墙壁剩余密度,用来约束骨架稠密程度。
|
||||
|
||||
- `doorDensityLevel`, `monsterDensityLevel`, `resourceDensityLevel`
|
||||
- `roomCountLevel`, `branchLevel`
|
||||
第二阶段使用门、怪物、入口的剩余密度,用来约束功能性实体的总量,避免某一类实体明显过多。
|
||||
|
||||
删除以下实例变量:
|
||||
第三阶段只使用资源剩余密度,用来约束资源分布总量。
|
||||
|
||||
- `self.room_th`, `self.branch_th`
|
||||
- `self.door_density_th`, `self.monster_density_th`, `self.resource_density_th`
|
||||
虽然模型内部可以继续使用统一的条件接口,但语义上每个阶段只关注自己负责的那几维,其他维度在该阶段不提供有效信号。
|
||||
|
||||
---
|
||||
## 训练与采样原则
|
||||
|
||||
## 三、模型结构变更(`model.py`)
|
||||
密度条件不能在样本读取时一次性写死,而必须根据当前输入地图动态计算。也就是说,同一张目标地图在不同阶段、不同采样步、不同当前状态下,对应的剩余密度都可能不同。
|
||||
|
||||
### 3.1 删除房间/分支嵌入
|
||||
这件事在训练和推理里都一样重要:
|
||||
|
||||
删除:
|
||||
- 训练时,模型应看到“当前已经放了多少”之后对应的剩余密度。
|
||||
- 推理时,每一步去掩码之后都要重新计算剩余密度,否则模型不知道自己刚刚已经新增了多少墙壁、门、怪物、入口或资源。
|
||||
|
||||
```python
|
||||
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
|
||||
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
|
||||
```
|
||||
自由生成时,如果需要随机条件,建议从训练集真实分布中采样密度组合,而不是各维完全独立均匀采样。特别是墙壁密度和第二、三阶段实体密度往往存在耦合关系,独立采样容易产生不自然组合。
|
||||
|
||||
保留:
|
||||
## 验证重点
|
||||
|
||||
```python
|
||||
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z) # SYM_VOCAB = 8
|
||||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) # OUTER_VOCAB = 2
|
||||
```
|
||||
这套方案的目标不是单纯让损失更低,而是让条件真正可控。验证时应重点看四类现象:
|
||||
|
||||
删除的常量:`ROOM_VOCAB`, `BRANCH_VOCAB`,保留 `SYM_VOCAB`, `OUTER_VOCAB`。
|
||||
1. 第一阶段墙壁密度误差是否明显下降,墙壁目标密度从低到高时,stage1 输出的墙壁数量是否随之单调上升。
|
||||
|
||||
### 3.2 密度嵌入层改为线性投影
|
||||
2. 第二阶段门、怪物、入口的过量生成是否下降,尤其要观察入口是否还会出现明显离谱的数量。
|
||||
|
||||
删除:
|
||||
3. 第三阶段资源的过量生成是否下降。
|
||||
|
||||
```python
|
||||
self.door_density_embed = nn.Embedding(DOOR_DENSITY_VOCAB, d_z)
|
||||
self.monster_density_embed = nn.Embedding(MONSTER_DENSITY_VOCAB, d_z)
|
||||
self.resource_density_embed = nn.Embedding(RESOURCE_DENSITY_VOCAB, d_z)
|
||||
```
|
||||
4. 固定结构和潜变量时,只改变某一维目标密度,最终对应图块数量是否发生方向一致的变化。
|
||||
|
||||
删除的常量:`DOOR_DENSITY_VOCAB`, `MONSTER_DENSITY_VOCAB`, `RESOURCE_DENSITY_VOCAB`。
|
||||
验证指标建议以真实密度误差为主,计数误差可以作为辅助观察项。
|
||||
|
||||
新增:
|
||||
## 兼容性结论
|
||||
|
||||
```python
|
||||
# 连续密度投影:将 3 个归一化浮点数映射为 1 个 d_z 维 token
|
||||
self.density_proj = nn.Linear(3, d_z)
|
||||
```
|
||||
这次改动不仅改变条件语义,也会改变密度条件的维度,因此旧 checkpoint 不再兼容。
|
||||
|
||||
### 3.3 cond_proj 维度更新
|
||||
|
||||
**当前 cond_seq 形状**:`[B, z_seq_len + 4_struct + 3_density, d_z]`,即 `[B, z_seq_len+7, d_z]`,展平后输入维度 `(z_seq_len+7) * d_z`。
|
||||
|
||||
**新 cond_seq 形状**:`[B, z_seq_len + 2_struct + 1_density, d_z]`,即 `[B, z_seq_len+3, d_z]`,展平后输入维度 `(z_seq_len+3) * d_z`。
|
||||
|
||||
```python
|
||||
# 旧
|
||||
self.cond_proj = nn.Linear((z_seq_len + 7) * d_z, d_model)
|
||||
# 新
|
||||
self.cond_proj = nn.Linear((z_seq_len + 3) * d_z, d_model)
|
||||
```
|
||||
|
||||
### 3.4 forward 流程变更
|
||||
|
||||
```python
|
||||
def forward(
|
||||
self,
|
||||
map: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
struct: torch.Tensor, # [B, 2] ← 由 [B, 4] 改为 [B, 2]
|
||||
density: torch.Tensor # [B, 3] float ← 由 [B, 3] long 改为 float
|
||||
) -> torch.Tensor:
|
||||
|
||||
# 结构标签:sym + outer,各嵌入为 d_z 维 token
|
||||
e_struct = torch.stack([
|
||||
self.sym_embed(struct[:, 0]), # [B, d_z]
|
||||
self.outer_embed(struct[:, 1]), # [B, d_z]
|
||||
], dim=1) # [B, 2, d_z]
|
||||
|
||||
# 密度:连续值投影为单个 d_z 维 token
|
||||
e_density = self.density_proj(density).unsqueeze(1) # [B, 1, d_z]
|
||||
|
||||
# z:逐 token 投影(不变)
|
||||
z_proj = self.z_proj(z) # [B, z_seq_len, d_z]
|
||||
|
||||
# 拼接 → [B, z_seq_len+3, d_z] → 展平 → 投影到 d_model
|
||||
cond_seq = torch.cat([z_proj, e_struct, e_density], dim=1)
|
||||
c = self.cond_proj(cond_seq.reshape(cond_seq.size(0), -1)) # [B, d_model]
|
||||
|
||||
# 后续不变(tile embedding + Transformer + output_fc)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 四、训练脚本变更(`train_seperated.py`)
|
||||
|
||||
### 4.1 random_struct
|
||||
|
||||
```python
|
||||
def random_struct(device: torch.device) -> torch.Tensor:
|
||||
# struct_inject 格式:[cond_sym(0-7), cond_outer(0-1)]
|
||||
cond_sym = random.randint(0, 7) # 地图对称类型
|
||||
cond_outer = random.randint(0, 1) # 是否有外围走廊
|
||||
return torch.LongTensor([cond_sym, cond_outer]).unsqueeze(0).to(device)
|
||||
```
|
||||
|
||||
### 4.2 random_density
|
||||
|
||||
```python
|
||||
def random_density(device: torch.device) -> torch.Tensor:
|
||||
# density_inject 格式:[door_norm, monster_norm, resource_norm] ∈ [0, 1]
|
||||
return torch.rand(1, 3, device=device) # 均匀分布随机采样
|
||||
```
|
||||
|
||||
### 4.3 annotate_labels
|
||||
|
||||
更新标注格式,删除 room/branch,密度显示为两位小数:
|
||||
|
||||
```python
|
||||
def annotate_labels(
|
||||
img: np.ndarray,
|
||||
struct: torch.Tensor, # [2] long
|
||||
density: torch.Tensor # [3] float
|
||||
) -> np.ndarray:
|
||||
s = struct.tolist()
|
||||
d = density.tolist()
|
||||
line1 = f"sym:{s[0]} outer:{s[1]}"
|
||||
line2 = f"door:{d[0]:.2f} enemy:{d[1]:.2f} res:{d[2]:.2f}"
|
||||
img = img.copy()
|
||||
for text, y in [(line1, 12), (line2, 24)]:
|
||||
cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 0), 2)
|
||||
cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)
|
||||
return img
|
||||
```
|
||||
|
||||
### 4.4 训练集与验证集初始化
|
||||
|
||||
```python
|
||||
train_dataset = GinkaSeperatedDataset(args.train, subset_weights=SUBSET_WEIGHTS)
|
||||
val_dataset = GinkaSeperatedDataset(
|
||||
args.validate,
|
||||
subset_weights=SUBSET_WEIGHTS,
|
||||
density_stats=train_dataset.density_stats # 复用训练集统计量
|
||||
)
|
||||
```
|
||||
|
||||
### 4.5 DataLoader collate_fn(FloatTensor 适配)
|
||||
|
||||
PyTorch 默认 collate 会自动将 FloatTensor 列表合并为 float 类型批张量,无需额外修改 DataLoader 配置。
|
||||
|
||||
### 4.6 验证阶段密度对照图(density_var)
|
||||
|
||||
`visualize_density_var` 内对比不同密度条件时,改为使用 5 个均匀分布采样点:
|
||||
|
||||
```python
|
||||
# 旧(三档枚举):density_levels = [0, 1, 2]
|
||||
# 新(连续采样):5 个均匀间隔值
|
||||
density_values = [0.0, 0.25, 0.5, 0.75, 1.0]
|
||||
for v in density_values:
|
||||
d = torch.FloatTensor([[v, v, v]]).to(device) # 三类等密度扫描
|
||||
...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 五、不需要改动的部分
|
||||
|
||||
- `ginka/maskGIT/maskGIT.py`:AdaLN / CondTransformerLayer / Transformer 均不感知条件维度,无需修改
|
||||
- `ginka/vqvae/` 目录:VQ-VAE 部分与条件系统无关
|
||||
- `ginka/train_seperated.py` 中的 `maskgit_sample`、`full_generate_random_z`、`full_generate_specific_z`:接口签名不变(仍接受 struct/density 张量),内部无直接操作条件内容,无需修改
|
||||
- `data/` 目录的 TypeScript 数据处理脚本:数据文件格式不变,Python 端自行计算标签
|
||||
|
||||
---
|
||||
|
||||
## 六、旧 checkpoint 兼容性
|
||||
|
||||
由于 `cond_proj` 输入维度和嵌入层数量均发生变化,**旧 checkpoint 不兼容**,需从头训练。
|
||||
|
||||
---
|
||||
|
||||
## 七、实施顺序
|
||||
|
||||
1. 修改 `ginka/dataset.py`:删除 room/branch 分档,新增密度归一化和 `density_stats` 参数
|
||||
2. 修改 `ginka/maskGIT/model.py`:删除多余嵌入,新增 `density_proj`,更新 `cond_proj` 维度和 `forward`
|
||||
3. 修改 `ginka/train_seperated.py`:更新 `random_struct`、`random_density`、`annotate_labels`、数据集初始化
|
||||
4. 运行小规模过拟合测试(单 batch 跑 50 步)验证前向通路无误
|
||||
旧模型学习的是整图最终密度,或相对目标量的剩余比例;新模型学习的是按总面积计的动态剩余密度,并且新增了第一阶段墙壁条件和第二阶段入口条件。两者不是同一套任务定义,需要从头训练。
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
# 实体密度标签设计文档
|
||||
|
||||
> 本文档描述的是早期“离散密度标签”方案,现已不再作为当前实现方向。
|
||||
>
|
||||
> 当前有效设计请以 `docs/cond-simplify-design.md` 为准。最新方案不再把整图最终密度或按目标量归一化的剩余比例直接作为模型输入,而是改为基于当前输入地图动态计算“按总面积计的真实剩余密度”。
|
||||
|
||||
## 背景与问题
|
||||
|
||||
当前三阶段级联生成(stage1 骨架、stage2 功能实体、stage3 资源)在结构可行性上基本稳定,但存在明显的分布偏移:
|
||||
|
||||
@ -29,6 +29,7 @@ class GinkaSeperatedDataset(Dataset):
|
||||
MONSTER = 4
|
||||
ENTRANCE = 5
|
||||
MASK_ID = 6
|
||||
MAP_SIZE = 13 * 13
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -40,32 +41,33 @@ class GinkaSeperatedDataset(Dataset):
|
||||
total = sum(subset_weights)
|
||||
self.subset_cumw = [sum(subset_weights[:i+1]) / total for i in range(len(subset_weights))]
|
||||
|
||||
# 实体密度连续归一化:统计门/怪物/资源的数量,用 min/max 归一化到 [0, 1]
|
||||
# density_stats 为 None 时自行计算(训练集),否则复用外部传入的统计量(验证集)
|
||||
if density_stats is None:
|
||||
door_counts = [self.count_tile(item['map'], self.DOOR) for item in self.data]
|
||||
monster_counts = [self.count_tile(item['map'], self.MONSTER) for item in self.data]
|
||||
resource_counts = [self.count_tile(item['map'], self.RESOURCE) for item in self.data]
|
||||
self.density_stats = {
|
||||
"door_min": float(min(door_counts)),
|
||||
"door_max": float(max(door_counts)),
|
||||
"monster_min": float(min(monster_counts)),
|
||||
"monster_max": float(max(monster_counts)),
|
||||
"resource_min": float(min(resource_counts)),
|
||||
"resource_max": float(max(resource_counts)),
|
||||
}
|
||||
self.density_stats = self.compute_density_stats()
|
||||
else:
|
||||
self.density_stats = density_stats
|
||||
|
||||
def norm_density(self, count: int, key: str) -> float:
|
||||
eps = 1e-6
|
||||
lo = self.density_stats[f"{key}_min"]
|
||||
hi = self.density_stats[f"{key}_max"]
|
||||
return float(min(max((count - lo) / (hi - lo + eps), 0.0), 1.0))
|
||||
|
||||
def count_tile(self, map_data: list, tile_id: int) -> int:
|
||||
return sum(cell == tile_id for row in map_data for cell in row)
|
||||
|
||||
def compute_density_stats(self) -> dict:
|
||||
wall_densities = [self.count_tile(item['map'], self.WALL) / self.MAP_SIZE for item in self.data]
|
||||
door_densities = [self.count_tile(item['map'], self.DOOR) / self.MAP_SIZE for item in self.data]
|
||||
monster_densities = [self.count_tile(item['map'], self.MONSTER) / self.MAP_SIZE for item in self.data]
|
||||
entrance_densities = [self.count_tile(item['map'], self.ENTRANCE) / self.MAP_SIZE for item in self.data]
|
||||
resource_densities = [self.count_tile(item['map'], self.RESOURCE) / self.MAP_SIZE for item in self.data]
|
||||
return {
|
||||
"wall_min_density": float(min(wall_densities)),
|
||||
"wall_max_density": float(max(wall_densities)),
|
||||
"door_min_density": float(min(door_densities)),
|
||||
"door_max_density": float(max(door_densities)),
|
||||
"monster_min_density": float(min(monster_densities)),
|
||||
"monster_max_density": float(max(monster_densities)),
|
||||
"entrance_min_density": float(min(entrance_densities)),
|
||||
"entrance_max_density": float(max(entrance_densities)),
|
||||
"resource_min_density": float(min(resource_densities)),
|
||||
"resource_max_density": float(max(resource_densities)),
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
@ -94,9 +96,9 @@ class GinkaSeperatedDataset(Dataset):
|
||||
return mask
|
||||
|
||||
def create_degreaded(self, raw: np.ndarray):
|
||||
# 阶段一:生成墙壁和入口
|
||||
# 阶段一:仅生成墙壁骨架
|
||||
target1 = raw.copy()
|
||||
self.degrade_tile(target1, [self.DOOR, self.RESOURCE, self.MONSTER])
|
||||
self.degrade_tile(target1, [self.DOOR, self.RESOURCE, self.MONSTER, self.ENTRANCE])
|
||||
inp1 = target1.copy()
|
||||
|
||||
# 阶段二:生成怪物、门,同时也允许生成入口以适配结构
|
||||
@ -135,7 +137,7 @@ class GinkaSeperatedDataset(Dataset):
|
||||
return inp1, target1, enc1, inp2, target2, enc2, inp3, target3, enc3
|
||||
|
||||
def apply_subset2(self, raw: np.ndarray):
|
||||
# 子集 2:掩码所有内容,墙壁随机掩码,不掩码入口
|
||||
# 子集 2:墙壁随机掩码,其它阶段内容由后续阶段补全
|
||||
target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw)
|
||||
|
||||
enc1 = target1.copy()
|
||||
@ -183,10 +185,12 @@ class GinkaSeperatedDataset(Dataset):
|
||||
struct_inject = torch.LongTensor([cond_sym, cond_outer])
|
||||
|
||||
m = item['map']
|
||||
density_inject = torch.FloatTensor([
|
||||
self.norm_density(self.count_tile(m, self.DOOR), "door"),
|
||||
self.norm_density(self.count_tile(m, self.MONSTER), "monster"),
|
||||
self.norm_density(self.count_tile(m, self.RESOURCE), "resource"),
|
||||
target_density = torch.FloatTensor([
|
||||
self.count_tile(m, self.WALL) / self.MAP_SIZE,
|
||||
self.count_tile(m, self.DOOR) / self.MAP_SIZE,
|
||||
self.count_tile(m, self.MONSTER) / self.MAP_SIZE,
|
||||
self.count_tile(m, self.ENTRANCE) / self.MAP_SIZE,
|
||||
self.count_tile(m, self.RESOURCE) / self.MAP_SIZE,
|
||||
])
|
||||
|
||||
return {
|
||||
@ -200,5 +204,5 @@ class GinkaSeperatedDataset(Dataset):
|
||||
"target_stage3": torch.LongTensor(out[7]),
|
||||
"encoder_stage3": torch.LongTensor(out[8]),
|
||||
"struct_inject": struct_inject,
|
||||
"density_inject": density_inject
|
||||
"target_density": target_density
|
||||
}
|
||||
|
||||
@ -27,13 +27,13 @@ class GinkaMaskGIT(nn.Module):
|
||||
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
|
||||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||
|
||||
# 密度连续投影:将 3 个归一化浮点数 [door, monster, resource] ∈ [0,1] 投影为 d_z 维 token
|
||||
self.density_proj = nn.Linear(3, d_z)
|
||||
# 剩余密度投影:将 5 个浮点数 [wall, door, monster, entrance, resource] 投影为 d_z 维 token
|
||||
self.remain_proj = nn.Linear(5, d_z)
|
||||
|
||||
# z 投影:逐 token 线性变换,保持序列结构
|
||||
self.z_proj = nn.Linear(d_z, d_z)
|
||||
|
||||
# 条件融合投影:z_seq_len 个 z token + 2 个结构 token + 1 个密度 token
|
||||
# 条件融合投影:z_seq_len 个 z token + 2 个结构 token + 1 个剩余密度 token
|
||||
self.cond_proj = nn.Linear((z_seq_len + 3) * d_z, d_model)
|
||||
|
||||
# 纯 encoder Transformer,条件向量 c 通过 AdaLN 注入每一层
|
||||
@ -48,12 +48,12 @@ class GinkaMaskGIT(nn.Module):
|
||||
map: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
struct: torch.Tensor,
|
||||
density: torch.Tensor
|
||||
remain: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# map: [B, H * W]
|
||||
# z: [B, z_seq_len, d_z]
|
||||
# struct: [B, 2] — [cond_sym(0-7), cond_outer(0-1)]
|
||||
# density: [B, 3] float — [door_norm, monster_norm, resource_norm] ∈ [0, 1]
|
||||
# remain: [B, 5] float — [wall, door, monster, entrance, resource] 剩余密度
|
||||
|
||||
# 结构标签:sym + outer,各自嵌入为独立 token,stack 成序列 [B, 2, d_z]
|
||||
e_struct = torch.stack([
|
||||
@ -61,14 +61,14 @@ class GinkaMaskGIT(nn.Module):
|
||||
self.outer_embed(struct[:, 1])
|
||||
], dim=1)
|
||||
|
||||
# 密度:连续浮点向量投影为单个 d_z 维 token,[B, 1, d_z]
|
||||
e_density = self.density_proj(density).unsqueeze(1)
|
||||
# 剩余密度:连续浮点向量投影为单个 d_z 维 token,[B, 1, d_z]
|
||||
e_remain = self.remain_proj(remain).unsqueeze(1)
|
||||
|
||||
# z:逐 token 投影,保留序列结构 [B, z_seq_len, d_z]
|
||||
z_proj = self.z_proj(z)
|
||||
|
||||
# 拼接所有条件 token → [B, z_seq_len+3, d_z],展平后投影到 d_model
|
||||
cond_seq = torch.cat([z_proj, e_struct, e_density], dim=1)
|
||||
cond_seq = torch.cat([z_proj, e_struct, e_remain], dim=1)
|
||||
c = self.cond_proj(cond_seq.reshape(cond_seq.size(0), -1)) # [B, d_model]
|
||||
|
||||
# tile embedding + 位置编码
|
||||
@ -94,12 +94,12 @@ if __name__ == "__main__":
|
||||
[5, 1],
|
||||
[1, 0],
|
||||
], dtype=torch.long).to(device) # [4, 2] — [cond_sym, cond_outer]
|
||||
density_input = torch.tensor([
|
||||
[0.1, 0.5, 0.9],
|
||||
[0.8, 0.2, 0.4],
|
||||
[0.3, 0.7, 0.0],
|
||||
[0.6, 0.1, 1.0],
|
||||
], dtype=torch.float).to(device) # [4, 3] — [door_norm, monster_norm, resource_norm]
|
||||
remain_input = torch.tensor([
|
||||
[0.2, 0.1, 0.5, 0.1, 0.9],
|
||||
[0.4, 0.8, 0.2, 0.0, 0.4],
|
||||
[0.6, 0.3, 0.7, 0.1, 0.0],
|
||||
[0.5, 0.6, 0.1, 0.0, 1.0],
|
||||
], dtype=torch.float).to(device) # [4, 5] — [wall, door, monster, entrance, resource]
|
||||
|
||||
model = GinkaMaskGIT(
|
||||
num_classes=7,
|
||||
@ -116,7 +116,7 @@ if __name__ == "__main__":
|
||||
print_memory(device, "初始化后")
|
||||
|
||||
start = time.perf_counter()
|
||||
logits = model(map_input, z_input, struct_input, density_input)
|
||||
logits = model(map_input, z_input, struct_input, remain_input)
|
||||
end = time.perf_counter()
|
||||
|
||||
print_memory(device, "前向传播后")
|
||||
@ -124,7 +124,7 @@ if __name__ == "__main__":
|
||||
print(f"推理耗时: {end - start:.4f}s")
|
||||
print(f"输出形状: logits={logits.shape}")
|
||||
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
|
||||
print(f"Density Projection parameters: {sum(p.numel() for p in model.density_proj.parameters())}")
|
||||
print(f"Remain Projection parameters: {sum(p.numel() for p in model.remain_proj.parameters())}")
|
||||
print(f"Cond Projection parameters: {sum(p.numel() for p in model.cond_proj.parameters())}")
|
||||
print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}")
|
||||
print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}")
|
||||
|
||||
@ -48,10 +48,10 @@ VQ_D_MODEL = 128 # VQ-VAE Transformer 模型维度
|
||||
VQ_NHEAD = 4 # VQ-VAE 多头注意力头数
|
||||
|
||||
# 第一阶段 MaskGIT 超参
|
||||
STAGE1_MG_DMODEL = 256
|
||||
STAGE1_MG_DMODEL = 512
|
||||
STAGE1_MG_NHEAD = 4
|
||||
STAGE1_MG_NUM_LAYERS = 6
|
||||
STAGE1_MG_DIM_FF = 1024
|
||||
STAGE1_MG_DIM_FF = 2048
|
||||
|
||||
# 第二阶段 MaskGIT 超参
|
||||
STAGE2_MG_DMODEL = 256
|
||||
@ -81,9 +81,16 @@ MASK_TOKEN = 6 # 掩码图块
|
||||
MAP_W = 13 # 地图宽度
|
||||
MAP_H = 13 # 地图高度
|
||||
MAP_SIZE = MAP_W * MAP_H # 地图大小
|
||||
DENSITY_DIM = 5 # [wall, door, monster, entrance, resource]
|
||||
GENERATE_STEP = 18 # MaskGIT 采样步数
|
||||
SUBSET_WEIGHTS = (0.5, 0.3, 0.2) # 每个子集的概率
|
||||
|
||||
WALL_DENSITY_IDX = 0
|
||||
DOOR_DENSITY_IDX = 1
|
||||
MONSTER_DENSITY_IDX = 2
|
||||
ENTRANCE_DENSITY_IDX = 3
|
||||
RESOURCE_DENSITY_IDX = 4
|
||||
|
||||
MG_Z_DROPOUT = 0.1 # z 隐变量 Dropout 概率
|
||||
MG_STRUCT_DROPOUT = 0.1 # 结构参量 Dropout 概率
|
||||
|
||||
@ -100,7 +107,7 @@ EPOCHS = 400 # 总训练轮数
|
||||
CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
"cuda:0" if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
@ -179,14 +186,58 @@ def random_struct(device: torch.device) -> torch.Tensor:
|
||||
cond_outer = random.randint(0, 1) # 是否有外围走廈
|
||||
return torch.LongTensor([cond_sym, cond_outer]).unsqueeze(0).to(device)
|
||||
|
||||
def random_density(device: torch.device) -> torch.Tensor:
|
||||
# 随机采样一组密度参量,用于自由生成
|
||||
# density_inject 格式:[door_norm, monster_norm, resource_norm] ∈ [0, 1]
|
||||
return torch.rand(1, 3, device=device)
|
||||
def random_target_density(density_stats: dict, device: torch.device) -> torch.Tensor:
|
||||
# 从训练集真实密度范围中采样 wall / door / monster / entrance / resource 目标密度
|
||||
wall_density = random.uniform(density_stats["wall_min_density"], density_stats["wall_max_density"])
|
||||
door_density = random.uniform(density_stats["door_min_density"], density_stats["door_max_density"])
|
||||
monster_density = random.uniform(density_stats["monster_min_density"], density_stats["monster_max_density"])
|
||||
entrance_density = random.uniform(density_stats["entrance_min_density"], density_stats["entrance_max_density"])
|
||||
resource_density = random.uniform(density_stats["resource_min_density"], density_stats["resource_max_density"])
|
||||
return torch.FloatTensor([
|
||||
wall_density,
|
||||
door_density,
|
||||
monster_density,
|
||||
entrance_density,
|
||||
resource_density,
|
||||
]).unsqueeze(0).to(device)
|
||||
|
||||
def compute_remaining(
|
||||
current: torch.Tensor,
|
||||
target_density: torch.Tensor,
|
||||
stage: int
|
||||
) -> torch.Tensor:
|
||||
remain = torch.zeros(current.size(0), DENSITY_DIM, device=current.device)
|
||||
|
||||
visible_wall = (current == 1).sum(dim=1).float() / MAP_SIZE
|
||||
visible_door = (current == 2).sum(dim=1).float() / MAP_SIZE
|
||||
visible_monster = (current == 4).sum(dim=1).float() / MAP_SIZE
|
||||
visible_entrance = (current == 5).sum(dim=1).float() / MAP_SIZE
|
||||
visible_resource = (current == 3).sum(dim=1).float() / MAP_SIZE
|
||||
|
||||
if stage == 1:
|
||||
remain[:, WALL_DENSITY_IDX] = (
|
||||
target_density[:, WALL_DENSITY_IDX] - visible_wall
|
||||
).clamp(min=0.0, max=1.0)
|
||||
elif stage == 2:
|
||||
remain[:, DOOR_DENSITY_IDX] = (
|
||||
target_density[:, DOOR_DENSITY_IDX] - visible_door
|
||||
).clamp(min=0.0, max=1.0)
|
||||
remain[:, MONSTER_DENSITY_IDX] = (
|
||||
target_density[:, MONSTER_DENSITY_IDX] - visible_monster
|
||||
).clamp(min=0.0, max=1.0)
|
||||
remain[:, ENTRANCE_DENSITY_IDX] = (
|
||||
target_density[:, ENTRANCE_DENSITY_IDX] - visible_entrance
|
||||
).clamp(min=0.0, max=1.0)
|
||||
elif stage == 3:
|
||||
remain[:, RESOURCE_DENSITY_IDX] = (
|
||||
target_density[:, RESOURCE_DENSITY_IDX] - visible_resource
|
||||
).clamp(min=0.0, max=1.0)
|
||||
|
||||
return remain
|
||||
|
||||
def maskgit_sample(
|
||||
model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor,
|
||||
struct: torch.Tensor, density: torch.Tensor, steps: int,
|
||||
struct: torch.Tensor, target_density: torch.Tensor, stage: int, steps: int,
|
||||
target_tiles: list[int] | None = None, keep_fixed: bool = True
|
||||
) -> np.ndarray:
|
||||
# target_tiles: 本阶段负责生成的图块 ID 列表;None 表示接受所有类别(stage1)
|
||||
@ -204,7 +255,8 @@ def maskgit_sample(
|
||||
|
||||
# 迭代去掩码:每步根据置信度分数重新决定掩码位置
|
||||
for step in range(steps):
|
||||
logits = model(current, z, struct, density)
|
||||
remain = compute_remaining(current, target_density, stage)
|
||||
logits = model(current, z, struct, remain)
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
||||
dist = torch.distributions.Categorical(probs)
|
||||
@ -270,7 +322,8 @@ def maskgit_sample(
|
||||
# 目标模式下,未被填充的位置视为空地(不属于本阶段负责的图块)
|
||||
current[0, still_masked] = 0
|
||||
else:
|
||||
logits = model(current, z, struct, density)
|
||||
remain = compute_remaining(current, target_density, stage)
|
||||
logits = model(current, z, struct, remain)
|
||||
current[0, still_masked] = torch.argmax(logits[0, still_masked], dim=-1)
|
||||
|
||||
return current[0].cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
@ -278,7 +331,7 @@ def maskgit_sample(
|
||||
def full_generate_random_z(
|
||||
input: torch.Tensor,
|
||||
struct: torch.Tensor,
|
||||
density: torch.Tensor,
|
||||
target_density: torch.Tensor,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device,
|
||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||
@ -288,14 +341,18 @@ def full_generate_random_z(
|
||||
with torch.no_grad():
|
||||
z = quantizer.sample(1, VQ_L, device)
|
||||
|
||||
# stage1:生成 floor/wall 骨架
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, density, GENERATE_STEP, keep_fixed=keep_fixed[0])
|
||||
# stage1:生成墙壁骨架
|
||||
pred1_np = maskgit_sample(
|
||||
mg1, input.clone(), z, struct, target_density, 1,
|
||||
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
||||
)
|
||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
|
||||
|
||||
# stage2:在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并
|
||||
pred2_np = maskgit_sample(
|
||||
mg2, inp2, z, struct, density, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||
mg2, inp2, z, struct, target_density, 2,
|
||||
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||
)
|
||||
merged12 = pred1_np.copy()
|
||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
||||
@ -304,7 +361,8 @@ def full_generate_random_z(
|
||||
|
||||
# stage3:填充 resource(3)
|
||||
pred3_np = maskgit_sample(
|
||||
mg3, inp3, z, struct, density, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||
mg3, inp3, z, struct, target_density, 3,
|
||||
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||
)
|
||||
merged123 = merged12.copy()
|
||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
||||
@ -315,7 +373,7 @@ def full_generate_specific_z(
|
||||
input: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
struct: torch.Tensor,
|
||||
density: torch.Tensor,
|
||||
target_density: torch.Tensor,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device,
|
||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||
@ -324,12 +382,16 @@ def full_generate_specific_z(
|
||||
|
||||
with torch.no_grad():
|
||||
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, density, GENERATE_STEP, keep_fixed=keep_fixed[0])
|
||||
pred1_np = maskgit_sample(
|
||||
mg1, input.clone(), z, struct, target_density, 1,
|
||||
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
||||
)
|
||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp2[inp2 == 0] = MASK_TOKEN
|
||||
|
||||
pred2_np = maskgit_sample(
|
||||
mg2, inp2, z, struct, density, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||
mg2, inp2, z, struct, target_density, 2,
|
||||
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||
)
|
||||
merged12 = pred1_np.copy()
|
||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
||||
@ -337,7 +399,8 @@ def full_generate_specific_z(
|
||||
inp3[inp3 == 0] = MASK_TOKEN
|
||||
|
||||
pred3_np = maskgit_sample(
|
||||
mg3, inp3, z, struct, density, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||
mg3, inp3, z, struct, target_density, 3,
|
||||
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||
)
|
||||
merged123 = merged12.copy()
|
||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
||||
@ -354,15 +417,16 @@ def annotate(img: np.ndarray, text: str) -> np.ndarray:
|
||||
def annotate_labels(
|
||||
img: np.ndarray,
|
||||
struct: torch.Tensor,
|
||||
density: torch.Tensor
|
||||
target_density: torch.Tensor
|
||||
) -> np.ndarray:
|
||||
# 两行标注:第一行结构标签,第二行密度连续小数
|
||||
# 三行标注:第一行结构标签,后两行显示五维目标密度
|
||||
s = struct.tolist()
|
||||
d = density.tolist()
|
||||
d = target_density.tolist()
|
||||
line1 = f"sym:{s[0]} outer:{s[1]}"
|
||||
line2 = f"door:{d[0]:.2f} enemy:{d[1]:.2f} res:{d[2]:.2f}"
|
||||
line2 = f"wall:{d[0]:.2f} door:{d[1]:.2f}"
|
||||
line3 = f"enemy:{d[2]:.2f} ent:{d[3]:.2f} res:{d[4]:.2f}"
|
||||
img = img.copy()
|
||||
for text, y in [(line1, 12), (line2, 24)]:
|
||||
for text, y in [(line1, 12), (line2, 24), (line3, 36)]:
|
||||
cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 0), 2)
|
||||
cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)
|
||||
return img
|
||||
@ -428,10 +492,10 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
||||
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_t = batch["struct_inject"][0:1].to(device)
|
||||
density_t = batch["density_inject"][0:1].to(device)
|
||||
target_density_t = batch["target_density"][0:1].to(device)
|
||||
kf = rand_keep()
|
||||
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
|
||||
inp1_t, z_q[0:1], struct_t, density_t, models, device, keep_fixed=kf
|
||||
inp1_t, z_q[0:1], struct_t, target_density_t, models, device, keep_fixed=kf
|
||||
)
|
||||
kf_label = 'fix' if kf[0] else 'free'
|
||||
|
||||
@ -441,15 +505,15 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
||||
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
|
||||
struct_cpu = batch["struct_inject"][0]
|
||||
density_cpu = batch["density_inject"][0]
|
||||
target_density_cpu = batch["target_density"][0]
|
||||
|
||||
rows = [
|
||||
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
|
||||
[
|
||||
annotate(to_img(inp1_np), kf_label),
|
||||
annotate_labels(to_img(auto_pred1_np), struct_cpu, density_cpu),
|
||||
annotate_labels(to_img(auto_merged12), struct_cpu, density_cpu),
|
||||
annotate_labels(to_img(auto_merged123), struct_cpu, density_cpu)
|
||||
annotate_labels(to_img(auto_pred1_np), struct_cpu, target_density_cpu),
|
||||
annotate_labels(to_img(auto_merged12), struct_cpu, target_density_cpu),
|
||||
annotate_labels(to_img(auto_merged123), struct_cpu, target_density_cpu)
|
||||
],
|
||||
]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255
|
||||
@ -461,7 +525,7 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
||||
return grid
|
||||
|
||||
# 验证可视化 part3:2×3 网格;行1=参考输入+相同 struct 随机 z 生成,行2=随机 struct 生成
|
||||
def visualize_part3(batch, models, device, tile_dict):
|
||||
def visualize_part3(batch, models, device, tile_dict, density_stats: dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
@ -472,24 +536,28 @@ def visualize_part3(batch, models, device, tile_dict):
|
||||
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_ref = batch["struct_inject"][0:1].to(device)
|
||||
density_ref = batch["density_inject"][0:1].to(device)
|
||||
target_density_ref = batch["target_density"][0:1].to(device)
|
||||
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
struct_cpu = batch["struct_inject"][0]
|
||||
density_cpu = batch["density_inject"][0]
|
||||
target_density_cpu = batch["target_density"][0]
|
||||
|
||||
row1 = [to_img(inp1_np)]
|
||||
for _ in range(2):
|
||||
kf = rand_keep()
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, density_ref, models, device, keep_fixed=kf)
|
||||
row1.append(annotate_labels(to_img(merged123), struct_cpu, density_cpu))
|
||||
_, _, merged123 = full_generate_random_z(
|
||||
inp1_t, struct_ref, target_density_ref, models, device, keep_fixed=kf
|
||||
)
|
||||
row1.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
||||
|
||||
row2 = []
|
||||
for _ in range(3):
|
||||
kf = rand_keep()
|
||||
rnd_struct = random_struct(device)
|
||||
rnd_density = random_density(device)
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, rnd_struct, rnd_density, models, device, keep_fixed=kf)
|
||||
row2.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_density[0].cpu()))
|
||||
rnd_target_density = random_target_density(density_stats, device)
|
||||
_, _, merged123 = full_generate_random_z(
|
||||
inp1_t, rnd_struct, rnd_target_density, models, device, keep_fixed=kf
|
||||
)
|
||||
row2.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_target_density[0].cpu()))
|
||||
|
||||
rows = [row1, row2]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
@ -501,7 +569,7 @@ def visualize_part3(batch, models, device, tile_dict):
|
||||
return grid
|
||||
|
||||
# 验证可视化 part4:2×3 网格;以少量随机墙壁作为种子,纯随机 struct+z 自由生成
|
||||
def visualize_part4(models, device, tile_dict):
|
||||
def visualize_part4(models, device, tile_dict, density_stats: dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
@ -520,9 +588,11 @@ def visualize_part4(models, device, tile_dict):
|
||||
for _ in range(5):
|
||||
kf = rand_keep()
|
||||
rnd_struct = random_struct(device)
|
||||
rnd_density = random_density(device)
|
||||
_, _, merged123 = full_generate_random_z(seed, rnd_struct, rnd_density, models, device, keep_fixed=kf)
|
||||
results.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_density[0].cpu()))
|
||||
rnd_target_density = random_target_density(density_stats, device)
|
||||
_, _, merged123 = full_generate_random_z(
|
||||
seed, rnd_struct, rnd_target_density, models, device, keep_fixed=kf
|
||||
)
|
||||
results.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_target_density[0].cpu()))
|
||||
|
||||
row1 = [to_img(seed_np)] + results[:2]
|
||||
row2 = results[2:]
|
||||
@ -537,17 +607,18 @@ def visualize_part4(models, device, tile_dict):
|
||||
|
||||
def visualize_validate(
|
||||
batch, logits1, logits2, logits3, z_q,
|
||||
models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int, batch_idx: int
|
||||
models: list[torch.nn.Module], device: torch.device, tile_dict,
|
||||
density_stats: dict, epoch: int, batch_idx: int
|
||||
):
|
||||
save_dir = f"result/seperated/e{epoch}"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict, density_stats))
|
||||
cv2.imwrite(f"{save_dir}/dvar{batch_idx}.png", visualize_density_var(batch, z_q, models, device, tile_dict))
|
||||
|
||||
# 密度对照图:随机种子+随机结构,5 张随机密度生成,2×3 网格(左上角为种子图)
|
||||
def visualize_density_cmp(models, device, tile_dict):
|
||||
def visualize_density_cmp(models, device, tile_dict, density_stats: dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
@ -565,10 +636,10 @@ def visualize_density_cmp(models, device, tile_dict):
|
||||
struct_cpu = rnd_struct[0].cpu()
|
||||
gen_imgs = []
|
||||
for _ in range(5):
|
||||
rnd_density = random_density(device)
|
||||
density_cpu = rnd_density[0].cpu()
|
||||
_, _, merged123 = full_generate_random_z(seed, rnd_struct, rnd_density, models, device)
|
||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, density_cpu))
|
||||
rnd_target_density = random_target_density(density_stats, device)
|
||||
target_density_cpu = rnd_target_density[0].cpu()
|
||||
_, _, merged123 = full_generate_random_z(seed, rnd_struct, rnd_target_density, models, device)
|
||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
||||
row1 = [to_img(seed_np)] + gen_imgs[:2]
|
||||
row2 = gen_imgs[2:]
|
||||
rows = [row1, row2]
|
||||
@ -580,7 +651,7 @@ def visualize_density_cmp(models, device, tile_dict):
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
|
||||
# 固定 z 和结构条件,使用 5 个随机密度各生成一次,2×3 网格(左上角为参考地图)
|
||||
# 固定 z 和结构条件,扫描 5 个不同墙壁目标密度,2×3 网格(左上角为参考地图)
|
||||
def visualize_density_var(batch, z_q, models, device, tile_dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
@ -593,17 +664,18 @@ def visualize_density_var(batch, z_q, models, device, tile_dict):
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_t = batch["struct_inject"][0:1].to(device)
|
||||
struct_cpu = batch["struct_inject"][0]
|
||||
base_target_density = batch["target_density"][0:1].to(device)
|
||||
ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
gen_imgs = []
|
||||
# 固定 z 和结构条件,展开 5 个均匀密度水平对比不同密度条件的生成效果
|
||||
density_values = [0.0, 0.25, 0.5, 0.75, 1.0]
|
||||
for v in density_values:
|
||||
fixed_density = torch.FloatTensor([[v, v, v]]).to(device)
|
||||
density_cpu = fixed_density[0].cpu()
|
||||
wall_count_values = [20, 35, 50, 65, 80]
|
||||
for wall_count in wall_count_values:
|
||||
fixed_target_density = base_target_density.clone()
|
||||
fixed_target_density[0, WALL_DENSITY_IDX] = wall_count / MAP_SIZE
|
||||
target_density_cpu = fixed_target_density[0].cpu()
|
||||
_, _, merged123 = full_generate_specific_z(
|
||||
inp1_t, z_q[0:1], struct_t, fixed_density, models, device
|
||||
inp1_t, z_q[0:1], struct_t, fixed_target_density, models, device
|
||||
)
|
||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, density_cpu))
|
||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
||||
row1 = [to_img(ref_np)] + gen_imgs[:2]
|
||||
row2 = gen_imgs[2:]
|
||||
rows = [row1, row2]
|
||||
@ -615,7 +687,14 @@ def visualize_density_var(batch, z_q, models, device, tile_dict):
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
|
||||
def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int):
|
||||
def validate(
|
||||
dataloader: DataLoader,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device,
|
||||
tile_dict,
|
||||
density_stats: dict,
|
||||
epoch: int
|
||||
):
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
|
||||
# 切换为推理模式(关闭 Dropout / BatchNorm 统计更新)
|
||||
@ -628,16 +707,13 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc
|
||||
loss3_total = torch.Tensor([0]).to(device)
|
||||
commit_total = torch.Tensor([0]).to(device)
|
||||
|
||||
# 按小模块(Low/Medium/High)累计实体计数差(L1),用于诊断密度条件可控性
|
||||
# 密度连续小数按 [0,1/3)/[1/3,2/3)/[2/3,1] 分框到三模块
|
||||
# 结构:{tile_id: {bucket: [累计误差, 样本数]}}
|
||||
density_l1 = {
|
||||
2: {0: [0.0, 0], 1: [0.0, 0], 2: [0.0, 0]}, # door
|
||||
4: {0: [0.0, 0], 1: [0.0, 0], 2: [0.0, 0]}, # monster
|
||||
3: {0: [0.0, 0], 1: [0.0, 0], 2: [0.0, 0]}, # resource
|
||||
density_metrics = {
|
||||
1: {"mae": 0.0, "over": 0.0, "count": 0},
|
||||
2: {"mae": 0.0, "over": 0.0, "count": 0},
|
||||
4: {"mae": 0.0, "over": 0.0, "count": 0},
|
||||
5: {"mae": 0.0, "over": 0.0, "count": 0},
|
||||
3: {"mae": 0.0, "over": 0.0, "count": 0},
|
||||
}
|
||||
# 三类实体对应的 density_inject 索引
|
||||
tile_density_idx = {2: 0, 4: 1, 3: 2}
|
||||
|
||||
idx = 0
|
||||
|
||||
@ -658,7 +734,7 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc
|
||||
enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE)
|
||||
|
||||
struct = batch["struct_inject"].to(device)
|
||||
density = batch["density_inject"].to(device)
|
||||
target_density = batch["target_density"].to(device)
|
||||
|
||||
# VQ 编码:各阶段独立编码后拼接、量化
|
||||
z_e1 = vq1(enc1) # [B, L, d_z]
|
||||
@ -668,57 +744,63 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc
|
||||
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
|
||||
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
|
||||
|
||||
# 三阶段 MaskGIT 推理(均以完整 z_q、struct 和 density 为条件)
|
||||
logits1 = mg1(inp1, z_q, struct, density)
|
||||
logits2 = mg2(inp2, z_q, struct, density)
|
||||
logits3 = mg3(inp3, z_q, struct, density)
|
||||
remain1 = compute_remaining(inp1, target_density, 1)
|
||||
remain2 = compute_remaining(inp2, target_density, 2)
|
||||
remain3 = compute_remaining(inp3, target_density, 3)
|
||||
|
||||
# 三阶段 MaskGIT 推理(均以完整 z_q、struct 和动态 remain 为条件)
|
||||
logits1 = mg1(inp1, z_q, struct, remain1)
|
||||
logits2 = mg2(inp2, z_q, struct, remain2)
|
||||
logits3 = mg3(inp3, z_q, struct, remain3)
|
||||
|
||||
loss1_total += focal_loss(logits1, target1)
|
||||
loss2_total += focal_loss(logits2, target2)
|
||||
loss3_total += focal_loss(logits3, target3)
|
||||
commit_total += commit_loss
|
||||
|
||||
# 计算 argmax 预测并统计各档位密度 L1(预测计数与真实计数之差的绝对值)
|
||||
# 计算各目标对象的真实密度误差与过量生成密度
|
||||
pred1_map = torch.argmax(logits1, dim=-1).cpu()
|
||||
pred2_map = torch.argmax(logits2, dim=-1).cpu() # [B, MAP_SIZE]
|
||||
pred3_map = torch.argmax(logits3, dim=-1).cpu()
|
||||
true1_map = target1.cpu() # [B, MAP_SIZE]
|
||||
true2_map = target2.cpu() # [B, MAP_SIZE]
|
||||
true3_map = target3.cpu()
|
||||
density_cpu = batch["density_inject"] # [B, 3]
|
||||
for b in range(pred2_map.size(0)):
|
||||
for tile_id, d_idx in tile_density_idx.items():
|
||||
if tile_id == 3:
|
||||
pred_map = pred3_map[b]
|
||||
true_map = true3_map[b]
|
||||
else:
|
||||
pred_map = pred2_map[b]
|
||||
true_map = true2_map[b]
|
||||
metric_sources = [
|
||||
(1, pred1_map, true1_map),
|
||||
(2, pred2_map, true2_map),
|
||||
(4, pred2_map, true2_map),
|
||||
(5, pred2_map, true2_map),
|
||||
(3, pred3_map, true3_map),
|
||||
]
|
||||
for tile_id, pred_map_batch, true_map_batch in metric_sources:
|
||||
for batch_idx in range(pred_map_batch.size(0)):
|
||||
pred_map = pred_map_batch[batch_idx]
|
||||
true_map = true_map_batch[batch_idx]
|
||||
pred_count = float((pred_map == tile_id).sum().item())
|
||||
true_count = float((true_map == tile_id).sum().item())
|
||||
# 连续密度按 [0,1/3)/[1/3,2/3)/[2/3,1] 分戆到三模块
|
||||
lv = min(int(density_cpu[b, d_idx].item() * 3), 2)
|
||||
density_l1[tile_id][lv][0] += abs(pred_count - true_count)
|
||||
density_l1[tile_id][lv][1] += 1
|
||||
density_metrics[tile_id]["mae"] += abs(pred_count - true_count) / MAP_SIZE
|
||||
density_metrics[tile_id]["over"] += max(pred_count - true_count, 0.0) / MAP_SIZE
|
||||
density_metrics[tile_id]["count"] += 1
|
||||
|
||||
# 每个 batch 生成三种可视化图(val/full/rand)
|
||||
visualize_validate(batch, logits1, logits2, logits3, z_q, models, device, tile_dict, epoch, idx)
|
||||
visualize_validate(
|
||||
batch, logits1, logits2, logits3, z_q,
|
||||
models, device, tile_dict, density_stats, epoch, idx
|
||||
)
|
||||
idx += 1
|
||||
|
||||
# 输出密度 L1 统计(各小模块内的平均实体计数,供诊断密度条件效果)
|
||||
bucket_names = ['Low', 'Medium', 'High']
|
||||
tile_names = {2: 'door', 4: 'enemy', 3: 'resource'}
|
||||
for tile_id in [2, 4, 3]:
|
||||
parts = []
|
||||
for lv in range(3):
|
||||
acc, cnt = density_l1[tile_id][lv]
|
||||
avg = acc / cnt if cnt > 0 else 0.0
|
||||
parts.append(f"{bucket_names[lv]}={avg:.2f}")
|
||||
tqdm.write(f" density {tile_names[tile_id]}: {' '.join(parts)}")
|
||||
tile_names = {1: 'wall', 2: 'door', 4: 'enemy', 5: 'entrance', 3: 'resource'}
|
||||
for tile_id in [1, 2, 4, 5, 3]:
|
||||
count = density_metrics[tile_id]["count"]
|
||||
avg_mae = density_metrics[tile_id]["mae"] / count if count > 0 else 0.0
|
||||
avg_over = density_metrics[tile_id]["over"] / count if count > 0 else 0.0
|
||||
tqdm.write(f" density {tile_names[tile_id]}: mae={avg_mae:.4f} over={avg_over:.4f}")
|
||||
|
||||
save_dir = f"result/seperated/e{epoch}"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# 每个 epoch 额外生成:无条件自由生成图 + 全局密度对照图
|
||||
cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/density_cmp.png", visualize_density_cmp(models, device, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict, density_stats))
|
||||
cv2.imwrite(f"{save_dir}/density_cmp.png", visualize_density_cmp(models, device, tile_dict, density_stats))
|
||||
|
||||
# 恢复训练模式
|
||||
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
||||
@ -811,9 +893,9 @@ def train(device: torch.device):
|
||||
target3 = batch["target_stage3"].to(device).reshape(-1, MAP_SIZE)
|
||||
enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE)
|
||||
|
||||
# 结构条件向量:[cond_sym, cond_room, cond_branch, cond_outer]
|
||||
# 结构条件向量:[cond_sym, cond_outer]
|
||||
struct = batch["struct_inject"].to(device)
|
||||
density = batch["density_inject"].to(device)
|
||||
target_density = batch["target_density"].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
@ -826,10 +908,14 @@ def train(device: torch.device):
|
||||
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
|
||||
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
|
||||
|
||||
# 三阶段 MaskGIT 前向(均接收完整三阶段 z_q、struct 和 density 条件)
|
||||
logits1 = mg1(inp1, z_q, struct, density)
|
||||
logits2 = mg2(inp2, z_q, struct, density)
|
||||
logits3 = mg3(inp3, z_q, struct, density)
|
||||
remain1 = compute_remaining(inp1, target_density, 1)
|
||||
remain2 = compute_remaining(inp2, target_density, 2)
|
||||
remain3 = compute_remaining(inp3, target_density, 3)
|
||||
|
||||
# 三阶段 MaskGIT 前向(均接收完整三阶段 z_q、struct 和动态 remain 条件)
|
||||
logits1 = mg1(inp1, z_q, struct, remain1)
|
||||
logits2 = mg2(inp2, z_q, struct, remain2)
|
||||
logits3 = mg3(inp3, z_q, struct, remain3)
|
||||
|
||||
# 三阶段 Focal Loss + VQ commit loss 加权求和
|
||||
loss1 = focal_loss(logits1, target1)
|
||||
@ -867,7 +953,7 @@ def train(device: torch.device):
|
||||
|
||||
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
|
||||
if (epoch + 1) % CHECKPOINT == 0:
|
||||
losses = validate(dataloader_val, models, device, tile_dict, epoch + 1)
|
||||
losses = validate(dataloader_val, models, device, tile_dict, dataset.density_stats, epoch + 1)
|
||||
loss1_total, loss2_total, loss3_total, commit_total = losses
|
||||
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1_total
|
||||
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2_total
|
||||
|
||||
Loading…
Reference in New Issue
Block a user