From abbad781abdeacd33fa47499c7db1b3c80181595 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 27 Apr 2026 14:56:21 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=B0=91=E6=95=B0?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E6=80=A7=E6=A0=87=E7=AD=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot --- data/src/auto.ts | 10 +- data/src/auto/info.ts | 195 +++++++++++- data/src/auto/types.ts | 14 + data/src/types.ts | 9 + docs/dataset-labels-design.md | 248 +++++++++++++++ docs/dataset-labels-impl.md | 563 ++++++++++++++++++++++++++++++++++ docs/vqvae-maskgit-design.md | 2 +- ginka/dataset.py | 69 ++++- ginka/maskGIT/model.py | 108 +++++-- ginka/train_vq.py | 44 +-- 10 files changed, 1212 insertions(+), 50 deletions(-) create mode 100644 docs/dataset-labels-design.md create mode 100644 docs/dataset-labels-impl.md diff --git a/data/src/auto.ts b/data/src/auto.ts index bae5855..515fe7c 100644 --- a/data/src/auto.ts +++ b/data/src/auto.ts @@ -361,7 +361,15 @@ const labelConfig: IAutoLabelConfig = { 0, 0, 0 - ] + ], + symmetry: [ + info.symmetryH ? 1 : 0, + info.symmetryV ? 1 : 0, + info.symmetryC ? 1 : 0 + ], + outerWall: info.outerWall ? 1 : 0, + roomCount: info.roomCount, + highDegBranchCount: info.highDegBranchCount }; dataset.data[id] = data; }); diff --git a/data/src/auto/info.ts b/data/src/auto/info.ts index 61c06ac..9f0796f 100644 --- a/data/src/auto/info.ts +++ b/data/src/auto/info.ts @@ -3,8 +3,10 @@ import { GraphNodeType, IAutoLabelConfig, IFloorInfo, + IMapGraph, IMapTileConverter, ITowerInfo, + MapGraphNode, TowerColor } from './types'; import { @@ -153,6 +155,180 @@ export function computeWallDensityStd( return Math.sqrt(variance); } +/** + * 计算地图的三种对称性(基于 convertedMap,完全匹配才标记为 true) + */ +function computeSymmetry(map: number[][]): { + symmetryH: boolean; + symmetryV: boolean; + symmetryC: boolean; +} { + const H = map.length; + const W = H > 0 ? map[0].length : 0; + let symmetryH = true; + let symmetryV = true; + let symmetryC = true; + + outer: for (let y = 0; y < H; y++) { + for (let x = 0; x < W; x++) { + const tile = map[y][x]; + if (symmetryH && tile !== map[y][W - 1 - x]) symmetryH = false; + if (symmetryV && tile !== map[H - 1 - y][x]) symmetryV = false; + if (symmetryC && tile !== map[H - 1 - y][W - 1 - x]) + symmetryC = false; + if (!symmetryH && !symmetryV && !symmetryC) break outer; + } + } + return { symmetryH, symmetryV, symmetryC }; +} + +/** + * 检测地图最外圈中墙壁 + 入口的占比是否超过 90% + * @param map convertedMap + * @param wall 墙壁图块编号 + * @param entry 入口图块编号 + */ +function computeOuterWall( + map: number[][], + wall: number, + entry: number +): boolean { + const H = map.length; + const W = H > 0 ? map[0].length : 0; + if (H < 2 || W < 2) return false; + + let borderCount = 0; + let wallOrEntry = 0; + + const check = (tile: number) => { + borderCount++; + if (tile === wall || tile === entry) wallOrEntry++; + }; + + for (let x = 0; x < W; x++) { + check(map[0][x]); + check(map[H - 1][x]); + } + for (let y = 1; y < H - 1; y++) { + check(map[y][0]); + check(map[y][W - 1]); + } + + return borderCount > 0 && wallOrEntry / borderCount > 0.9; +} + +/** + * 统计拓扑图中符合"房间"定义的连通区域数量。 + * + * 算法: + * 1. 以 Empty / Resource 节点为顶点,在它们之间 BFS, + * 得到若干"候选区域"(Branch 节点作为边界,不被合并)。 + * 2. 对每个候选区域检查三个条件: + * a. 区域内至少一个节点有 Branch 类型邻居 + * b. 区域内所有格子总数 >= 4 + * c. 所有格子的外接矩形宽 > 1 且高 > 1 + * + * @param graph 拓扑图 + * @param width 地图宽度(用于平坦坐标解码) + */ +function computeRoomCount(graph: IMapGraph, width: number): number { + const allEmptyResource = new Set(); + for (const node of graph.nodeMap.values()) { + if ( + node.type === GraphNodeType.Empty || + node.type === GraphNodeType.Resource + ) { + allEmptyResource.add(node); + } + } + + let roomCount = 0; + const visited = new Set(); + + for (const startNode of allEmptyResource) { + if (visited.has(startNode)) continue; + + const regionNodes = new Set(); + const queue: MapGraphNode[] = [startNode]; + visited.add(startNode); + + while (queue.length > 0) { + const current = queue.shift()!; + regionNodes.add(current); + for (const nb of current.neighbors) { + if ( + !visited.has(nb) && + (nb.type === GraphNodeType.Empty || + nb.type === GraphNodeType.Resource) + ) { + visited.add(nb); + queue.push(nb); + } + } + } + + // 条件 a:区域内任一节点有 Branch 邻居 + let hasBranch = false; + outer: for (const node of regionNodes) { + for (const nb of node.neighbors) { + if (nb.type === GraphNodeType.Branch) { + hasBranch = true; + break outer; + } + } + } + if (!hasBranch) continue; + + // 收集区域内所有格子,计算总数和外接矩形 + let totalTiles = 0; + let minX = Infinity, + maxX = -Infinity, + minY = Infinity, + maxY = -Infinity; + + for (const node of regionNodes) { + totalTiles += node.tiles.size; + for (const t of node.tiles) { + const x = t % width; + const y = (t - x) / width; + if (x < minX) minX = x; + if (x > maxX) maxX = x; + if (y < minY) minY = y; + if (y > maxY) maxY = y; + } + } + + // 条件 b:总格子数 >= 4 + if (totalTiles < 4) continue; + + // 条件 c:外接矩形宽高均 > 1 + if (maxX - minX < 1 || maxY - minY < 1) continue; + + roomCount++; + } + + return roomCount; +} + +/** + * 统计邻居数 >= 3 的分支节点数量(高连接度分支节点) + * @param graph 拓扑图 + */ +function computeHighDegBranchCount(graph: IMapGraph): number { + let count = 0; + const visited = new Set(); + + for (const node of graph.nodeMap.values()) { + if (visited.has(node)) continue; + visited.add(node); + + if (node.type === GraphNodeType.Branch && node.neighbors.size >= 3) { + count++; + } + } + return count; +} + /** * 根据地图矩阵解析出地图数据 * @param tower 地图所属塔信息 @@ -177,6 +353,17 @@ export function parseFloorInfo( ); const flattened = map.flat(); const area = flattened.length; + const width = map[0]?.length ?? 0; + + // ── 结构标签计算 ───────────────────────────────── + const { symmetryH, symmetryV, symmetryC } = computeSymmetry(map); + const outerWall = computeOuterWall( + map, + config.classes.wall, + config.classes.entry + ); + const roomCount = computeRoomCount(topo.graph, width); + const highDegBranchCount = computeHighDegBranchCount(topo.graph); let hasUselessBranch = false; @@ -283,7 +470,13 @@ export function parseFloorInfo( doorHeatmap: gaussainHeatmap( generateHeatmap(map, doorTiles, config.heatmapKernel), config.guassainRadius - ) + ), + symmetryH, + symmetryV, + symmetryC, + outerWall, + roomCount, + highDegBranchCount }; return floorInfo; diff --git a/data/src/auto/types.ts b/data/src/auto/types.ts index 4ec1a32..57137f5 100644 --- a/data/src/auto/types.ts +++ b/data/src/auto/types.ts @@ -112,6 +112,20 @@ export interface IFloorInfo { readonly entryHeatmap: number[][]; /** 门热力图 */ readonly doorHeatmap: number[][]; + + // ── 结构标签(新增)────────────────────────────── + /** 左右对称(基于 convertedMap 完全匹配) */ + readonly symmetryH: boolean; + /** 上下对称 */ + readonly symmetryV: boolean; + /** 中心对称 */ + readonly symmetryC: boolean; + /** 是否外包围墙壁(最外圈墙壁+入口占比 > 90%) */ + readonly outerWall: boolean; + /** 房间数量原始值(供 Python 两趟扫描使用) */ + readonly roomCount: number; + /** 高连接度分支节点数量原始值(供 Python 两趟扫描使用) */ + readonly highDegBranchCount: number; } export interface IMapBlockConfig { diff --git a/data/src/types.ts b/data/src/types.ts index 9d11c97..9cac4ef 100644 --- a/data/src/types.ts +++ b/data/src/types.ts @@ -48,6 +48,15 @@ export interface GinkaTrainData { map: number[][]; size: [number, number]; heatmap?: number[][][]; + // ── 结构标签(新增)────────────────────────────── + /** 对称性:[symmetryH, symmetryV, symmetryC],0 或 1 */ + symmetry: [number, number, number]; + /** 是否外包围墙壁,0 或 1 */ + outerWall: number; + /** 房间数量原始值 */ + roomCount: number; + /** 高连接度分支节点数量原始值 */ + highDegBranchCount: number; } export interface GinkaDataset { diff --git a/docs/dataset-labels-design.md b/docs/dataset-labels-design.md new file mode 100644 index 0000000..979b79c --- /dev/null +++ b/docs/dataset-labels-design.md @@ -0,0 +1,248 @@ +# 数据集标签设计文档 + +## 背景 + +在当前 VQ-VAE + MaskGIT 的联合训练方案中,VQ-VAE 的 codebook 承担着地图风格与多样性的控制职能,但缺少可用户感知的语义维度。为提升生成的可控性,计划在数据集中添加一组**可程序化标注的结构标签**,作为额外条件输入训练,从而使模型能够接受来自用户的高层语义约束(如"生成一个左右对称的地图"、"高房间数量"等)。 + +所有标签均可在数据集构建阶段(`info.ts` / `parseFloorInfo`)或训练前的两趟扫描中自动计算,无需人工标注。 + +--- + +## 标签一览 + +| 标签名 | 类型 | 取值 | 标注时机 | +| ---------------- | --------- | ------------------- | -------- | +| `symmetryH` | `boolean` | 左右对称 | 单张地图 | +| `symmetryV` | `boolean` | 上下对称 | 单张地图 | +| `symmetryC` | `boolean` | 中心对称 | 单张地图 | +| `roomCountLevel` | `0\|1\|2` | Low / Medium / High | 两趟扫描 | +| `branchLevel` | `0\|1\|2` | Low / Medium / High | 两趟扫描 | +| `outerWall` | `boolean` | 是否外包围墙壁 | 单张地图 | + +--- + +## 标签一:对称性 + +### 定义 + +针对经过转换后的地图(`convertedMap`)矩阵逐格比较,仅当**完全满足条件**时才标记为对应对称类型。三种对称相互独立,可同时成立。 + +| 对称类型 | 条件(对所有 `x ∈ [0, W)`, `y ∈ [0, H)` 成立) | +| -------- | ---------------------------------------------- | +| 左右对称 | `map[y][x] === map[y][W - 1 - x]` | +| 上下对称 | `map[y][x] === map[H - 1 - y][x]` | +| 中心对称 | `map[y][x] === map[H - 1 - y][W - 1 - x]` | + +### 实现要点 + +- 比较使用 `convertedMap`(标签化图块编号),而非原始 `originMap`,使不同塔的同类图块具有可比性。 +- 中心对称与左右、上下对称在数学上不蕴含关系(非充分也非必要),需独立计算。 +- 对于奇数尺寸的地图(如 13×13),中心行/列与自身比较必然成立,无需特殊处理。 + +### 字段扩展(`IFloorInfo`) + +```typescript +/** 左右对称 */ +readonly symmetryH: boolean; +/** 上下对称 */ +readonly symmetryV: boolean; +/** 中心对称 */ +readonly symmetryC: boolean; +``` + +--- + +## 标签二:房间数量等级 + +### 定义 + +**房间(Room)** 是地图中由"空白节点"或"资源节点"组成的区域,同时满足以下三个条件: + +1. **位置条件**:该节点在拓扑图中至少与 **1 个分支节点(Branch)** 相邻; +2. **面积条件**:节点所包含的地图格子数(`tiles.size`)**≥ 4**; +3. **形状条件**:节点所有格子的**外接矩形的宽和高都大于 1**(避免把单行/单列走廊计入房间)。 + +> **为什么这样定义房间** +> +> 在拓扑图中,空白/资源节点天然是游戏空间的"腔体",而分支节点(门/怪物)是进入腔体的关卡节点。只要与至少一个分支节点相邻,就说明这片空间是需要"先过关才能进入/离开"的区域——包括怪物守着宝箱这类单入口房间,同样是典型的房间结构。面积和形状约束则过滤掉通道和死胡同。 + +### 等级划分 + +等级为 **三档**:Low(0)/ Medium(1)/ High(2),通过以下两趟扫描确定: + +1. **第一趟**:遍历整个训练集,计算每张地图的房间数量 `roomCount`,收集为数组; +2. **第二趟**:对数组升序排序,取 1/3 和 2/3 分位数作为阈值 `[th1, th2]`: + - `roomCount < th1` → Low(0) + - `th1 ≤ roomCount < th2` → Medium(1) + - `roomCount ≥ th2` → High(2) + +等级划分力求三档样本数量均等(等频分箱),而非等距分箱。 + +### 外接矩形计算 + +给定节点的所有格子坐标集合(`tiles`,存储 `y * width + x` 的平坦坐标),还原为 `(x, y)` 后: + +$$ +x_{\min} = \min_{t \in tiles}(t \bmod W), \quad x_{\max} = \max_{t \in tiles}(t \bmod W) +$$ + +$$ +y_{\min} = \min_{t \in tiles}(\lfloor t / W \rfloor), \quad y_{\max} = \max_{t \in tiles}(\lfloor t / W \rfloor) +$$ + +外接矩形宽 = $x_{\max} - x_{\min} + 1$,高 = $y_{\max} - y_{\min} + 1$,两者均需 $> 1$。 + +### 字段扩展 + +```typescript +/** 房间数量(原始统计值,供两趟扫描使用) */ +readonly roomCount: number; +/** 房间数量等级:0=Low, 1=Medium, 2=High(需两趟扫描后赋值) */ +roomCountLevel: 0 | 1 | 2; +``` + +--- + +## 标签三:分支数量等级 + +### 定义 + +**高连接度分支节点**:拓扑图中 `type === Branch` 的节点,其**非墙邻居节点**总数(`neighbors.size`)**≥ 3**。由于当前拓扑图中已不含墙节点,`neighbors.size` 即等于非墙邻居数,两者在实现上等价。 + +这类节点是地图中的"交叉口"——一个门或怪物后方至少有三条不同路线,是地图分叉度和策略深度的指征。 + +等级划分方式与房间数量等级相同:先统计每张地图中高连接度分支节点的数量,再等频分箱为 Low(0)/ Medium(1)/ High(2)。 + +### 与房间数量的区别 + +| 维度 | 房间数量等级 | 分支数量等级 | +| -------- | -------------------------- | ---------------------------- | +| 度量对象 | 空白/资源节点区域的封闭性 | 分支节点的路径分叉度 | +| 反映特征 | 地图内封闭房间的数量与密度 | 关键路口/多分支门怪的复杂度 | +| 典型高值 | 多房间迷宫风格地图 | 高度分叉、策略选择丰富的地图 | + +### 字段扩展 + +```typescript +/** 高连接度分支节点数量(原始统计值) */ +readonly highDegBranchCount: number; +/** 分支数量等级:0=Low, 1=Medium, 2=High(需两趟扫描后赋值) */ +branchLevel: 0 | 1 | 2; +``` + +--- + +## 标签四:外包围墙壁 + +### 定义 + +地图**最外圈**(最外一圈格子)的格子中,**墙壁格子与入口格子**之和占外圈总格子数的比例 > **90%**,则标记为 `outerWall = true`。 + +$$ +\text{outerWall} = \frac{|\{(x,y) \in \text{border} : \text{isWall}(x,y) \lor \text{isEntry}(x,y)\}|}{|\text{border}|} > 0.9 +$$ + +### 最外圈定义 + +对于 $H \times W$ 的地图,最外圈为所有满足下列条件之一的格子: + +$$ +x = 0 \; \lor \; x = W - 1 \; \lor \; y = 0 \; \lor \; y = H - 1 +$$ + +最外圈格子总数为 $2(H + W) - 4$(对于 13×13,共 48 格)。 + +### 为什么入口也算"通过" + +入口格子在游戏中是楼梯/传送点,不属于可通行的空地,在视觉和结构上等价于边界开口,属于外圈围合结构的合理组成部分,不应被视为"破坏围合"的元素。 + +### 实现要点 + +- 使用 `convertedMap` 判断墙壁(`tile === config.wall`); +- 使用 `originMap` + `converter.isEntry()` 或直接对 `convertedMap` 判断入口(`tile === config.entry`)判断入口; +- 两项合取计入分子。 + +### 字段扩展 + +```typescript +/** 是否外包围墙壁 */ +readonly outerWall: boolean; +``` + +--- + +## 实现方案 + +### 单张地图可直接计算的标签 + +以下标签可在 `parseFloorInfo` 中直接计算,加入 `IFloorInfo`: + +- `symmetryH`、`symmetryV`、`symmetryC` +- `outerWall` +- `roomCount`(原始值) +- `highDegBranchCount`(原始值) + +### 需要两趟扫描的等级标签 + +`roomCountLevel` 和 `branchLevel` 依赖全局分位数,须在数据集构建完成后进行二次处理: + +``` +第一趟:构建所有楼层的 IFloorInfo,写入 roomCount / highDegBranchCount +第二趟:收集所有楼层的原始值 → 计算 1/3, 2/3 分位 → 回填等级 +``` + +在 Python 训练侧,推荐方式: + +```python +# 在 Dataset.__init__ 中完成两趟计算 +counts = [item['roomCount'] for item in raw_data] +counts_sorted = sorted(counts) +th1 = counts_sorted[len(counts_sorted) // 3] +th2 = counts_sorted[2 * len(counts_sorted) // 3] + +for item in raw_data: + c = item['roomCount'] + item['roomCountLevel'] = 0 if c < th1 else (1 if c < th2 else 2) +``` + +同理处理 `branchLevel`。 + +> **注意**:分位数阈值应仅基于**训练集**统计,验证集 / 测试集使用相同的阈值映射,避免数据泄露。 + +--- + +## 训练集成 + +### 条件嵌入 + +将上述标签作为离散条件与 VQ-VAE 的 z 一同注入 MaskGIT: + +```python +# 对称性:三个独立布尔值,可合并为 0~7 的整数 cond_sym +cond_sym = symmetryH * 4 + symmetryV * 2 + symmetryC * 1 # [0, 7] + +# 房间等级:0 / 1 / 2 +cond_room = roomCountLevel + +# 分支等级:0 / 1 / 2 +cond_branch = branchLevel + +# 外包围墙壁:0 / 1 +cond_outer = int(outerWall) +``` + +每个条件通过独立的 `nn.Embedding` 映射为固定维度向量,与 VQ-VAE 的 z 序列沿序列维度拼接后,一同经 Cross-Attention 注入 MaskGIT。 + +### 条件 Dropout + +与 z dropout 类似,训练时以一定概率(如 10~20%)将部分或全部结构标签替换为"无条件"(null embedding),使模型在推理时支持条件缺省(CFG 风格)。 + +--- + +## 待细化事项 + +- [x] 90% 阈值是否合适?——**保持 90%**,如后续数据分布分析发现问题再调整。 +- [x] 房间定义中分支节点邻居数量——**改为至少 1 个**,覆盖怪物守宝箱的单入口房间场景。 +- [x] 分支等级邻居计数口径——**使用非墙邻居**(当前图结构中无墙节点,与 `neighbors.size` 等价)。 +- [x] 是否新增通道数量/路径长度标签——**暂不考虑**。 +- [x] 条件嵌入维度对齐——**各标签 Embedding 与 z 序列拼接后统一经 Cross-Attention 注入**。 diff --git a/docs/dataset-labels-impl.md b/docs/dataset-labels-impl.md new file mode 100644 index 0000000..fdb359e --- /dev/null +++ b/docs/dataset-labels-impl.md @@ -0,0 +1,563 @@ +# 数据集标签实现指南 + +## 总览 + +本文档描述将[标签设计文档](./dataset-labels-design.md)中定义的四类结构标签落地到代码中的具体实现步骤,涉及以下文件: + +| 文件 | 变更性质 | +| ------------------------ | ------------------------------------------- | +| `data/src/auto/types.ts` | 扩展 `IFloorInfo` 接口 | +| `data/src/types.ts` | 扩展 `GinkaTrainData` 接口 | +| `data/src/auto/info.ts` | 新增四个 helper 函数,修改 `parseFloorInfo` | +| `data/src/auto.ts` | 在数据集序列化时写入新字段 | +| `ginka/dataset.py` | 两趟扫描 + `__getitem__` 读取新标签 | + +--- + +## 第一步:扩展 TypeScript 类型 + +### `data/src/auto/types.ts` — 扩展 `IFloorInfo` + +在 `IFloorInfo` 接口的 `doorHeatmap` 字段之后追加以下字段: + +```typescript +// ── 结构标签(新增)────────────────────────────── +/** 左右对称(基于 convertedMap 完全匹配) */ +readonly symmetryH: boolean; +/** 上下对称 */ +readonly symmetryV: boolean; +/** 中心对称 */ +readonly symmetryC: boolean; +/** 是否外包围墙壁(最外圈墙壁+入口占比 > 90%) */ +readonly outerWall: boolean; +/** 房间数量原始值(供 Python 两趟扫描使用) */ +readonly roomCount: number; +/** 高连接度分支节点数量原始值(供 Python 两趟扫描使用) */ +readonly highDegBranchCount: number; +``` + +### `data/src/types.ts` — 扩展 `GinkaTrainData` + +在 `GinkaTrainData` 接口中追加序列化后写入 JSON 的字段: + +```typescript +export interface GinkaTrainData { + tag?: number[]; + val: number[]; + map: number[][]; + size: [number, number]; + heatmap?: number[][][]; + // ── 结构标签(新增)────────────────────────────── + /** 对称性:[symmetryH, symmetryV, symmetryC],0 或 1 */ + symmetry: [number, number, number]; + /** 是否外包围墙壁,0 或 1 */ + outerWall: number; + /** 房间数量原始值 */ + roomCount: number; + /** 高连接度分支节点数量原始值 */ + highDegBranchCount: number; +} +``` + +--- + +## 第二步:实现 Helper 函数(`data/src/auto/info.ts`) + +以下四个函数添加到 `computeWallDensityStd` 之后、`parseFloorInfo` 之前。 + +### 2.1 对称性计算 + +```typescript +/** + * 计算地图的三种对称性(基于 convertedMap,完全匹配才标记为 true) + */ +function computeSymmetry(map: number[][]): { + symmetryH: boolean; + symmetryV: boolean; + symmetryC: boolean; +} { + const H = map.length; + const W = H > 0 ? map[0].length : 0; + let symmetryH = true; + let symmetryV = true; + let symmetryC = true; + + outer: for (let y = 0; y < H; y++) { + for (let x = 0; x < W; x++) { + const tile = map[y][x]; + if (symmetryH && tile !== map[y][W - 1 - x]) symmetryH = false; + if (symmetryV && tile !== map[H - 1 - y][x]) symmetryV = false; + if (symmetryC && tile !== map[H - 1 - y][W - 1 - x]) + symmetryC = false; + // 三种对称均已排除,提前退出 + if (!symmetryH && !symmetryV && !symmetryC) break outer; + } + } + return { symmetryH, symmetryV, symmetryC }; +} +``` + +**复杂度**:O(H × W),最坏情况遍历全图一次,实际有短路优化。 + +**注意**:对于奇数尺寸地图(如 13×13),中心行/列的格子与自身比较恒为 true,不影响结果。 + +### 2.2 外包围墙壁检测 + +```typescript +/** + * 检测地图最外圈中墙壁 + 入口的占比是否超过 90% + * @param map convertedMap + * @param wall 墙壁图块编号 + * @param entry 入口图块编号 + */ +function computeOuterWall( + map: number[][], + wall: number, + entry: number +): boolean { + const H = map.length; + const W = H > 0 ? map[0].length : 0; + if (H < 2 || W < 2) return false; + + let borderCount = 0; + let wallOrEntry = 0; + + const check = (tile: number) => { + borderCount++; + if (tile === wall || tile === entry) wallOrEntry++; + }; + + // 顶行 + 底行 + for (let x = 0; x < W; x++) { + check(map[0][x]); + check(map[H - 1][x]); + } + // 左列 + 右列(排除角格,已由上面计入) + for (let y = 1; y < H - 1; y++) { + check(map[y][0]); + check(map[y][W - 1]); + } + + // 13×13 地图: borderCount = 2*(13+13)-4 = 48 + return borderCount > 0 && wallOrEntry / borderCount > 0.9; +} +``` + +### 2.3 房间数量统计 + +**核心思路**:拓扑图中 Empty 节点和 Resource 节点在物理上可能彼此相邻,共同构成一个连续的游戏空间(如 2×3 的房间内放了一个宝物)。不能逐节点独立判断,而应先通过 BFS 将相邻的 Empty/Resource 节点合并为一个"候选区域",再对整个区域判断三个条件。Branch 节点作为边界,不会被合并进区域。 + +```typescript +/** + * 统计拓扑图中符合"房间"定义的连通区域数量。 + * + * 算法: + * 1. 以 Empty / Resource 节点为顶点,在它们之间 BFS, + * 得到若干"候选区域"(Branch 节点作为边界,不被合并)。 + * 2. 对每个候选区域检查三个条件: + * a. 区域内至少一个节点有 Branch 类型邻居 + * b. 区域内所有格子总数 >= 4 + * c. 所有格子的外接矩形宽 > 1 且高 > 1 + * + * @param graph 拓扑图 + * @param width 地图宽度(用于平坦坐标解码) + */ +function computeRoomCount(graph: IMapGraph, width: number): number { + // Step 1: 收集所有 Empty 和 Resource 节点(去重) + const allEmptyResource = new Set(); + for (const node of graph.nodeMap.values()) { + if ( + node.type === GraphNodeType.Empty || + node.type === GraphNodeType.Resource + ) { + allEmptyResource.add(node); + } + } + + // Step 2: BFS 将相邻的 Empty/Resource 节点合并为连通区域 + let roomCount = 0; + const visited = new Set(); + + for (const startNode of allEmptyResource) { + if (visited.has(startNode)) continue; + + // BFS:只在 Empty/Resource 节点间传播,Branch 节点阻断 + const regionNodes = new Set(); + const queue: MapGraphNode[] = [startNode]; + visited.add(startNode); + + while (queue.length > 0) { + const current = queue.shift()!; + regionNodes.add(current); + for (const nb of current.neighbors) { + if ( + !visited.has(nb) && + (nb.type === GraphNodeType.Empty || + nb.type === GraphNodeType.Resource) + ) { + visited.add(nb); + queue.push(nb); + } + } + } + + // Step 3: 对合并后的区域检查三个条件 + + // 条件 a:区域内任一节点有 Branch 邻居 + let hasBranch = false; + outer: for (const node of regionNodes) { + for (const nb of node.neighbors) { + if (nb.type === GraphNodeType.Branch) { + hasBranch = true; + break outer; + } + } + } + if (!hasBranch) continue; + + // 收集区域内所有格子,计算总数和外接矩形 + let totalTiles = 0; + let minX = Infinity, + maxX = -Infinity, + minY = Infinity, + maxY = -Infinity; + + for (const node of regionNodes) { + totalTiles += node.tiles.size; + for (const t of node.tiles) { + const x = t % width; + const y = (t - x) / width; + if (x < minX) minX = x; + if (x > maxX) maxX = x; + if (y < minY) minY = y; + if (y > maxY) maxY = y; + } + } + + // 条件 b:总格子数 >= 4 + if (totalTiles < 4) continue; + + // 条件 c:外接矩形宽高均 > 1(即 maxX - minX >= 1 && maxY - minY >= 1) + if (maxX - minX < 1 || maxY - minY < 1) continue; + + roomCount++; + } + + return roomCount; +} +``` + +**边界情况讨论**: + +- **混合节点房间**:2×3 房间(5 个空地 + 1 个资源)→ 1 个 Empty 节点 + 1 个 Resource 节点,BFS 合并后总格子数=6,外接矩形 2×3,**计入房间**。 +- **纯条形走廊**:1×4 的 Empty 节点 → 外接矩形高=1,**不计入房间**。 +- **孤立资源死角**:Resource 节点的邻居全是 Entry 而无 Branch → 条件 a 不满足,**不计入房间**。 +- **两个相邻的 Empty 区域被 Branch 隔开**:BFS 不会跨越 Branch,各自单独判断。 + +### 2.4 高连接度分支节点统计 + +```typescript +/** + * 统计邻居数 >= 3 的分支节点数量 + * + * 由于拓扑图中已不含墙节点,neighbors.size 即等于非墙邻居数。 + * + * @param graph 拓扑图 + */ +function computeHighDegBranchCount(graph: IMapGraph): number { + let count = 0; + const visited = new Set(); + + for (const node of graph.nodeMap.values()) { + if (visited.has(node)) continue; + visited.add(node); + + if (node.type === GraphNodeType.Branch && node.neighbors.size >= 3) { + count++; + } + } + return count; +} +``` + +--- + +## 第三步:在 `parseFloorInfo` 中调用 + +在 `parseFloorInfo` 函数内,`topo` 构建完毕、`floorInfo` 对象构造之前,调用上述四个函数: + +```typescript +// ── 结构标签计算 ───────────────────────────────── +const width = map[0]?.length ?? 0; +const { symmetryH, symmetryV, symmetryC } = computeSymmetry(map); +const outerWall = computeOuterWall( + map, + config.classes.wall, + config.classes.entry +); +const roomCount = computeRoomCount(topo.graph, width); +const highDegBranchCount = computeHighDegBranchCount(topo.graph); +``` + +然后在 `floorInfo` 字面量中追加这些字段: + +```typescript +const floorInfo: IFloorInfo = { + // ...(原有字段保持不变) + symmetryH, + symmetryV, + symmetryC, + outerWall, + roomCount, + highDegBranchCount +}; +``` + +--- + +## 第四步:在 `data/src/auto.ts` 中序列化新字段 + +在构建 `GinkaTrainData` 的对象字面量中追加: + +```typescript +const data: GinkaTrainData = { + map: floor.data.map, + size: [width, height], + heatmap: [ + /* ...原有热力图通道... */ + ], + val: [ + /* ...原有标量... */ + ], + // ── 新增结构标签 ────────────────────────────── + symmetry: [ + info.symmetryH ? 1 : 0, + info.symmetryV ? 1 : 0, + info.symmetryC ? 1 : 0 + ], + outerWall: info.outerWall ? 1 : 0, + roomCount: info.roomCount, + highDegBranchCount: info.highDegBranchCount +}; +``` + +布尔值存为 `0/1` 整数,便于 Python 侧直接读取,不需要类型转换。 + +--- + +## 第五步:Python 侧两趟扫描(`ginka/dataset.py`) + +### 5.1 等频分箱函数 + +在 `dataset.py` 顶部添加通用的等频分箱 helper: + +```python +def assign_level(values: list[int]) -> list[int]: + """ + 将整数列表按等频分箱映射为 0/1/2 三档等级。 + 分位数阈值基于当前列表计算(训练集与验证集应共用训练集的阈值)。 + + Args: + values: 原始统计值列表,顺序与数据集一一对应 + + Returns: + 与输入等长的等级列表,每项为 0 / 1 / 2 + """ + n = len(values) + if n == 0: + return [] + sorted_vals = sorted(values) + th1 = sorted_vals[n // 3] + th2 = sorted_vals[2 * n // 3] + return [ + 0 if v < th1 else (1 if v < th2 else 2) + for v in values + ] +``` + +### 5.2 修改 `GinkaVQDataset.__init__` + +在加载数据之后立即进行两趟扫描,并将等级结果回填到各 item 中: + +```python +class GinkaVQDataset(Dataset): + def __init__( + self, + data_path: str, + subset_weights: tuple = (0.5, 0.2, 0.2, 0.1), + wall_mask_min: float = 0.0, + wall_mask_max: float = 0.5, + # 以下两个参数仅在 train 模式下使用,验证集传入训练集的阈值即可 + room_thresholds: tuple[int, int] | None = None, + branch_thresholds: tuple[int, int] | None = None, + ): + self.data = load_data(data_path) + # ...(原有初始化逻辑) + + # ── 两趟扫描:计算等频分箱阈值 ────────────────────────────── + room_counts = [item['roomCount'] for item in self.data] + branch_counts = [item['highDegBranchCount'] for item in self.data] + + if room_thresholds is None: + # 训练集:自行计算阈值 + n = len(room_counts) + rs = sorted(room_counts) + bs = sorted(branch_counts) + self.room_th = (rs[n // 3], rs[2 * n // 3]) + self.branch_th = (bs[n // 3], bs[2 * n // 3]) + else: + # 验证集/测试集:直接使用训练集的阈值,避免数据泄露 + self.room_th = room_thresholds + self.branch_th = branch_thresholds + + def to_level(v: int, th: tuple[int, int]) -> int: + return 0 if v < th[0] else (1 if v < th[1] else 2) + + # 回填等级字段 + for item in self.data: + item['roomCountLevel'] = to_level(item['roomCount'], self.room_th) + item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th) +``` + +**调用方式**(`train_vq.py` 中): + +```python +dataset_train = GinkaVQDataset(args.train) +dataset_val = GinkaVQDataset( + args.validate, + room_thresholds=dataset_train.room_th, + branch_thresholds=dataset_train.branch_th +) +``` + +### 5.3 `__getitem__` 中读取结构标签 + +对称性标签在数据增强(旋转/翻转)后需要**重新从增强后的地图中计算**,因为 `rot90(k=1/3)` 会交换 `symmetryH` 和 `symmetryV`。其他三个标签在旋转/翻转下保持不变,可直接读取。 + +```python +def _compute_symmetry(target_np: np.ndarray) -> tuple[int, int, int]: + """从 numpy 地图矩阵中直接计算三种对称性,O(H*W)""" + H, W = target_np.shape + sym_h = bool(np.all(target_np == target_np[:, ::-1])) + sym_v = bool(np.all(target_np == target_np[::-1, :])) + sym_c = bool(np.all(target_np == target_np[::-1, ::-1])) + return int(sym_h), int(sym_v), int(sym_c) +``` + +在 `__getitem__` 数据增强完成后,读取所有标签: + +```python +def __getitem__(self, idx): + # ...(原有增强逻辑,target_np 已经过 rot90 / flip) + + # 对称性:在增强后重新计算 + sym_h, sym_v, sym_c = _compute_symmetry(target_np) + cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7] + + # 其余标签:增强不改变拓扑结构,直接读取 + item = self.data[idx] + cond_room = item['roomCountLevel'] # 0/1/2 + cond_branch = item['branchLevel'] # 0/1/2 + cond_outer = item['outerWall'] # 0/1 + + # 封装为 tensor + struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]) + + return { + "raw_map": ..., + "masked_map": ..., + "target_map": ..., + "subset": ..., + "struct_cond": struct_cond # [4],供模型 Embedding 查表 + } +``` + +--- + +## 第六步:模型侧条件注入(`ginka/maskGIT/model.py`) + +`struct_cond` 的四个维度分别对应不同词表大小的 Embedding: + +```python +# 词表大小 +SYM_VOCAB = 8 # cond_sym: [0, 7] +ROOM_VOCAB = 4 # cond_room: [0, 2] + 1 个 null(dropout 用) +BRANCH_VOCAB = 4 # cond_branch: [0, 2] + 1 个 null +OUTER_VOCAB = 3 # cond_outer: [0, 1] + 1 个 null + +# 在 GinkaMaskGIT.__init__ 中 +self.sym_embed = nn.Embedding(SYM_VOCAB, d_z) +self.room_embed = nn.Embedding(ROOM_VOCAB, d_z) +self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z) +self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) +``` + +在 `forward` 中,将四个 Embedding 结果与 VQ-VAE 的 z 序列拼接,作为 Cross-Attention 的 memory: + +```python +def forward(self, map_tokens, z, struct_cond, dropout_struct=False): + # z: [B, L, d_z] + # struct_cond: [B, 4] — [sym, room, branch, outer] + + B = z.size(0) + + if dropout_struct or (self.training and torch.rand(1) < self.struct_dropout_prob): + # 条件 dropout:全部替换为 null index(各词表最后一个 index) + e_sym = self.sym_embed(torch.full((B,), SYM_VOCAB - 1, device=z.device)) + e_room = self.room_embed(torch.full((B,), ROOM_VOCAB - 1, device=z.device)) + e_branch = self.branch_embed(torch.full((B,), BRANCH_VOCAB - 1, device=z.device)) + e_outer = self.outer_embed(torch.full((B,), OUTER_VOCAB - 1, device=z.device)) + else: + e_sym = self.sym_embed(struct_cond[:, 0]) # [B, d_z] + e_room = self.room_embed(struct_cond[:, 1]) + e_branch = self.branch_embed(struct_cond[:, 2]) + e_outer = self.outer_embed(struct_cond[:, 3]) + + # 将四个结构标签嵌入拼接为序列,与 z 合并 + # 每个 e_* 形状为 [B, d_z],unsqueeze 后变为 [B, 1, d_z] + struct_seq = torch.stack( + [e_sym, e_room, e_branch, e_outer], dim=1 + ) # [B, 4, d_z] + + memory = torch.cat([z, struct_seq], dim=1) # [B, L+4, d_z] + + # 后续 Cross-Attention 正常进行 + # query = map_token_embeddings,key/value = memory + ... +``` + +**null index 规则**:各 Embedding 的最后一个 index 保留为"无条件"占位符,词表大小因此比有效类别数多 1。 + +--- + +## 边界情况与注意事项 + +### 关于两趟扫描的时机 + +两趟扫描在 `Dataset.__init__` 中完成,整个过程仅需遍历两次 Python 列表,耗时可忽略不计(数千条数据 < 1ms)。不建议延迟到 `__getitem__` 中逐条计算。 + +### 关于对称性在数据增强下的变化 + +| 增强操作 | `symmetryH` | `symmetryV` | `symmetryC` | +| ----------------- | ----------- | ----------- | ----------- | +| `fliplr` | 不变 | 不变 | 不变 | +| `flipud` | 不变 | 不变 | 不变 | +| `rot90(k=1 or 3)` | 与 V 交换 | 与 H 交换 | 不变 | +| `rot90(k=2)` | 不变 | 不变 | 不变 | + +因此在 `__getitem__` 中对增强后的地图**重新计算**对称性是最简洁正确的方案,无需记录增强历史。 + +### 关于极端分布 + +若训练集中某一标签的样本极度不均匀(如 90% 的地图无对称性),可以在条件 Dropout 中对不常见的条件值适当降低 dropout 概率,以确保模型充分学习该条件。初始阶段统一使用相同 dropout 概率即可,后续根据生成效果调整。 + +### 关于 `roomCount = 0` 和 `highDegBranchCount = 0` 的地图 + +这类地图在等频分箱后会进入 Low(0)等级。如果训练集中大量地图的值为 0,`th1` 可能也为 0,导致 Low 等级极少。可以在分箱前加一步检查:若 `th1 == th2`,则手动将 `th2 = th1 + 1` 以避免 Medium 等级为空。 + +```python +th1 = sorted_vals[n // 3] +th2 = sorted_vals[2 * n // 3] +if th1 == th2: + th2 = th1 + 1 # 防止 Medium 等级为空 +``` diff --git a/docs/vqvae-maskgit-design.md b/docs/vqvae-maskgit-design.md index a2b505e..49cb91c 100644 --- a/docs/vqvae-maskgit-design.md +++ b/docs/vqvae-maskgit-design.md @@ -376,7 +376,7 @@ $$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{comm ## 待探索事项 -- 合适的 K、L 取值(建议从 K=16, L=2 开始实验) +- 合适的 K、L 取值(建议从 K=16, L=2 开始实验),K=1, L=64 可能比较合适。 - z dropout 的最优概率 - 若后续 codebook collapse:引入 EMA 更新 + 重置机制 - 若后续需要更细粒度控制:加入标量 cond(需对推理侧标量做随机扰动处理) diff --git a/ginka/dataset.py b/ginka/dataset.py index f79debf..d5fe027 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -232,6 +232,14 @@ class GinkaJointDataset(Dataset): } +def _compute_symmetry(target_np: np.ndarray) -> tuple: + """从 numpy 地图矩阵中直接计算三种对称性,O(H*W)""" + sym_h = bool(np.all(target_np == target_np[:, ::-1])) + sym_v = bool(np.all(target_np == target_np[::-1, :])) + sym_c = bool(np.all(target_np == target_np[::-1, ::-1])) + return int(sym_h), int(sym_v), int(sym_c) + + class GinkaVQDataset(Dataset): """ 用于 VQ-VAE + MaskGIT 联合训练的多子集数据集。 @@ -259,13 +267,17 @@ class GinkaVQDataset(Dataset): data_path: str, subset_weights: tuple = (0.5, 0.2, 0.2, 0.1), wall_mask_ratio: float = 0.3, + room_thresholds: tuple = None, + branch_thresholds: tuple = None, ): """ Args: - data_path: JSON 数据文件路径 - subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化 - wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限 - (每次从 [0, wall_mask_ratio] 均匀采样实际比例) + data_path: JSON 数据文件路径 + subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化 + wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限 + (每次从 [0, wall_mask_ratio] 均匀采样实际比例) + room_thresholds: (th1, th2) 房间数量等频分箱阈值;为 None 时自动从当前数据计算(训练集) + branch_thresholds: (th1, th2) 分支数量等频分箱阈值;为 None 时自动从当前数据计算(训练集) """ self.data = load_data(data_path) self.wall_mask_ratio = wall_mask_ratio @@ -275,6 +287,35 @@ class GinkaVQDataset(Dataset): normalized = [x / total_w for x in subset_weights] self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))] + # ── 两趟扫描:计算等频分箱阈值 ────────────────────────────── + room_counts = [item['roomCount'] for item in self.data] + branch_counts = [item['highDegBranchCount'] for item in self.data] + + if room_thresholds is None: + n = len(room_counts) + rs = sorted(room_counts) + bs = sorted(branch_counts) + th1_r, th2_r = rs[n // 3], rs[2 * n // 3] + th1_b, th2_b = bs[n // 3], bs[2 * n // 3] + # 防止 Medium 等级为空 + if th1_r == th2_r: + th2_r = th1_r + 1 + if th1_b == th2_b: + th2_b = th1_b + 1 + self.room_th = (th1_r, th2_r) + self.branch_th = (th1_b, th2_b) + else: + self.room_th = room_thresholds + self.branch_th = branch_thresholds + + def to_level(v: int, th: tuple) -> int: + return 0 if v < th[0] else (1 if v < th[1] else 2) + + # 回填等级字段 + for item in self.data: + item['roomCountLevel'] = to_level(item['roomCount'], self.room_th) + item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th) + def __len__(self): return len(self.data) @@ -407,11 +448,23 @@ class GinkaVQDataset(Dataset): masked_np = self._apply_subset(raw_np, subset) # [H*W] raw_flat = raw_np.reshape(-1) # [H*W] + # 对称性:在增强后重新计算 + sym_h, sym_v, sym_c = _compute_symmetry(raw_np) + cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7] + + # 其余结构标签:增强不改变拓扑结构,直接读取 + cond_room = item['roomCountLevel'] # 0/1/2 + cond_branch = item['branchLevel'] # 0/1/2 + cond_outer = item['outerWall'] # 0/1 + + struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]) + return { - "raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入 - "masked_map": torch.LongTensor(masked_np), # MaskGIT 输入 - "target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth - "subset": subset, # 调试/统计用 + "raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入 + "masked_map": torch.LongTensor(masked_np), # MaskGIT 输入 + "target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth + "subset": subset, # 调试/统计用 + "struct_cond": struct_cond, # [4],供模型 Embedding 查表 } diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index 67b9364..3d8f2f8 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -4,6 +4,12 @@ import torch.nn as nn from ..utils import print_memory from .maskGIT import Transformer +# 结构标签词表大小(最后一个索引为无条件占位符 null) +SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-6,7 = null +ROOM_VOCAB = 4 # roomCountLevel 0-2,3 = null +BRANCH_VOCAB = 4 # branchLevel 0-2,3 = null +OUTER_VOCAB = 3 # outerWall 0-1,2 = null + class GinkaMaskGIT(nn.Module): """ @@ -26,20 +32,24 @@ class GinkaMaskGIT(nn.Module): num_layers: int = 4, map_size: int = 13 * 13, z_dropout: float = 0.1, + struct_dropout: float = 0.15, ): """ Args: - num_classes: tile 类别数(含 MASK token=15) - d_model: Transformer 内部维度 - d_z: VQ-VAE 码字嵌入维度,需与 GinkaVQVAE.d_z 一致 - dim_ff: 前馈网络隐层维度 - nhead: 注意力头数 - num_layers: Transformer 层数 - map_size: 地图 token 总数(H * W) - z_dropout: 训练时随机替换 z 为随机码字的概率(提升鲁棒性) + num_classes: tile 类别数(含 MASK token=15) + d_model: Transformer 内部维度 + d_z: VQ-VAE 码字嵌入维度,需与 GinkaVQVAE.d_z 一致 + dim_ff: 前馈网络隐层维度 + nhead: 注意力头数 + num_layers: Transformer 层数 + map_size: 地图 token 总数(H * W) + z_dropout: 训练时随机替换 z 为随机码字的概率(提升鲁棒性) + struct_dropout: 训练时以此概率将结构标签替换为 null(无条件占位), + 实现 classifier-free guidance 兼容训练 """ super().__init__() self.z_dropout = z_dropout + self.struct_dropout_prob = struct_dropout # Tile 嵌入 + 位置编码 self.tile_embedding = nn.Embedding(num_classes, d_model) @@ -51,6 +61,12 @@ class GinkaMaskGIT(nn.Module): nn.LayerNorm(d_model), ) + # 结构标签嵌入(编码到 d_z 维度,与 z 拼接后统一投影到 d_model) + self.sym_embed = nn.Embedding(SYM_VOCAB, d_z) + self.room_embed = nn.Embedding(ROOM_VOCAB, d_z) + self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z) + self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) + # Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention self.transformer = Transformer( d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers @@ -58,33 +74,75 @@ class GinkaMaskGIT(nn.Module): self.output_fc = nn.Linear(d_model, num_classes) - def forward(self, map: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + def forward( + self, + map: torch.Tensor, + z: torch.Tensor, + struct_cond: torch.Tensor | None = None, + dropout_struct: bool = False, + ) -> torch.Tensor: """ Args: - map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15) - z: [B, L, d_z] VQ-VAE 量化后的离散隐变量 + map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15) + z: [B, L, d_z] VQ-VAE 量化后的离散隐变量 + struct_cond: [B, 4] 结构标签 LongTensor,顺序为 + [cond_sym, cond_room, cond_branch, cond_outer]; + 为 None 时等价于全 null(无条件模式) + dropout_struct: bool 强制将所有结构标签替换为 null(推理时无条件生成) Returns: logits: [B, H*W, num_classes] """ + B = z.shape[0] + # z dropout:训练时以一定概率将 z 替换为随机均匀噪声, # 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义 if self.training and self.z_dropout > 0: - mask = torch.rand(z.shape[0], 1, 1, device=z.device) < self.z_dropout + mask = torch.rand(B, 1, 1, device=z.device) < self.z_dropout rand_z = torch.randn_like(z) z = torch.where(mask, rand_z, z) - # 投影 z 到 d_model 维度 - z_mem = self.z_proj(z) # [B, L, d_model] + # 结构标签嵌入 + # struct_cond 为 None 或 dropout_struct=True 时,全部使用 null 索引 + if struct_cond is None or dropout_struct: + sym_idx = torch.full((B,), SYM_VOCAB - 1, dtype=torch.long, device=z.device) + room_idx = torch.full((B,), ROOM_VOCAB - 1, dtype=torch.long, device=z.device) + branch_idx = torch.full((B,), BRANCH_VOCAB - 1, dtype=torch.long, device=z.device) + outer_idx = torch.full((B,), OUTER_VOCAB - 1, dtype=torch.long, device=z.device) + else: + sc = struct_cond.to(z.device) + sym_idx, room_idx, branch_idx, outer_idx = sc[:, 0], sc[:, 1], sc[:, 2], sc[:, 3] + + # 训练时对各标签独立做 struct dropout + if self.training and self.struct_dropout_prob > 0: + def _drop(idx, null_val): + drop_mask = torch.rand(B, device=z.device) < self.struct_dropout_prob + return torch.where(drop_mask, torch.full_like(idx, null_val), idx) + sym_idx = _drop(sym_idx, SYM_VOCAB - 1) + room_idx = _drop(room_idx, ROOM_VOCAB - 1) + branch_idx = _drop(branch_idx, BRANCH_VOCAB - 1) + outer_idx = _drop(outer_idx, OUTER_VOCAB - 1) + + # 嵌入结构标签到 d_z 维度,拼接到 z 序列末尾 + e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z] + e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z] + e_branch = self.branch_embed(branch_idx).unsqueeze(1) # [B, 1, d_z] + e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z] + + struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z] + z_ext = torch.cat([z, struct_seq], dim=1) # [B, L+4, d_z] + + # 统一投影到 d_model 维度 + z_mem = self.z_proj(z_ext) # [B, L+4, d_model] # tile embedding + 位置编码 - x = self.tile_embedding(map) # [B, H*W, d_model] - x = x + self.pos_embedding # [B, H*W, d_model] + x = self.tile_embedding(map) # [B, H*W, d_model] + x = x + self.pos_embedding # [B, H*W, d_model] - # Transformer:encoder 做 map 自注意力,decoder cross-attend z + # Transformer:encoder 做 map 自注意力,decoder cross-attend z+struct x = self.transformer(x, memory=z_mem) # [B, H*W, d_model] - logits = self.output_fc(x) # [B, H*W, num_classes] + logits = self.output_fc(x) # [B, H*W, num_classes] return logits @@ -107,11 +165,19 @@ if __name__ == "__main__": n = sum(p.numel() for p in module.parameters()) print(f" {name}: {n:,}") - map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169] - z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64] + map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169] + z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64] + struct_input = torch.tensor([[3, 1, 0, 1], + [0, 2, 1, 0], + [7, 3, 3, 2], + [1, 0, 2, 1]], dtype=torch.long).to(device) # [B=4, 4] model.train() - logits = model(map_input, z_input) + logits = model(map_input, z_input, struct_cond=struct_input) print(f"\nlogits shape: {logits.shape}") # [4, 169, 16] + # 无条件模式测试 + logits_uncond = model(map_input, z_input, struct_cond=None) + print(f"logits_uncond shape: {logits_uncond.shape}") # [4, 169, 16] + print_memory(device, "前向传播后") diff --git a/ginka/train_vq.py b/ginka/train_vq.py index 65eda30..2c4feb6 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -59,7 +59,8 @@ MG_D_MODEL = 192 MG_NHEAD = 8 MG_LAYERS = 4 MG_DIM_FF = 512 -MG_Z_DROPOUT= 0.15 # 训练时以此概率把 z 替换为随机噪声 +MG_Z_DROPOUT = 0.15 # 训练时以此概率把 z 替换为随机噪声 +MG_STRUCT_DROPOUT= 0.15 # 训练时以此概率将结构标签替换为 null(无条件占位) # 验证时对每条样本额外采样的 z 数量(0 = 只用真实 z) N_Z_SAMPLES = 3 @@ -106,6 +107,7 @@ def maskgit_generate( z: torch.Tensor, steps: int = GENERATE_STEP, init_map: torch.Tensor = None, + struct_cond: torch.Tensor | None = None, ) -> torch.Tensor: """ 迭代生成地图(cosine schedule unmasking)。 @@ -133,7 +135,7 @@ def maskgit_generate( if not generatable.any(): break - logits = model_mg(map_seq, z) # [B, S, C] + logits = model_mg(map_seq, z, struct_cond=struct_cond) # [B, S, C] probs = F.softmax(logits, dim=-1) dist = torch.distributions.Categorical(probs) sampled = dist.sample() # [B, S] @@ -279,7 +281,8 @@ def validate( B = raw_map.shape[0] z_q, _, vq_loss, _, _ = model_vq(raw_map) - logits = model_mg(masked_map, z_q) + struct_cond_b = batch["struct_cond"].to(device) # [B, 4] + logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b) mask = (masked_map == MASK_TOKEN) ce_loss = F.cross_entropy( @@ -294,9 +297,10 @@ def validate( s = subsets[i] if captured[s] is None: captured[s] = { - "raw": raw_map[i:i+1].clone(), - "masked": masked_map[i:i+1].clone(), - "z_q": z_q[i:i+1].clone(), + "raw": raw_map[i:i+1].clone(), + "masked": masked_map[i:i+1].clone(), + "z_q": z_q[i:i+1].clone(), + "struct_cond": struct_cond_b[i:i+1].clone(), } if all(v is not None for v in captured.values()): @@ -307,24 +311,24 @@ def validate( imgs = [] for i in range(n): z_r = model_vq.sample(1, device) - gen = maskgit_generate(model_mg, z_r, init_map=cond_map) + gen = maskgit_generate(model_mg, z_r, init_map=cond_map) # struct_cond=None 无条件 imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}")) return imgs # ── 场景1:标准掩码补全(子集 A)───────────────────────────────────────── if captured['A'] is not None: cap = captured['A'] - raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q'] + raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond'] real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") cond_img = label_image(make_map_image(cond[0], tile_dict), "masked input") # 单步 argmax 预测(观察模型对掩码位置的瞬时判断) - pred = model_mg(cond, z_q).argmax(dim=-1)[0] + pred = model_mg(cond, z_q, struct_cond=sc).argmax(dim=-1)[0] pred_img = label_image(make_map_image(pred, tile_dict), "z_real pred") # 迭代生成(从掩码输入出发,真实 z) - gen_real = maskgit_generate(model_mg, z_q, init_map=cond) + gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) @@ -333,11 +337,11 @@ def validate( # ── 场景2:墙壁辅助生成(子集 B)───────────────────────────────────────── if captured['B'] is not None: cap = captured['B'] - raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q'] + raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond'] real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") cond_img = label_image(make_map_image(cond[0], tile_dict), "wall-only input") - gen_real = maskgit_generate(model_mg, z_q, init_map=cond) + gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) @@ -346,11 +350,11 @@ def validate( # ── 场景3:稀疏墙壁条件生成(子集 C)──────────────────────────────────── if captured['C'] is not None: cap = captured['C'] - raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q'] + raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond'] real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") cond_img = label_image(make_map_image(cond[0], tile_dict), "sparse wall input") - gen_real = maskgit_generate(model_mg, z_q, init_map=cond) + gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) @@ -359,11 +363,11 @@ def validate( # ── 场景4:墙壁+入口条件生成(子集 D)─────────────────────────────────── if captured['D'] is not None: cap = captured['D'] - raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q'] + raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond'] real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") cond_img = label_image(make_map_image(cond[0], tile_dict), "wall+entrance input") - gen_real = maskgit_generate(model_mg, z_q, init_map=cond) + gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) @@ -403,6 +407,7 @@ def train(): num_layers=MG_LAYERS, map_size=MAP_SIZE, z_dropout=MG_Z_DROPOUT, + struct_dropout=MG_STRUCT_DROPOUT, ).to(device) vq_params = sum(p.numel() for p in model_vq.parameters()) @@ -419,6 +424,8 @@ def train(): dataset_val = GinkaVQDataset( args.validate, subset_weights=SUBSET_WEIGHTS, + room_thresholds=dataset_train.room_th, + branch_thresholds=dataset_train.branch_th, ) dataloader_train = DataLoader( dataset_train, batch_size=BATCH_SIZE, shuffle=True, @@ -481,8 +488,9 @@ def train(): # 1. VQ-VAE 编码真实地图 → z_q z_q, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q: [B, L, d_z] - # 2. MaskGIT 以掩码地图 + z 预测原始 tile - logits = model_mg(masked_map, z_q) # [B, 169, C] + # 2. MaskGIT 以掩码地图 + z + 结构标签预测原始 tile + struct_cond = batch["struct_cond"].to(device) # [B, 4] + logits = model_mg(masked_map, z_q, struct_cond=struct_cond) # [B, 169, C] # 3. 只对被 mask 的位置计算 CE loss mask = (masked_map == MASK_TOKEN) # [B, 169] bool