From 98e17342480279c00ed978f202545db56eabfeef Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sat, 16 May 2026 14:16:40 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AF=86=E5=BA=A6=E6=A0=87=E7=AD=BE?= =?UTF-8?q?=E6=94=B9=E4=B8=BA=E5=8F=82=E6=95=B0=E9=87=8F=EF=BC=8C=E5=8E=BB?= =?UTF-8?q?=E9=99=A4=E6=88=BF=E9=97=B4=E6=95=B0=E5=92=8C=E5=88=86=E6=94=AF?= =?UTF-8?q?=E6=95=B0=E6=A0=87=E7=AD=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/cond-simplify-design.md | 311 +++++++++++++++++++++++++++++++++++ ginka/dataset.py | 71 ++++---- ginka/maskGIT/model.py | 64 +++---- ginka/train_seperated.py | 51 +++--- 4 files changed, 387 insertions(+), 110 deletions(-) create mode 100644 docs/cond-simplify-design.md diff --git a/docs/cond-simplify-design.md b/docs/cond-simplify-design.md new file mode 100644 index 0000000..80aa05e --- /dev/null +++ b/docs/cond-simplify-design.md @@ -0,0 +1,311 @@ +# 条件简化与密度连续化设计文档 + +## 背景 + +当前三阶段级联生成模型的条件系统存在以下问题: + +1. **结构条件中的房间数和分支数对生成指导意义有限**:这两个指标依赖数据集中预计算的离散分档,与实际生成质量的相关性较弱,且分档边界处噪声大,容易引入无效条件信号。 + +2. **实体密度条件(门/怪物/资源)的离散三档存在明显一对多问题**:三档划分过于粗糙,同一档内样本分布差异极大(例如 Medium 档中资源数可以从 2 到 8 不等),导致模型无法建立条件与生成结果之间的精确映射。连续值能够更精确地描述目标密度,避免档位内分布散乱导致的条件信号模糊。 + +## 改动总览 + +| 模块 | 改动类型 | 说明 | +| -------------------------- | -------- | -------------------------------------------------------------- | +| `ginka/dataset.py` | 修改 | 删除房间/分支分档;密度改为连续归一化;输出 FloatTensor | +| `ginka/maskGIT/model.py` | 修改 | 删除房间/分支嵌入;密度嵌入层改为线性投影;更新 cond_proj 维度 | +| `ginka/train_seperated.py` | 修改 | 更新 random_struct/random_density;更新 annotate_labels | + +--- + +## 一、条件向量格式变更 + +### 1.1 struct_inject + +**当前格式**(4 个离散整数): + +``` +[cond_sym(0-7), cond_room(0-2), cond_branch(0-2), cond_outer(0-1)] +``` + +**新格式**(2 个离散整数,删除 room 和 branch): + +``` +[cond_sym(0-7), cond_outer(0-1)] +``` + +`cond_sym` 的计算方式不变(水平/垂直/中心对称的三位二进制组合,0–7),`cond_outer` 不变。 + +### 1.2 density_inject + +**当前格式**(3 个离散整数,`LongTensor`): + +``` +[door_level(0-2), monster_level(0-2), resource_level(0-2)] +``` + +**新格式**(3 个连续浮点数,`FloatTensor`,值域 [0, 1]): + +``` +[door_norm, monster_norm, resource_norm] ∈ [0.0, 1.0]^3 +``` + +--- + +## 二、密度归一化方案 + +### 2.1 统计量定义 + +在训练集初始化阶段,对原始地图统计三类图块的实际数量(非密度,直接计数): + +- `door_count = 图块ID为2的数量` +- `monster_count = 图块ID为4的数量` +- `resource_count = 图块ID为3的数量` + +对每类分别求训练集内的 **最小值** 和 **最大值**: + +``` +door_min, door_max +monster_min, monster_max +resource_min, resource_max +``` + +### 2.2 归一化公式 + +对每个样本的 count,归一化为 [0, 1]: + +$$ +\text{norm}(x) = \frac{x - x_{\min}}{x_{\max} - x_{\min} + \epsilon} +$$ + +其中 $\epsilon = 1\text{e-}6$,防止分母为零(当所有样本计数相同时)。 + +结果裁剪到 [0, 1]:`norm = clamp(norm, 0.0, 1.0)`。 + +### 2.3 验证集复用训练集统计量 + +`GinkaSeperatedDataset` 新增参数: + +```python +def __init__( + self, + data_path: str, + subset_weights: tuple = (0.5, 0.3, 0.2), + density_stats: dict | None = None # 新增:外部传入统计量 +): +``` + +- 训练集:`density_stats=None`,自行计算并保存 `min/max` 到 `self.density_stats` +- 验证集:传入训练集的 `self.density_stats`,直接复用,保证归一化语义一致 + +`density_stats` 的结构: + +```python +{ + "door_min": float, "door_max": float, + "monster_min": float, "monster_max": float, + "resource_min": float, "resource_max": float, +} +``` + +### 2.4 输出字段变更 + +`__getitem__` 中 `density_inject` 由 `LongTensor` 改为 `FloatTensor`: + +```python +# 删除旧的离散分档逻辑 +density_inject = torch.FloatTensor([ + self.norm_density(count_door, "door"), + self.norm_density(count_monster, "monster"), + self.norm_density(count_resource, "resource"), +]) +``` + +删除以下字段(不再写入 item 也不再输出): + +- `doorDensityLevel`, `monsterDensityLevel`, `resourceDensityLevel` +- `roomCountLevel`, `branchLevel` + +删除以下实例变量: + +- `self.room_th`, `self.branch_th` +- `self.door_density_th`, `self.monster_density_th`, `self.resource_density_th` + +--- + +## 三、模型结构变更(`model.py`) + +### 3.1 删除房间/分支嵌入 + +删除: + +```python +self.room_embed = nn.Embedding(ROOM_VOCAB, d_z) +self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z) +``` + +保留: + +```python +self.sym_embed = nn.Embedding(SYM_VOCAB, d_z) # SYM_VOCAB = 8 +self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) # OUTER_VOCAB = 2 +``` + +删除的常量:`ROOM_VOCAB`, `BRANCH_VOCAB`,保留 `SYM_VOCAB`, `OUTER_VOCAB`。 + +### 3.2 密度嵌入层改为线性投影 + +删除: + +```python +self.door_density_embed = nn.Embedding(DOOR_DENSITY_VOCAB, d_z) +self.monster_density_embed = nn.Embedding(MONSTER_DENSITY_VOCAB, d_z) +self.resource_density_embed = nn.Embedding(RESOURCE_DENSITY_VOCAB, d_z) +``` + +删除的常量:`DOOR_DENSITY_VOCAB`, `MONSTER_DENSITY_VOCAB`, `RESOURCE_DENSITY_VOCAB`。 + +新增: + +```python +# 连续密度投影:将 3 个归一化浮点数映射为 1 个 d_z 维 token +self.density_proj = nn.Linear(3, d_z) +``` + +### 3.3 cond_proj 维度更新 + +**当前 cond_seq 形状**:`[B, z_seq_len + 4_struct + 3_density, d_z]`,即 `[B, z_seq_len+7, d_z]`,展平后输入维度 `(z_seq_len+7) * d_z`。 + +**新 cond_seq 形状**:`[B, z_seq_len + 2_struct + 1_density, d_z]`,即 `[B, z_seq_len+3, d_z]`,展平后输入维度 `(z_seq_len+3) * d_z`。 + +```python +# 旧 +self.cond_proj = nn.Linear((z_seq_len + 7) * d_z, d_model) +# 新 +self.cond_proj = nn.Linear((z_seq_len + 3) * d_z, d_model) +``` + +### 3.4 forward 流程变更 + +```python +def forward( + self, + map: torch.Tensor, + z: torch.Tensor, + struct: torch.Tensor, # [B, 2] ← 由 [B, 4] 改为 [B, 2] + density: torch.Tensor # [B, 3] float ← 由 [B, 3] long 改为 float +) -> torch.Tensor: + + # 结构标签:sym + outer,各嵌入为 d_z 维 token + e_struct = torch.stack([ + self.sym_embed(struct[:, 0]), # [B, d_z] + self.outer_embed(struct[:, 1]), # [B, d_z] + ], dim=1) # [B, 2, d_z] + + # 密度:连续值投影为单个 d_z 维 token + e_density = self.density_proj(density).unsqueeze(1) # [B, 1, d_z] + + # z:逐 token 投影(不变) + z_proj = self.z_proj(z) # [B, z_seq_len, d_z] + + # 拼接 → [B, z_seq_len+3, d_z] → 展平 → 投影到 d_model + cond_seq = torch.cat([z_proj, e_struct, e_density], dim=1) + c = self.cond_proj(cond_seq.reshape(cond_seq.size(0), -1)) # [B, d_model] + + # 后续不变(tile embedding + Transformer + output_fc) +``` + +--- + +## 四、训练脚本变更(`train_seperated.py`) + +### 4.1 random_struct + +```python +def random_struct(device: torch.device) -> torch.Tensor: + # struct_inject 格式:[cond_sym(0-7), cond_outer(0-1)] + cond_sym = random.randint(0, 7) # 地图对称类型 + cond_outer = random.randint(0, 1) # 是否有外围走廊 + return torch.LongTensor([cond_sym, cond_outer]).unsqueeze(0).to(device) +``` + +### 4.2 random_density + +```python +def random_density(device: torch.device) -> torch.Tensor: + # density_inject 格式:[door_norm, monster_norm, resource_norm] ∈ [0, 1] + return torch.rand(1, 3, device=device) # 均匀分布随机采样 +``` + +### 4.3 annotate_labels + +更新标注格式,删除 room/branch,密度显示为两位小数: + +```python +def annotate_labels( + img: np.ndarray, + struct: torch.Tensor, # [2] long + density: torch.Tensor # [3] float +) -> np.ndarray: + s = struct.tolist() + d = density.tolist() + line1 = f"sym:{s[0]} outer:{s[1]}" + line2 = f"door:{d[0]:.2f} enemy:{d[1]:.2f} res:{d[2]:.2f}" + img = img.copy() + for text, y in [(line1, 12), (line2, 24)]: + cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 0), 2) + cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1) + return img +``` + +### 4.4 训练集与验证集初始化 + +```python +train_dataset = GinkaSeperatedDataset(args.train, subset_weights=SUBSET_WEIGHTS) +val_dataset = GinkaSeperatedDataset( + args.validate, + subset_weights=SUBSET_WEIGHTS, + density_stats=train_dataset.density_stats # 复用训练集统计量 +) +``` + +### 4.5 DataLoader collate_fn(FloatTensor 适配) + +PyTorch 默认 collate 会自动将 FloatTensor 列表合并为 float 类型批张量,无需额外修改 DataLoader 配置。 + +### 4.6 验证阶段密度对照图(density_var) + +`visualize_density_var` 内对比不同密度条件时,改为使用 5 个均匀分布采样点: + +```python +# 旧(三档枚举):density_levels = [0, 1, 2] +# 新(连续采样):5 个均匀间隔值 +density_values = [0.0, 0.25, 0.5, 0.75, 1.0] +for v in density_values: + d = torch.FloatTensor([[v, v, v]]).to(device) # 三类等密度扫描 + ... +``` + +--- + +## 五、不需要改动的部分 + +- `ginka/maskGIT/maskGIT.py`:AdaLN / CondTransformerLayer / Transformer 均不感知条件维度,无需修改 +- `ginka/vqvae/` 目录:VQ-VAE 部分与条件系统无关 +- `ginka/train_seperated.py` 中的 `maskgit_sample`、`full_generate_random_z`、`full_generate_specific_z`:接口签名不变(仍接受 struct/density 张量),内部无直接操作条件内容,无需修改 +- `data/` 目录的 TypeScript 数据处理脚本:数据文件格式不变,Python 端自行计算标签 + +--- + +## 六、旧 checkpoint 兼容性 + +由于 `cond_proj` 输入维度和嵌入层数量均发生变化,**旧 checkpoint 不兼容**,需从头训练。 + +--- + +## 七、实施顺序 + +1. 修改 `ginka/dataset.py`:删除 room/branch 分档,新增密度归一化和 `density_stats` 参数 +2. 修改 `ginka/maskGIT/model.py`:删除多余嵌入,新增 `density_proj`,更新 `cond_proj` 维度和 `forward` +3. 修改 `ginka/train_seperated.py`:更新 `random_struct`、`random_density`、`annotate_labels`、数据集初始化 +4. 运行小规模过拟合测试(单 batch 跑 50 步)验证前向通路无误 diff --git a/ginka/dataset.py b/ginka/dataset.py index fa74a20..dcc8859 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -33,49 +33,35 @@ class GinkaSeperatedDataset(Dataset): def __init__( self, data_path: str, - subset_weights: tuple = (0.5, 0.3, 0.2) + subset_weights: tuple = (0.5, 0.3, 0.2), + density_stats: dict | None = None ): self.data = load_data(data_path) total = sum(subset_weights) self.subset_cumw = [sum(subset_weights[:i+1]) / total for i in range(len(subset_weights))] - n = len(self.data) - rs = sorted(item['roomCount'] for item in self.data) - bs = sorted(item['highDegBranchCount'] for item in self.data) - th1_r, th2_r = rs[n // 3], rs[2 * n // 3] - th1_b, th2_b = bs[n // 3], bs[2 * n // 3] - if th1_r == th2_r: th2_r = th1_r + 1 - if th1_b == th2_b: th2_b = th1_b + 1 - self.room_th = (th1_r, th2_r) - self.branch_th = (th1_b, th2_b) + # 实体密度连续归一化:统计门/怪物/资源的数量,用 min/max 归一化到 [0, 1] + # density_stats 为 None 时自行计算(训练集),否则复用外部传入的统计量(验证集) + if density_stats is None: + door_counts = [self.count_tile(item['map'], self.DOOR) for item in self.data] + monster_counts = [self.count_tile(item['map'], self.MONSTER) for item in self.data] + resource_counts = [self.count_tile(item['map'], self.RESOURCE) for item in self.data] + self.density_stats = { + "door_min": float(min(door_counts)), + "door_max": float(max(door_counts)), + "monster_min": float(min(monster_counts)), + "monster_max": float(max(monster_counts)), + "resource_min": float(min(resource_counts)), + "resource_max": float(max(resource_counts)), + } + else: + self.density_stats = density_stats - for item in self.data: - item['roomCountLevel'] = self.to_level(item['roomCount'], self.room_th) - item['branchLevel'] = self.to_level(item['highDegBranchCount'], self.branch_th) - - # 实体密度等级:统计原始地图中门/怪物/资源的数量,等频三档 + def norm_density(self, count: int, key: str) -> float: eps = 1e-6 - door_counts = sorted(self.count_tile(item['map'], self.DOOR) for item in self.data) - monster_counts = sorted(self.count_tile(item['map'], self.MONSTER) for item in self.data) - resource_counts = sorted(self.count_tile(item['map'], self.RESOURCE) for item in self.data) - th1_d, th2_d = door_counts[n // 3], door_counts[2 * n // 3] - th1_m, th2_m = monster_counts[n // 3], monster_counts[2 * n // 3] - th1_rc, th2_rc = resource_counts[n // 3], resource_counts[2 * n // 3] - if th1_d == th2_d: th2_d = th1_d + eps - if th1_m == th2_m: th2_m = th1_m + eps - if th1_rc == th2_rc: th2_rc = th1_rc + eps - self.door_density_th = (th1_d, th2_d) - self.monster_density_th = (th1_m, th2_m) - self.resource_density_th = (th1_rc, th2_rc) - - for item in self.data: - m = item['map'] - item['doorDensityLevel'] = self.to_level(self.count_tile(m, self.DOOR), self.door_density_th) - item['monsterDensityLevel'] = self.to_level(self.count_tile(m, self.MONSTER), self.monster_density_th) - item['resourceDensityLevel'] = self.to_level(self.count_tile(m, self.RESOURCE), self.resource_density_th) - - def to_level(self, v, th): - return 0 if v < th[0] else (1 if v < th[1] else 2) + lo = self.density_stats[f"{key}_min"] + hi = self.density_stats[f"{key}_max"] + return float(min(max((count - lo) / (hi - lo + eps), 0.0), 1.0)) def count_tile(self, map_data: list, tile_id: int) -> int: return sum(cell == tile_id for row in map_data for cell in row) @@ -193,15 +179,14 @@ class GinkaSeperatedDataset(Dataset): sym_h, sym_v, sym_c = compute_symmetry(map_np) cond_sym = sym_h * 4 + sym_v * 2 + sym_c - cond_room = item['roomCountLevel'] - cond_branch = item['branchLevel'] cond_outer = item['outerWall'] - struct_inject = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]) + struct_inject = torch.LongTensor([cond_sym, cond_outer]) - density_inject = torch.LongTensor([ - item['doorDensityLevel'], - item['monsterDensityLevel'], - item['resourceDensityLevel'] + m = item['map'] + density_inject = torch.FloatTensor([ + self.norm_density(self.count_tile(m, self.DOOR), "door"), + self.norm_density(self.count_tile(m, self.MONSTER), "monster"), + self.norm_density(self.count_tile(m, self.RESOURCE), "resource"), ]) return { diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index 2f5764c..411a198 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -6,15 +6,8 @@ from .maskGIT import Transformer # 结构标签词表大小 SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-7 -ROOM_VOCAB = 3 # roomCountLevel 0-2 -BRANCH_VOCAB = 3 # branchLevel 0-2 OUTER_VOCAB = 2 # outerWall 0-1 -# 密度标签词表大小(Low/Medium/High 三档) -DOOR_DENSITY_VOCAB = 3 -MONSTER_DENSITY_VOCAB = 3 -RESOURCE_DENSITY_VOCAB = 3 - class GinkaMaskGIT(nn.Module): def __init__( self, num_classes: int = 16, d_model: int = 192, dim_ff: int = 512, @@ -30,23 +23,18 @@ class GinkaMaskGIT(nn.Module): self.row_embedding = nn.Parameter(torch.randn(1, map_h, d_model) * 0.02) self.col_embedding = nn.Parameter(torch.randn(1, map_w, d_model) * 0.02) - # 结构标签嵌入:各自独立嵌入到 d_z 维度,作为独立 token + # 结构标签嵌入:sym(0-7)和 outer(0-1),各自独立嵌入到 d_z 维度 self.sym_embed = nn.Embedding(SYM_VOCAB, d_z) - self.room_embed = nn.Embedding(ROOM_VOCAB, d_z) - self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z) self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) - # 密度标签嵌入:各自独立嵌入到 d_z 维度,作为独立 token - self.door_density_embed = nn.Embedding(DOOR_DENSITY_VOCAB, d_z) - self.monster_density_embed = nn.Embedding(MONSTER_DENSITY_VOCAB, d_z) - self.resource_density_embed = nn.Embedding(RESOURCE_DENSITY_VOCAB, d_z) + # 密度连续投影:将 3 个归一化浮点数 [door, monster, resource] ∈ [0,1] 投影为 d_z 维 token + self.density_proj = nn.Linear(3, d_z) # z 投影:逐 token 线性变换,保持序列结构 self.z_proj = nn.Linear(d_z, d_z) - # 条件融合投影:将 (z_seq_len + 4 + 3) 个 d_z 维 token 拼接后降维到 d_model - # 拼接顺序:z_seq_len 个 z token + 4 个结构 token + 3 个密度 token - self.cond_proj = nn.Linear((z_seq_len + 7) * d_z, d_model) + # 条件融合投影:z_seq_len 个 z token + 2 个结构 token + 1 个密度 token + self.cond_proj = nn.Linear((z_seq_len + 3) * d_z, d_model) # 纯 encoder Transformer,条件向量 c 通过 AdaLN 注入每一层 self.transformer = Transformer( @@ -64,28 +52,22 @@ class GinkaMaskGIT(nn.Module): ) -> torch.Tensor: # map: [B, H * W] # z: [B, z_seq_len, d_z] - # struct: [B, 4] - # density: [B, 3] — [door_level, monster_level, resource_level] + # struct: [B, 2] — [cond_sym(0-7), cond_outer(0-1)] + # density: [B, 3] float — [door_norm, monster_norm, resource_norm] ∈ [0, 1] - # 结构标签:各自嵌入为独立 token,stack 成序列 [B, 4, d_z] + # 结构标签:sym + outer,各自嵌入为独立 token,stack 成序列 [B, 2, d_z] e_struct = torch.stack([ self.sym_embed(struct[:, 0]), - self.room_embed(struct[:, 1]), - self.branch_embed(struct[:, 2]), - self.outer_embed(struct[:, 3]) + self.outer_embed(struct[:, 1]) ], dim=1) - # 密度标签:各自嵌入为独立 token,stack 成序列 [B, 3, d_z] - e_density = torch.stack([ - self.door_density_embed(density[:, 0]), - self.monster_density_embed(density[:, 1]), - self.resource_density_embed(density[:, 2]) - ], dim=1) + # 密度:连续浮点向量投影为单个 d_z 维 token,[B, 1, d_z] + e_density = self.density_proj(density).unsqueeze(1) # z:逐 token 投影,保留序列结构 [B, z_seq_len, d_z] z_proj = self.z_proj(z) - # 拼接所有条件 token → [B, z_seq_len+7, d_z],展平后投影到 d_model + # 拼接所有条件 token → [B, z_seq_len+3, d_z],展平后投影到 d_model cond_seq = torch.cat([z_proj, e_struct, e_density], dim=1) c = self.cond_proj(cond_seq.reshape(cond_seq.size(0), -1)) # [B, d_model] @@ -107,17 +89,17 @@ if __name__ == "__main__": map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [4, 169] z_input = torch.randn(4, 6, 64).to(device) # [4, L*3, 64] struct_input = torch.tensor([ - [3, 1, 0, 1], - [0, 2, 1, 0], - [5, 1, 2, 1], - [1, 0, 1, 0], - ], dtype=torch.long).to(device) # [4, 4] + [3, 1], + [0, 0], + [5, 1], + [1, 0], + ], dtype=torch.long).to(device) # [4, 2] — [cond_sym, cond_outer] density_input = torch.tensor([ - [0, 1, 2], - [2, 0, 1], - [1, 2, 0], - [0, 0, 1], - ], dtype=torch.long).to(device) # [4, 3] + [0.1, 0.5, 0.9], + [0.8, 0.2, 0.4], + [0.3, 0.7, 0.0], + [0.6, 0.1, 1.0], + ], dtype=torch.float).to(device) # [4, 3] — [door_norm, monster_norm, resource_norm] model = GinkaMaskGIT( num_classes=7, @@ -142,8 +124,8 @@ if __name__ == "__main__": print(f"推理耗时: {end - start:.4f}s") print(f"输出形状: logits={logits.shape}") print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") - print(f"Struct Projection parameters: {sum(p.numel() for p in model.struct_proj.parameters())}") print(f"Density Projection parameters: {sum(p.numel() for p in model.density_proj.parameters())}") + print(f"Cond Projection parameters: {sum(p.numel() for p in model.cond_proj.parameters())}") print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}") print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}") print(f"Output FC parameters: {sum(p.numel() for p in model.output_fc.parameters())}") diff --git a/ginka/train_seperated.py b/ginka/train_seperated.py index 4231db2..e5f5827 100644 --- a/ginka/train_seperated.py +++ b/ginka/train_seperated.py @@ -174,20 +174,15 @@ def focal_loss(logits, target): def random_struct(device: torch.device) -> torch.Tensor: # 随机采样一组结构参量,用于无条件自由生成 - # struct_inject 格式:[cond_sym(0-7), cond_room(0-2), cond_branch(0-2), cond_outer(0-1)] - cond_sym = random.randint(0, 7) # 地图对称类型 - cond_room = random.randint(0, 2) # 房间数量档位 - cond_branch = random.randint(0, 2) # 分支复杂度档位 - cond_outer = random.randint(0, 1) # 是否有外围走廊 - return torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]).unsqueeze(0).to(device) + # struct_inject 格式:[cond_sym(0-7), cond_outer(0-1)] + cond_sym = random.randint(0, 7) # 地图对称类型 + cond_outer = random.randint(0, 1) # 是否有外围走廈 + return torch.LongTensor([cond_sym, cond_outer]).unsqueeze(0).to(device) def random_density(device: torch.device) -> torch.Tensor: # 随机采样一组密度参量,用于自由生成 - # density_inject 格式:[door_level(0-2), monster_level(0-2), resource_level(0-2)] - door_lv = random.randint(0, 2) - monster_lv = random.randint(0, 2) - resource_lv = random.randint(0, 2) - return torch.LongTensor([door_lv, monster_lv, resource_lv]).unsqueeze(0).to(device) + # density_inject 格式:[door_norm, monster_norm, resource_norm] ∈ [0, 1] + return torch.rand(1, 3, device=device) def maskgit_sample( model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor, @@ -361,12 +356,11 @@ def annotate_labels( struct: torch.Tensor, density: torch.Tensor ) -> np.ndarray: - # 两行标注:第一行结构标签,第二行密度标签 - lv = ['Low', 'Medium', 'High'] + # 两行标注:第一行结构标签,第二行密度连续小数 s = struct.tolist() d = density.tolist() - line1 = f"sym:{s[0]} room:{lv[s[1]]} branch:{lv[s[2]]} outer:{s[3]}" - line2 = f"door:{lv[d[0]]} enemy:{lv[d[1]]} res:{lv[d[2]]}" + line1 = f"sym:{s[0]} outer:{s[1]}" + line2 = f"door:{d[0]:.2f} enemy:{d[1]:.2f} res:{d[2]:.2f}" img = img.copy() for text, y in [(line1, 12), (line2, 24)]: cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 0), 2) @@ -601,11 +595,13 @@ def visualize_density_var(batch, z_q, models, device, tile_dict): struct_cpu = batch["struct_inject"][0] ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W) gen_imgs = [] - for _ in range(5): - rnd_density = random_density(device) - density_cpu = rnd_density[0].cpu() + # 固定 z 和结构条件,展开 5 个均匀密度水平对比不同密度条件的生成效果 + density_values = [0.0, 0.25, 0.5, 0.75, 1.0] + for v in density_values: + fixed_density = torch.FloatTensor([[v, v, v]]).to(device) + density_cpu = fixed_density[0].cpu() _, _, merged123 = full_generate_specific_z( - inp1_t, z_q[0:1], struct_t, rnd_density, models, device + inp1_t, z_q[0:1], struct_t, fixed_density, models, device ) gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, density_cpu)) row1 = [to_img(ref_np)] + gen_imgs[:2] @@ -632,8 +628,9 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc loss3_total = torch.Tensor([0]).to(device) commit_total = torch.Tensor([0]).to(device) - # 按档位(0/1/2)累计实体计数差(L1),用于诊断密度条件可控性 - # 结构:{tile_id: {level: [累计误差, 样本数]}} + # 按小模块(Low/Medium/High)累计实体计数差(L1),用于诊断密度条件可控性 + # 密度连续小数按 [0,1/3)/[1/3,2/3)/[2/3,1] 分框到三模块 + # 结构:{tile_id: {bucket: [累计误差, 样本数]}} density_l1 = { 2: {0: [0.0, 0], 1: [0.0, 0], 2: [0.0, 0]}, # door 4: {0: [0.0, 0], 1: [0.0, 0], 2: [0.0, 0]}, # monster @@ -697,7 +694,8 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc true_map = true2_map[b] pred_count = float((pred_map == tile_id).sum().item()) true_count = float((true_map == tile_id).sum().item()) - lv = int(density_cpu[b, d_idx].item()) + # 连续密度按 [0,1/3)/[1/3,2/3)/[2/3,1] 分戆到三模块 + lv = min(int(density_cpu[b, d_idx].item() * 3), 2) density_l1[tile_id][lv][0] += abs(pred_count - true_count) density_l1[tile_id][lv][1] += 1 @@ -705,15 +703,15 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc visualize_validate(batch, logits1, logits2, logits3, z_q, models, device, tile_dict, epoch, idx) idx += 1 - # 输出密度 L1 统计(各档位的平均实体计数,供诊断密度条件效果) - lv_names = ['Low', 'Medium', 'High'] + # 输出密度 L1 统计(各小模块内的平均实体计数,供诊断密度条件效果) + bucket_names = ['Low', 'Medium', 'High'] tile_names = {2: 'door', 4: 'enemy', 3: 'resource'} for tile_id in [2, 4, 3]: parts = [] for lv in range(3): acc, cnt = density_l1[tile_id][lv] avg = acc / cnt if cnt > 0 else 0.0 - parts.append(f"{lv_names[lv]}={avg:.2f}") + parts.append(f"{bucket_names[lv]}={avg:.2f}") tqdm.write(f" density {tile_names[tile_id]}: {' '.join(parts)}") save_dir = f"result/seperated/e{epoch}" @@ -777,7 +775,8 @@ def train(device: torch.device): ) dataset_val = GinkaSeperatedDataset( - args.validate, subset_weights=SUBSET_WEIGHTS + args.validate, subset_weights=SUBSET_WEIGHTS, + density_stats=dataset.density_stats # 复用训练集统计量,保证归一化语义一致 ) dataloader_val = DataLoader( dataset_val, batch_size=min(BATCH_SIZE, len(dataset_val) // 8), shuffle=True