ginka-generator/docs/cond-simplify-design.md

9.8 KiB
Raw Blame History

条件简化与密度连续化设计文档

背景

当前三阶段级联生成模型的条件系统存在以下问题:

  1. 结构条件中的房间数和分支数对生成指导意义有限:这两个指标依赖数据集中预计算的离散分档,与实际生成质量的相关性较弱,且分档边界处噪声大,容易引入无效条件信号。

  2. 实体密度条件(门/怪物/资源)的离散三档存在明显一对多问题:三档划分过于粗糙,同一档内样本分布差异极大(例如 Medium 档中资源数可以从 2 到 8 不等),导致模型无法建立条件与生成结果之间的精确映射。连续值能够更精确地描述目标密度,避免档位内分布散乱导致的条件信号模糊。

改动总览

模块 改动类型 说明
ginka/dataset.py 修改 删除房间/分支分档;密度改为连续归一化;输出 FloatTensor
ginka/maskGIT/model.py 修改 删除房间/分支嵌入;密度嵌入层改为线性投影;更新 cond_proj 维度
ginka/train_seperated.py 修改 更新 random_struct/random_density更新 annotate_labels

一、条件向量格式变更

1.1 struct_inject

当前格式4 个离散整数):

[cond_sym(0-7), cond_room(0-2), cond_branch(0-2), cond_outer(0-1)]

新格式2 个离散整数,删除 room 和 branch

[cond_sym(0-7), cond_outer(0-1)]

cond_sym 的计算方式不变(水平/垂直/中心对称的三位二进制组合07cond_outer 不变。

1.2 density_inject

当前格式3 个离散整数,LongTensor

[door_level(0-2), monster_level(0-2), resource_level(0-2)]

新格式3 个连续浮点数,FloatTensor,值域 [0, 1]

[door_norm, monster_norm, resource_norm]  ∈ [0.0, 1.0]^3

二、密度归一化方案

2.1 统计量定义

在训练集初始化阶段,对原始地图统计三类图块的实际数量(非密度,直接计数):

  • door_count = 图块ID为2的数量
  • monster_count = 图块ID为4的数量
  • resource_count = 图块ID为3的数量

对每类分别求训练集内的 最小值最大值

door_min, door_max
monster_min, monster_max
resource_min, resource_max

2.2 归一化公式

对每个样本的 count归一化为 [0, 1]


\text{norm}(x) = \frac{x - x_{\min}}{x_{\max} - x_{\min} + \epsilon}

其中 $\epsilon = 1\text{e-}6$,防止分母为零(当所有样本计数相同时)。

结果裁剪到 [0, 1]norm = clamp(norm, 0.0, 1.0)

2.3 验证集复用训练集统计量

GinkaSeperatedDataset 新增参数:

def __init__(
    self,
    data_path: str,
    subset_weights: tuple = (0.5, 0.3, 0.2),
    density_stats: dict | None = None   # 新增:外部传入统计量
):
  • 训练集:density_stats=None,自行计算并保存 min/maxself.density_stats
  • 验证集:传入训练集的 self.density_stats,直接复用,保证归一化语义一致

density_stats 的结构:

{
    "door_min": float, "door_max": float,
    "monster_min": float, "monster_max": float,
    "resource_min": float, "resource_max": float,
}

2.4 输出字段变更

__getitem__density_injectLongTensor 改为 FloatTensor

# 删除旧的离散分档逻辑
density_inject = torch.FloatTensor([
    self.norm_density(count_door, "door"),
    self.norm_density(count_monster, "monster"),
    self.norm_density(count_resource, "resource"),
])

删除以下字段(不再写入 item 也不再输出):

  • doorDensityLevel, monsterDensityLevel, resourceDensityLevel
  • roomCountLevel, branchLevel

删除以下实例变量:

  • self.room_th, self.branch_th
  • self.door_density_th, self.monster_density_th, self.resource_density_th

三、模型结构变更(model.py

3.1 删除房间/分支嵌入

删除:

self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)

保留:

self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)   # SYM_VOCAB = 8
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) # OUTER_VOCAB = 2

删除的常量:ROOM_VOCAB, BRANCH_VOCAB,保留 SYM_VOCAB, OUTER_VOCAB

3.2 密度嵌入层改为线性投影

删除:

self.door_density_embed = nn.Embedding(DOOR_DENSITY_VOCAB, d_z)
self.monster_density_embed = nn.Embedding(MONSTER_DENSITY_VOCAB, d_z)
self.resource_density_embed = nn.Embedding(RESOURCE_DENSITY_VOCAB, d_z)

删除的常量:DOOR_DENSITY_VOCAB, MONSTER_DENSITY_VOCAB, RESOURCE_DENSITY_VOCAB

新增:

# 连续密度投影:将 3 个归一化浮点数映射为 1 个 d_z 维 token
self.density_proj = nn.Linear(3, d_z)

3.3 cond_proj 维度更新

当前 cond_seq 形状[B, z_seq_len + 4_struct + 3_density, d_z],即 [B, z_seq_len+7, d_z],展平后输入维度 (z_seq_len+7) * d_z

