From b52bfdb78f357c4e76cdac7823dab5a96cb20d84 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Fri, 15 May 2026 16:11:53 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=B5=84=E6=BA=90=E3=80=81=E6=80=AA?= =?UTF-8?q?=E7=89=A9=E3=80=81=E9=97=A8=E6=95=B0=E9=87=8F=E6=A0=87=E7=AD=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/entity-density-labels-design.md | 178 +++++++++++++++++++++ ginka/dataset.py | 33 +++- ginka/maskGIT/maskGIT.py | 2 +- ginka/maskGIT/model.py | 39 ++++- ginka/train_seperated.py | 229 ++++++++++++++++++++++----- 5 files changed, 437 insertions(+), 44 deletions(-) create mode 100644 docs/entity-density-labels-design.md diff --git a/docs/entity-density-labels-design.md b/docs/entity-density-labels-design.md new file mode 100644 index 0000000..24e011a --- /dev/null +++ b/docs/entity-density-labels-design.md @@ -0,0 +1,178 @@ +# 实体密度标签设计文档 + +## 背景与问题 + +当前三阶段级联生成(stage1 骨架、stage2 功能实体、stage3 资源)在结构可行性上基本稳定,但存在明显的分布偏移: + +- 怪物数量偏多 +- 资源数量偏多 +- 门数量在部分样本上也偏高 + +已尝试在采样阶段通过“随机抛弃部分新揭开位并重新掩码”的方式抑制过密生成,但效果不稳定,核心原因是该策略属于推理期启发式约束,不能从训练目标层面改变模型对全局密度的先验。 + +因此需要引入显式条件:将每张地图中门、怪物、资源的密度离散为三档(低/中/高),并在训练和推理时作为条件输入,让模型学习“在指定密度档位下生成”。 + +## 目标 + +- 新增 3 个可控标签:`doorDensityLevel`、`monsterDensityLevel`、`resourceDensityLevel`,取值均为 `0 | 1 | 2`。 +- 标签计算与分档在 Python 端完成,保持与现有 `roomCountLevel`、`branchLevel` 一致的处理方式。 +- 标签注入模型后,支持在推理时显式控制三类实体密度。 +- 在不改动数据处理端(TypeScript)的前提下完成接入。 + +## 设计原则 + +- 统计口径稳定:密度分母采用固定地图面积(13x13),避免受随机掩码影响。 +- 分档可迁移:使用训练集等频分箱阈值;验证/推理复用同一阈值。 +- 最小侵入:优先扩展现有 Python 数据集与条件注入链路,不改变数据文件格式。 +- 可回溯:训练日志与可视化中输出目标密度档位与实际密度,便于诊断。 + +## 标签定义 + +### 1. 统计对象 + +基于原始地图 `item['map']`(未掩码、未降级)统计三类图块数量: + +- `doorCount`: 图块 ID = 2 +- `resourceCount`: 图块 ID = 3 +- `monsterCount`: 图块 ID = 4 + +### 2. 密度定义 + +设地图面积为 `MAP_SIZE = 13 * 13 = 169`,则: + +- `doorDensity = doorCount / 169` +- `monsterDensity = monsterCount / 169` +- `resourceDensity = resourceCount / 169` + +### 3. 分档定义 + +采用等频分箱(三档)并与现有 `to_level` 规则一致: + +- 训练集上收集某一密度指标的全量样本值,升序排序 +- 取 `n/3` 与 `2n/3` 位置作为阈值 `th1`、`th2` +- 分档规则: + - `< th1` -> `0`(Low) + - `>= th1 且 < th2` -> `1`(Medium) + - `>= th2` -> `2`(High) + +阈值退化处理(与现有实现一致): + +- 若 `th1 == th2`,将 `th2 = th1 + eps` +- 对密度值建议 `eps = 1e-6` + +## Python 端处理方案 + +### 1. 数据集初始化阶段 + +在 `GinkaSeperatedDataset.__init__` 中新增一次统计流程: + +- 从 `self.data` 中提取每张图的 `doorDensity`、`monsterDensity`、`resourceDensity` +- 分别计算三组阈值: + - `self.door_density_th` + - `self.monster_density_th` + - `self.resource_density_th` +- 回填每个样本: + - `item['doorDensityLevel']` + - `item['monsterDensityLevel']` + - `item['resourceDensityLevel']` + +### 2. 样本输出阶段 + +在 `__getitem__` 返回字典中新增条件向量(建议独立字段,避免影响旧逻辑): + +- `density_inject = LongTensor([doorLevel, monsterLevel, resourceLevel])` + +不建议直接复用旧 `struct_inject` 覆盖含义。推荐并行保留: + +- `struct_inject`:结构语义(对称/房间/分支/外墙) +- `density_inject`:实体密度语义(门/怪物/资源) + +## 模型接入方案 + +### 1. 条件输入组织 + +密度条件与结构条件在语义上完全不同(结构描述地图拓扑形态,密度描述实体数量先验),不复用 `struct_inject` 的处理路径。 + +设计:在 MaskGIT 内新增一个独立的**密度 MLP**: + +- 输入:3 个独立 embedding 表(每档取值 0/1/2)输出相加后的向量 + - `emb_door_density: Embedding(3, d_embed)` + - `emb_monster_density: Embedding(3, d_embed)` + - `emb_resource_density: Embedding(3, d_embed)` +- 三个 embedding 相加后送入 2 层 MLP(`d_embed -> d_model -> d_model`,激活函数 GELU),输出一个 `d_model` 维向量 +- 该向量作为独立条件 token 拼接到主序列头部(与 struct token 并列,不替换) + +结构条件(`struct_inject`)保留原有处理方式不变。 + +### 2. 训练与推理接口 + +- 训练前向:`mgX(inpX, z_q, struct_inject, density_inject)` +- 推理采样:允许显式指定密度档位;未指定时可随机采样档位或使用数据先验分布采样 + +### 3. 条件 Dropout + +对密度条件增加独立 dropout(例如 0.1): + +- 训练时随机置空部分密度条件,降低过拟合风险 +- 推理时可在“无密度条件”与“强密度条件”两种模式间切换 + +## 训练与验证改造 + +### 1. 日志指标 + +在验证阶段新增统计输出: + +- 按档位分组的密度 L1 误差:分别统计 door/monster/resource 三类实体在 Low/Medium/High 三档条件下,生成地图实际计数与档位中位期望值之间的 L1 距离(仅用于观察,不参与反向传播) + +无需额外输出目标档位分布或实际密度均值,档位 L1 已足够直观反映控制效果。 + +### 2. 可视化对照 + +在每张验证生成图上直接标注所有条件标签,分两行显示: + +- 第一行(结构标签):`sym=N room=L/M/H branch=L/M/H outer=0/1` +- 第二行(密度标签):`d=L/M/H m=L/M/H r=L/M/H` + +其中 `sym` 取 `cond_sym` 的原始整数值(0–7),`room`/`branch`/`d`/`m`/`r` 均以 `L`/`M`/`H` 表示三档。 + +标注位置:图像顶部左上角,两行叠加,与现有 `fix`/`free` 标注并列(可追加到同一 `annotate` 调用后)。 + +额外新增一类对照图:固定同一 `z` 和结构条件,仅扫遍密度档位(Low/Medium/High 三档),分别生成地图并排排列,用于直观验证"只改密度条件,生成实体数量随档位单调变化"。该对照图在每个 checkpoint 验证时生成一次,保存到 `result/seperated/eN/density_cmp.png`。 + +### 3. 验收标准 + +至少满足以下条件后再认为方案有效: + +- 同一结构条件下,密度档位从 Low -> High 时,三类实体计数总体单调上升 +- 验证集上各档位的目标-实际密度 MAE 明显低于未加标签版本 +- 地图可玩性不退化(入口可达、关键路径连通性不显著恶化) + +## 与现有流程的兼容性 + +- 数据源 JSON 无需新增字段。 +- 标签在 Python 读取后即时计算,不影响 `data/` 侧脚本。 +- 旧 checkpoint 不兼容新增输入维度,需要从旧权重迁移或重新训练。 + +## 实施步骤建议 + +1. 在数据集类中实现三类密度统计、分档和 `density_inject` 返回。 +2. 扩展 MaskGIT 条件嵌入与前向接口,打通三阶段训练调用。 +3. 更新训练/验证日志与可视化标注,增加按档位评估。 +4. 先做小规模过拟合与对照采样验证,再进入完整训练。 + +## 风险与应对 + +- 风险:档位边界样本噪声大,模型学习不稳定。 + - 应对:引入软标签邻域采样(可选)或在损失中增加密度一致性正则。 + +- 风险:实体密度受结构强约束,条件可控性受限。 + - 应对:在评估中按结构复杂度分组分析,必要时引入结构-密度联合条件建模。 + +- 风险:三阶段相互影响导致 stage2/stage3 条件冲突。 + - 应对:分别监控阶段内计数与最终合并计数,必要时增加阶段特异性权重。 + +## 后续可扩展方向 + +- 将三档扩展为五档,提升控制精度。 +- 在密度标签之外增加“功能实体聚集度/均匀度”标签。 +- 引入条件一致性判别器,进一步约束生成结果与目标档位一致。 diff --git a/ginka/dataset.py b/ginka/dataset.py index 3c9d9b9..fa74a20 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -53,9 +53,33 @@ class GinkaSeperatedDataset(Dataset): item['roomCountLevel'] = self.to_level(item['roomCount'], self.room_th) item['branchLevel'] = self.to_level(item['highDegBranchCount'], self.branch_th) + # 实体密度等级:统计原始地图中门/怪物/资源的数量,等频三档 + eps = 1e-6 + door_counts = sorted(self.count_tile(item['map'], self.DOOR) for item in self.data) + monster_counts = sorted(self.count_tile(item['map'], self.MONSTER) for item in self.data) + resource_counts = sorted(self.count_tile(item['map'], self.RESOURCE) for item in self.data) + th1_d, th2_d = door_counts[n // 3], door_counts[2 * n // 3] + th1_m, th2_m = monster_counts[n // 3], monster_counts[2 * n // 3] + th1_rc, th2_rc = resource_counts[n // 3], resource_counts[2 * n // 3] + if th1_d == th2_d: th2_d = th1_d + eps + if th1_m == th2_m: th2_m = th1_m + eps + if th1_rc == th2_rc: th2_rc = th1_rc + eps + self.door_density_th = (th1_d, th2_d) + self.monster_density_th = (th1_m, th2_m) + self.resource_density_th = (th1_rc, th2_rc) + + for item in self.data: + m = item['map'] + item['doorDensityLevel'] = self.to_level(self.count_tile(m, self.DOOR), self.door_density_th) + item['monsterDensityLevel'] = self.to_level(self.count_tile(m, self.MONSTER), self.monster_density_th) + item['resourceDensityLevel'] = self.to_level(self.count_tile(m, self.RESOURCE), self.resource_density_th) + def to_level(self, v, th): return 0 if v < th[0] else (1 if v < th[1] else 2) + def count_tile(self, map_data: list, tile_id: int) -> int: + return sum(cell == tile_id for row in map_data for cell in row) + def __len__(self): return len(self.data) @@ -174,6 +198,12 @@ class GinkaSeperatedDataset(Dataset): cond_outer = item['outerWall'] struct_inject = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]) + density_inject = torch.LongTensor([ + item['doorDensityLevel'], + item['monsterDensityLevel'], + item['resourceDensityLevel'] + ]) + return { "input_stage1": torch.LongTensor(out[0]), "target_stage1": torch.LongTensor(out[1]), @@ -184,5 +214,6 @@ class GinkaSeperatedDataset(Dataset): "input_stage3": torch.LongTensor(out[6]), "target_stage3": torch.LongTensor(out[7]), "encoder_stage3": torch.LongTensor(out[8]), - "struct_inject": struct_inject + "struct_inject": struct_inject, + "density_inject": density_inject } diff --git a/ginka/maskGIT/maskGIT.py b/ginka/maskGIT/maskGIT.py index d41d4fe..51cf5b8 100644 --- a/ginka/maskGIT/maskGIT.py +++ b/ginka/maskGIT/maskGIT.py @@ -16,7 +16,7 @@ class Transformer(nn.Module): ) def forward(self, x, memory=None): - # x: [B, S, d_model] 地图 token 序列 + # x: [B, S, d_model] 地图 token 序列 # memory: [B, L, d_model] 可选的 z 投影,用于 cross-attention # 若 memory 为 None,则退化为原始自编解码行为(向后兼容) enc_out = self.encoder(x) diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index b99ce12..bb2390f 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -10,10 +10,16 @@ ROOM_VOCAB = 3 # roomCountLevel 0-2 BRANCH_VOCAB = 3 # branchLevel 0-2 OUTER_VOCAB = 2 # outerWall 0-1 +# 密度标签词表大小(Low/Medium/High 三档) +DOOR_DENSITY_VOCAB = 3 +MONSTER_DENSITY_VOCAB = 3 +RESOURCE_DENSITY_VOCAB = 3 + class GinkaMaskGIT(nn.Module): def __init__( self, num_classes: int = 16, d_model: int = 192, dim_ff: int = 512, - nhead: int = 8, num_layers: int = 4, map_h: int = 13, map_w: int = 13, d_z: int = 64 + nhead: int = 8, num_layers: int = 4, map_h: int = 13, map_w: int = 13, + d_z: int = 64 ): super().__init__() self.map_h = map_h @@ -57,15 +63,31 @@ class GinkaMaskGIT(nn.Module): self.output_fc = nn.Linear(d_model, num_classes) + # 密度标签嵌入 + 独立 MLP(与结构路径完全分离) + # 三个密度 embedding 相加后经两层 MLP 映射为单个条件 token + self.door_density_embed = nn.Embedding(DOOR_DENSITY_VOCAB, d_z) + self.monster_density_embed = nn.Embedding(MONSTER_DENSITY_VOCAB, d_z) + self.resource_density_embed = nn.Embedding(RESOURCE_DENSITY_VOCAB, d_z) + self.density_mlp = nn.Sequential( + nn.Linear(d_z, d_model * 2), + nn.LayerNorm(d_model * 2), + nn.GELU(), + + nn.Linear(d_model * 2, d_model), + nn.LayerNorm(d_model) + ) + def forward( self, map: torch.Tensor, z: torch.Tensor, - struct: torch.Tensor + struct: torch.Tensor, + density: torch.Tensor ) -> torch.Tensor: # map: [B, H * W] # z: [B, L * 3, d_z] - # struch: [B, 4] + # struct: [B, 4] + # density: [B, 3] — [door_level, monster_level, resource_level] sym_idx = struct[:, 0] room_idx = struct[:, 1] @@ -83,7 +105,16 @@ class GinkaMaskGIT(nn.Module): # VQ 码字与结构标签语义不同,使用各自独立的投影层后再拼接 z_mem_vq = self.z_proj(z) # [B, L, d_model] z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model] - z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L * 3 + 4, d_model] + + # 密度条件:三个 embedding 相加后经独立 MLP 得到单个条件 token + e_density = ( + self.door_density_embed(density[:, 0]) + + self.monster_density_embed(density[:, 1]) + + self.resource_density_embed(density[:, 2]) + ) # [B, d_z] + density_token = self.density_mlp(e_density).unsqueeze(1) # [B, 1, d_model] + + z_mem = torch.cat([z_mem_vq, z_mem_struct, density_token], dim=1) # [B, L*3+5, d_model] # tile embedding + 位置编码 row_idx = torch.arange(self.map_h, device=map.device).repeat_interleave(self.map_w) diff --git a/ginka/train_seperated.py b/ginka/train_seperated.py index eb6703e..6588756 100644 --- a/ginka/train_seperated.py +++ b/ginka/train_seperated.py @@ -44,24 +44,24 @@ VQ_BETA = 0.5 # commit loss 权重(防止编码器输出漂离 codebook) VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用) VQ_LAYERS = 3 # VQ-VAE Transformer 层数 VQ_DIM_FF = 512 # VQ-VAE 前馈网络隐层维度 -VQ_D_MODEL = 64 # VQ-VAE Transformer 模型维度 -VQ_NHEAD = 8 # VQ-VAE 多头注意力头数 +VQ_D_MODEL = 128 # VQ-VAE Transformer 模型维度 +VQ_NHEAD = 4 # VQ-VAE 多头注意力头数 # 第一阶段 MaskGIT 超参 -STAGE1_MG_DMODEL = 192 -STAGE1_MG_NHEAD = 8 +STAGE1_MG_DMODEL = 256 +STAGE1_MG_NHEAD = 4 STAGE1_MG_NUM_LAYERS = 6 STAGE1_MG_DIM_FF = 1024 # 第二阶段 MaskGIT 超参 -STAGE2_MG_DMODEL = 192 -STAGE2_MG_NHEAD = 8 +STAGE2_MG_DMODEL = 256 +STAGE2_MG_NHEAD = 4 STAGE2_MG_NUM_LAYERS = 6 STAGE2_MG_DIM_FF = 1024 # 第三阶段 MaskGIT 超参 -STAGE3_MG_DMODEL = 192 -STAGE3_MG_NHEAD = 8 +STAGE3_MG_DMODEL = 256 +STAGE3_MG_NHEAD = 4 STAGE3_MG_NUM_LAYERS = 6 STAGE3_MG_DIM_FF = 1024 @@ -178,10 +178,18 @@ def random_struct(device: torch.device) -> torch.Tensor: cond_outer = random.randint(0, 1) # 是否有外围走廊 return torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]).unsqueeze(0).to(device) +def random_density(device: torch.device) -> torch.Tensor: + # 随机采样一组密度参量,用于自由生成 + # density_inject 格式:[door_level(0-2), monster_level(0-2), resource_level(0-2)] + door_lv = random.randint(0, 2) + monster_lv = random.randint(0, 2) + resource_lv = random.randint(0, 2) + return torch.LongTensor([door_lv, monster_lv, resource_lv]).unsqueeze(0).to(device) + def maskgit_sample( model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor, - struct: torch.Tensor, steps: int, target_tiles: list[int] | None = None, - keep_fixed: bool = True + struct: torch.Tensor, density: torch.Tensor, steps: int, + target_tiles: list[int] | None = None, keep_fixed: bool = True ) -> np.ndarray: # target_tiles: 本阶段负责生成的图块 ID 列表;None 表示接受所有类别(stage1) # keep_fixed=True:锁定输入中已有的非掩码/非空地位,使上一阶段结构保持不变 @@ -198,7 +206,7 @@ def maskgit_sample( # 迭代去掩码:每步根据置信度分数重新决定掩码位置 for step in range(steps): - logits = model(current, z, struct) + logits = model(current, z, struct, density) probs = F.softmax(logits, dim=-1) dist = torch.distributions.Categorical(probs) @@ -264,7 +272,7 @@ def maskgit_sample( # 目标模式下,未被填充的位置视为空地(不属于本阶段负责的图块) current[0, still_masked] = 0 else: - logits = model(current, z, struct) + logits = model(current, z, struct, density) current[0, still_masked] = torch.argmax(logits[0, still_masked], dim=-1) return current[0].cpu().numpy().reshape(MAP_H, MAP_W) @@ -272,6 +280,7 @@ def maskgit_sample( def full_generate_random_z( input: torch.Tensor, struct: torch.Tensor, + density: torch.Tensor, models: list[torch.nn.Module], device: torch.device, keep_fixed: tuple[bool, bool, bool] = (True, True, True) @@ -282,13 +291,13 @@ def full_generate_random_z( z = quantizer.sample(1, VQ_L, device) # stage1:生成 floor/wall 骨架 - pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0]) + pred1_np = maskgit_sample(mg1, input.clone(), z, struct, density, GENERATE_STEP, keep_fixed=keep_fixed[0]) inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充 # stage2:在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并 pred2_np = maskgit_sample( - mg2, inp2, z, struct, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1] + mg2, inp2, z, struct, density, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1] ) merged12 = pred1_np.copy() merged12[pred2_np != 0] = pred2_np[pred2_np != 0] @@ -297,7 +306,7 @@ def full_generate_random_z( # stage3:填充 resource(3) pred3_np = maskgit_sample( - mg3, inp3, z, struct, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2] + mg3, inp3, z, struct, density, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2] ) merged123 = merged12.copy() merged123[pred3_np != 0] = pred3_np[pred3_np != 0] @@ -308,6 +317,7 @@ def full_generate_specific_z( input: torch.Tensor, z: torch.Tensor, struct: torch.Tensor, + density: torch.Tensor, models: list[torch.nn.Module], device: torch.device, keep_fixed: tuple[bool, bool, bool] = (True, True, True) @@ -316,12 +326,12 @@ def full_generate_specific_z( with torch.no_grad(): # 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z - pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0]) + pred1_np = maskgit_sample(mg1, input.clone(), z, struct, density, GENERATE_STEP, keep_fixed=keep_fixed[0]) inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) inp2[inp2 == 0] = MASK_TOKEN pred2_np = maskgit_sample( - mg2, inp2, z, struct, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1] + mg2, inp2, z, struct, density, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1] ) merged12 = pred1_np.copy() merged12[pred2_np != 0] = pred2_np[pred2_np != 0] @@ -329,7 +339,7 @@ def full_generate_specific_z( inp3[inp3 == 0] = MASK_TOKEN pred3_np = maskgit_sample( - mg3, inp3, z, struct, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2] + mg3, inp3, z, struct, density, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2] ) merged123 = merged12.copy() merged123[pred3_np != 0] = pred3_np[pred3_np != 0] @@ -343,6 +353,23 @@ def annotate(img: np.ndarray, text: str) -> np.ndarray: cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) return img +def annotate_labels( + img: np.ndarray, + struct: torch.Tensor, + density: torch.Tensor +) -> np.ndarray: + # 两行标注:第一行结构标签,第二行密度标签 + lv = ['Low', 'Medium', 'High'] + s = struct.tolist() + d = density.tolist() + line1 = f"sym:{s[0]} room:{lv[s[1]]} branch:{lv[s[2]]} outer:{s[3]}" + line2 = f"door:{lv[d[0]]} enemy:{lv[d[1]]} res:{lv[d[2]]}" + img = img.copy() + for text, y in [(line1, 12), (line2, 24)]: + cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 0), 2) + cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1) + return img + def rand_keep() -> tuple[bool, bool, bool]: b = random.choice([True, False]) return (b, b, b) @@ -404,23 +431,29 @@ def visualize_part2(batch, z_q, models, device, tile_dict): inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE) struct_t = batch["struct_inject"][0:1].to(device) + density_t = batch["density_inject"][0:1].to(device) kf = rand_keep() auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z( - inp1_t, z_q[0:1], struct_t, models, device, keep_fixed=kf + inp1_t, z_q[0:1], struct_t, density_t, models, device, keep_fixed=kf ) kf_label = 'fix' if kf[0] else 'free' - label1 = f"s1:{kf_label}" - label2 = f"s2:{kf_label}" - label3 = f"s3:{kf_label}" enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W) enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W) enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W) inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W) + struct_cpu = batch["struct_inject"][0] + density_cpu = batch["density_inject"][0] + rows = [ [to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)], - [to_img(inp1_np), annotate(to_img(auto_pred1_np), label1), annotate(to_img(auto_merged12), label2), annotate(to_img(auto_merged123), label3)], + [ + annotate(to_img(inp1_np), kf_label), + annotate_labels(to_img(auto_pred1_np), struct_cpu, density_cpu), + annotate_labels(to_img(auto_merged12), struct_cpu, density_cpu), + annotate_labels(to_img(auto_merged123), struct_cpu, density_cpu) + ], ] grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255 for r, row in enumerate(rows): @@ -442,19 +475,24 @@ def visualize_part3(batch, models, device, tile_dict): inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE) struct_ref = batch["struct_inject"][0:1].to(device) + density_ref = batch["density_inject"][0:1].to(device) inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W) + struct_cpu = batch["struct_inject"][0] + density_cpu = batch["density_inject"][0] row1 = [to_img(inp1_np)] for _ in range(2): kf = rand_keep() - _, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device, keep_fixed=kf) - row1.append(annotate(to_img(merged123), keep_label(kf))) + _, _, merged123 = full_generate_random_z(inp1_t, struct_ref, density_ref, models, device, keep_fixed=kf) + row1.append(annotate_labels(to_img(merged123), struct_cpu, density_cpu)) row2 = [] for _ in range(3): kf = rand_keep() - _, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device, keep_fixed=kf) - row2.append(annotate(to_img(merged123), keep_label(kf))) + rnd_struct = random_struct(device) + rnd_density = random_density(device) + _, _, merged123 = full_generate_random_z(inp1_t, rnd_struct, rnd_density, models, device, keep_fixed=kf) + row2.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_density[0].cpu())) rows = [row1, row2] grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 @@ -484,8 +522,10 @@ def visualize_part4(models, device, tile_dict): results = [] for _ in range(5): kf = rand_keep() - _, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device, keep_fixed=kf) - results.append(annotate(to_img(merged123), keep_label(kf))) + rnd_struct = random_struct(device) + rnd_density = random_density(device) + _, _, merged123 = full_generate_random_z(seed, rnd_struct, rnd_density, models, device, keep_fixed=kf) + results.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_density[0].cpu())) row1 = [to_img(seed_np)] + results[:2] row2 = results[2:] @@ -507,6 +547,74 @@ def visualize_validate( cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict)) cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict)) cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict)) + cv2.imwrite(f"{save_dir}/dvar{batch_idx}.png", visualize_density_var(batch, z_q, models, device, tile_dict)) + +# 密度对照图:随机种子+随机结构,5 张随机密度生成,2×3 网格(左上角为种子图) +def visualize_density_cmp(models, device, tile_dict): + SEP = 3 + TILE_SIZE = 32 + img_h = MAP_H * TILE_SIZE + img_w = MAP_W * TILE_SIZE + + def to_img(mat): + return matrix_to_image_cv(mat, tile_dict, TILE_SIZE) + + n_walls = random.randint(math.floor(MAP_SIZE * 0.02), math.floor(MAP_SIZE * 0.06)) + seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device) + wall_pos = torch.randperm(MAP_SIZE, device=device)[:n_walls] + seed[0, wall_pos] = 1 + seed_np = seed[0].cpu().numpy().reshape(MAP_H, MAP_W) + rnd_struct = random_struct(device) + struct_cpu = rnd_struct[0].cpu() + gen_imgs = [] + for _ in range(5): + rnd_density = random_density(device) + density_cpu = rnd_density[0].cpu() + _, _, merged123 = full_generate_random_z(seed, rnd_struct, rnd_density, models, device) + gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, density_cpu)) + row1 = [to_img(seed_np)] + gen_imgs[:2] + row2 = gen_imgs[2:] + rows = [row1, row2] + grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 + for r, row in enumerate(rows): + for c, img in enumerate(row): + y = SEP + r * (img_h + SEP) + x = SEP + c * (img_w + SEP) + grid[y:y + img_h, x:x + img_w] = img + return grid + +# 固定 z 和结构条件,使用 5 个随机密度各生成一次,2×3 网格(左上角为参考地图) +def visualize_density_var(batch, z_q, models, device, tile_dict): + SEP = 3 + TILE_SIZE = 32 + img_h = MAP_H * TILE_SIZE + img_w = MAP_W * TILE_SIZE + + def to_img(mat): + return matrix_to_image_cv(mat, tile_dict, TILE_SIZE) + + inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE) + struct_t = batch["struct_inject"][0:1].to(device) + struct_cpu = batch["struct_inject"][0] + ref_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W) + gen_imgs = [] + for _ in range(5): + rnd_density = random_density(device) + density_cpu = rnd_density[0].cpu() + _, _, merged123 = full_generate_specific_z( + inp1_t, z_q[0:1], struct_t, rnd_density, models, device + ) + gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, density_cpu)) + row1 = [to_img(ref_np)] + gen_imgs[:2] + row2 = gen_imgs[2:] + rows = [row1, row2] + grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 + for r, row in enumerate(rows): + for c, img in enumerate(row): + y = SEP + r * (img_h + SEP) + x = SEP + c * (img_w + SEP) + grid[y:y + img_h, x:x + img_w] = img + return grid def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int): vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models @@ -521,10 +629,21 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc loss3_total = torch.Tensor([0]).to(device) commit_total = torch.Tensor([0]).to(device) + # 按档位(0/1/2)累计实体计数差(L1),用于诊断密度条件可控性 + # 结构:{tile_id: {level: [累计误差, 样本数]}} + density_l1 = { + 2: {0: [0.0, 0], 1: [0.0, 0], 2: [0.0, 0]}, # door + 4: {0: [0.0, 0], 1: [0.0, 0], 2: [0.0, 0]}, # monster + 3: {0: [0.0, 0], 1: [0.0, 0], 2: [0.0, 0]}, # resource + } + # 三类实体对应的 density_inject 索引 + tile_density_idx = {2: 0, 4: 1, 3: 2} + idx = 0 with torch.no_grad(): for batch in tqdm(dataloader, leave=False, desc="Validate Progress", disable=disable_tqdm): + # 三阶段各自的掩码输入、预测目标和 VQ 编码器输入 inp1 = batch["input_stage1"].to(device).reshape(-1, MAP_SIZE) target1 = batch["target_stage1"].to(device).reshape(-1, MAP_SIZE) @@ -539,6 +658,7 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE) struct = batch["struct_inject"].to(device) + density = batch["density_inject"].to(device) # VQ 编码:各阶段独立编码后拼接、量化 z_e1 = vq1(enc1) # [B, L, d_z] @@ -548,24 +668,56 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z] z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z] - # 三阶段 MaskGIT 推理(均以完整 z_q 和 struct 为条件) - logits1 = mg1(inp1, z_q, struct) - logits2 = mg2(inp2, z_q, struct) - logits3 = mg3(inp3, z_q, struct) + # 三阶段 MaskGIT 推理(均以完整 z_q、struct 和 density 为条件) + logits1 = mg1(inp1, z_q, struct, density) + logits2 = mg2(inp2, z_q, struct, density) + logits3 = mg3(inp3, z_q, struct, density) loss1_total += focal_loss(logits1, target1) loss2_total += focal_loss(logits2, target2) loss3_total += focal_loss(logits3, target3) commit_total += commit_loss + # 计算 argmax 预测并统计各档位密度 L1(预测计数与真实计数之差的绝对值) + pred2_map = torch.argmax(logits2, dim=-1).cpu() # [B, MAP_SIZE] + pred3_map = torch.argmax(logits3, dim=-1).cpu() + true2_map = target2.cpu() # [B, MAP_SIZE] + true3_map = target3.cpu() + density_cpu = batch["density_inject"] # [B, 3] + for b in range(pred2_map.size(0)): + for tile_id, d_idx in tile_density_idx.items(): + if tile_id == 3: + pred_map = pred3_map[b] + true_map = true3_map[b] + else: + pred_map = pred2_map[b] + true_map = true2_map[b] + pred_count = float((pred_map == tile_id).sum().item()) + true_count = float((true_map == tile_id).sum().item()) + lv = int(density_cpu[b, d_idx].item()) + density_l1[tile_id][lv][0] += abs(pred_count - true_count) + density_l1[tile_id][lv][1] += 1 + # 每个 batch 生成三种可视化图(val/full/rand) visualize_validate(batch, logits1, logits2, logits3, z_q, models, device, tile_dict, epoch, idx) idx += 1 - # 每个 epoch 额外生成一张无条件自由生成图(不依赖任何 batch 样本) + # 输出密度 L1 统计(各档位的平均实体计数,供诊断密度条件效果) + lv_names = ['Low', 'Medium', 'High'] + tile_names = {2: 'door', 4: 'enemy', 3: 'resource'} + for tile_id in [2, 4, 3]: + parts = [] + for lv in range(3): + acc, cnt = density_l1[tile_id][lv] + avg = acc / cnt if cnt > 0 else 0.0 + parts.append(f"{lv_names[lv]}={avg:.2f}") + tqdm.write(f" density {tile_names[tile_id]}: {' '.join(parts)}") + save_dir = f"result/seperated/e{epoch}" os.makedirs(save_dir, exist_ok=True) + # 每个 epoch 额外生成:无条件自由生成图 + 全局密度对照图 cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict)) + cv2.imwrite(f"{save_dir}/density_cmp.png", visualize_density_cmp(models, device, tile_dict)) # 恢复训练模式 for m in [vq1, vq2, vq3, mg1, mg2, mg3]: @@ -659,6 +811,7 @@ def train(device: torch.device): # 结构条件向量:[cond_sym, cond_room, cond_branch, cond_outer] struct = batch["struct_inject"].to(device) + density = batch["density_inject"].to(device) optimizer.zero_grad() @@ -671,10 +824,10 @@ def train(device: torch.device): z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z] z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z] - # 三阶段 MaskGIT 前向(均接收完整三阶段 z_q) - logits1 = mg1(inp1, z_q, struct) - logits2 = mg2(inp2, z_q, struct) - logits3 = mg3(inp3, z_q, struct) + # 三阶段 MaskGIT 前向(均接收完整三阶段 z_q、struct 和 density 条件) + logits1 = mg1(inp1, z_q, struct, density) + logits2 = mg2(inp2, z_q, struct, density) + logits3 = mg3(inp3, z_q, struct, density) # 三阶段 Focal Loss + VQ commit loss 加权求和 loss1 = focal_loss(logits1, target1)