新 cond_seq 形状[B, z_seq_len + 2_struct + 1_density, d_z],即 [B, z_seq_len+3, d_z],展平后输入维度 (z_seq_len+3) * d_z

# 旧
self.cond_proj = nn.Linear((z_seq_len + 7) * d_z, d_model)
# 新
self.cond_proj = nn.Linear((z_seq_len + 3) * d_z, d_model)

3.4 forward 流程变更

def forward(
    self,
    map: torch.Tensor,
    z: torch.Tensor,
    struct: torch.Tensor,   # [B, 2]  ← 由 [B, 4] 改为 [B, 2]
    density: torch.Tensor   # [B, 3] float ← 由 [B, 3] long 改为 float
) -> torch.Tensor:

    # 结构标签sym + outer各嵌入为 d_z 维 token
    e_struct = torch.stack([
        self.sym_embed(struct[:, 0]),   # [B, d_z]
        self.outer_embed(struct[:, 1]), # [B, d_z]
    ], dim=1)  # [B, 2, d_z]

    # 密度:连续值投影为单个 d_z 维 token
    e_density = self.density_proj(density).unsqueeze(1)  # [B, 1, d_z]

    # z逐 token 投影(不变)
    z_proj = self.z_proj(z)  # [B, z_seq_len, d_z]

    # 拼接 → [B, z_seq_len+3, d_z] → 展平 → 投影到 d_model
    cond_seq = torch.cat([z_proj, e_struct, e_density], dim=1)
    c = self.cond_proj(cond_seq.reshape(cond_seq.size(0), -1))  # [B, d_model]

    # 后续不变tile embedding + Transformer + output_fc

四、训练脚本变更(train_seperated.py

4.1 random_struct

def random_struct(device: torch.device) -> torch.Tensor:
    # struct_inject 格式:[cond_sym(0-7), cond_outer(0-1)]
    cond_sym = random.randint(0, 7)   # 地图对称类型
    cond_outer = random.randint(0, 1) # 是否有外围走廊
    return torch.LongTensor([cond_sym, cond_outer]).unsqueeze(0).to(device)

4.2 random_density

def random_density(device: torch.device) -> torch.Tensor:
    # density_inject 格式:[door_norm, monster_norm, resource_norm] ∈ [0, 1]
    return torch.rand(1, 3, device=device)  # 均匀分布随机采样

4.3 annotate_labels

更新标注格式,删除 room/branch密度显示为两位小数

def annotate_labels(
    img: np.ndarray,
    struct: torch.Tensor,   # [2] long
    density: torch.Tensor   # [3] float
) -> np.ndarray:
    s = struct.tolist()
    d = density.tolist()
    line1 = f"sym:{s[0]} outer:{s[1]}"
    line2 = f"door:{d[0]:.2f} enemy:{d[1]:.2f} res:{d[2]:.2f}"
    img = img.copy()
    for text, y in [(line1, 12), (line2, 24)]:
        cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 0), 2)
        cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)
    return img

4.4 训练集与验证集初始化

train_dataset = GinkaSeperatedDataset(args.train, subset_weights=SUBSET_WEIGHTS)
val_dataset = GinkaSeperatedDataset(
    args.validate,
    subset_weights=SUBSET_WEIGHTS,
    density_stats=train_dataset.density_stats  # 复用训练集统计量
)

4.5 DataLoader collate_fnFloatTensor 适配)

PyTorch 默认 collate 会自动将 FloatTensor 列表合并为 float 类型批张量,无需额外修改 DataLoader 配置。

4.6 验证阶段密度对照图density_var

visualize_density_var 内对比不同密度条件时,改为使用 5 个均匀分布采样点:

# 旧三档枚举density_levels = [0, 1, 2]
# 新连续采样5 个均匀间隔值
density_values = [0.0, 0.25, 0.5, 0.75, 1.0]
for v in density_values:
    d = torch.FloatTensor([[v, v, v]]).to(device)  # 三类等密度扫描
    ...

五、不需要改动的部分

  • ginka/maskGIT/maskGIT.pyAdaLN / CondTransformerLayer / Transformer 均不感知条件维度,无需修改
  • ginka/vqvae/ 目录VQ-VAE 部分与条件系统无关
  • ginka/train_seperated.py 中的 maskgit_samplefull_generate_random_zfull_generate_specific_z:接口签名不变(仍接受 struct/density 张量),内部无直接操作条件内容,无需修改
  • data/ 目录的 TypeScript 数据处理脚本数据文件格式不变Python 端自行计算标签

六、旧 checkpoint 兼容性

由于 cond_proj 输入维度和嵌入层数量均发生变化,旧 checkpoint 不兼容,需从头训练。


七、实施顺序

  1. 修改 ginka/dataset.py:删除 room/branch 分档,新增密度归一化和 density_stats 参数
  2. 修改 ginka/maskGIT/model.py:删除多余嵌入,新增 density_proj,更新 cond_proj 维度和 forward
  3. 修改 ginka/train_seperated.py:更新 random_structrandom_densityannotate_labels、数据集初始化
  4. 运行小规模过拟合测试(单 batch 跑 50 步)验证前向通路无误