feat: 添加少数结构性标签

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-27 14:56:21 +08:00
parent 21315a6cb0
commit abbad781ab
10 changed files with 1212 additions and 50 deletions

View File

@ -361,7 +361,15 @@ const labelConfig: IAutoLabelConfig = {
0,
0,
0
]
],
symmetry: [
info.symmetryH ? 1 : 0,
info.symmetryV ? 1 : 0,
info.symmetryC ? 1 : 0
],
outerWall: info.outerWall ? 1 : 0,
roomCount: info.roomCount,
highDegBranchCount: info.highDegBranchCount
};
dataset.data[id] = data;
});

View File

@ -3,8 +3,10 @@ import {
GraphNodeType,
IAutoLabelConfig,
IFloorInfo,
IMapGraph,
IMapTileConverter,
ITowerInfo,
MapGraphNode,
TowerColor
} from './types';
import {
@ -153,6 +155,180 @@ export function computeWallDensityStd(
return Math.sqrt(variance);
}
/**
* convertedMap true
*/
function computeSymmetry(map: number[][]): {
symmetryH: boolean;
symmetryV: boolean;
symmetryC: boolean;
} {
const H = map.length;
const W = H > 0 ? map[0].length : 0;
let symmetryH = true;
let symmetryV = true;
let symmetryC = true;
outer: for (let y = 0; y < H; y++) {
for (let x = 0; x < W; x++) {
const tile = map[y][x];
if (symmetryH && tile !== map[y][W - 1 - x]) symmetryH = false;
if (symmetryV && tile !== map[H - 1 - y][x]) symmetryV = false;
if (symmetryC && tile !== map[H - 1 - y][W - 1 - x])
symmetryC = false;
if (!symmetryH && !symmetryV && !symmetryC) break outer;
}
}
return { symmetryH, symmetryV, symmetryC };
}
/**
* + 90%
* @param map convertedMap
* @param wall
* @param entry
*/
function computeOuterWall(
map: number[][],
wall: number,
entry: number
): boolean {
const H = map.length;
const W = H > 0 ? map[0].length : 0;
if (H < 2 || W < 2) return false;
let borderCount = 0;
let wallOrEntry = 0;
const check = (tile: number) => {
borderCount++;
if (tile === wall || tile === entry) wallOrEntry++;
};
for (let x = 0; x < W; x++) {
check(map[0][x]);
check(map[H - 1][x]);
}
for (let y = 1; y < H - 1; y++) {
check(map[y][0]);
check(map[y][W - 1]);
}
return borderCount > 0 && wallOrEntry / borderCount > 0.9;
}
/**
* "房间"
*
*
* 1. Empty / Resource BFS
* "候选区域"Branch
* 2.
* a. Branch
* b. >= 4
* c. > 1 > 1
*
* @param graph
* @param width
*/
function computeRoomCount(graph: IMapGraph, width: number): number {
const allEmptyResource = new Set<MapGraphNode>();
for (const node of graph.nodeMap.values()) {
if (
node.type === GraphNodeType.Empty ||
node.type === GraphNodeType.Resource
) {
allEmptyResource.add(node);
}
}
let roomCount = 0;
const visited = new Set<MapGraphNode>();
for (const startNode of allEmptyResource) {
if (visited.has(startNode)) continue;
const regionNodes = new Set<MapGraphNode>();
const queue: MapGraphNode[] = [startNode];
visited.add(startNode);
while (queue.length > 0) {
const current = queue.shift()!;
regionNodes.add(current);
for (const nb of current.neighbors) {
if (
!visited.has(nb) &&
(nb.type === GraphNodeType.Empty ||
nb.type === GraphNodeType.Resource)
) {
visited.add(nb);
queue.push(nb);
}
}
}
// 条件 a区域内任一节点有 Branch 邻居
let hasBranch = false;
outer: for (const node of regionNodes) {
for (const nb of node.neighbors) {
if (nb.type === GraphNodeType.Branch) {
hasBranch = true;
break outer;
}
}
}
if (!hasBranch) continue;
// 收集区域内所有格子,计算总数和外接矩形
let totalTiles = 0;
let minX = Infinity,
maxX = -Infinity,
minY = Infinity,
maxY = -Infinity;
for (const node of regionNodes) {
totalTiles += node.tiles.size;
for (const t of node.tiles) {
const x = t % width;
const y = (t - x) / width;
if (x < minX) minX = x;
if (x > maxX) maxX = x;
if (y < minY) minY = y;
if (y > maxY) maxY = y;
}
}
// 条件 b总格子数 >= 4
if (totalTiles < 4) continue;
// 条件 c外接矩形宽高均 > 1
if (maxX - minX < 1 || maxY - minY < 1) continue;
roomCount++;
}
return roomCount;
}
/**
* >= 3
* @param graph
*/
function computeHighDegBranchCount(graph: IMapGraph): number {
let count = 0;
const visited = new Set<MapGraphNode>();
for (const node of graph.nodeMap.values()) {
if (visited.has(node)) continue;
visited.add(node);
if (node.type === GraphNodeType.Branch && node.neighbors.size >= 3) {
count++;
}
}
return count;
}
/**
*
* @param tower
@ -177,6 +353,17 @@ export function parseFloorInfo(
);
const flattened = map.flat();
const area = flattened.length;
const width = map[0]?.length ?? 0;
// ── 结构标签计算 ─────────────────────────────────
const { symmetryH, symmetryV, symmetryC } = computeSymmetry(map);
const outerWall = computeOuterWall(
map,
config.classes.wall,
config.classes.entry
);
const roomCount = computeRoomCount(topo.graph, width);
const highDegBranchCount = computeHighDegBranchCount(topo.graph);
let hasUselessBranch = false;
@ -283,7 +470,13 @@ export function parseFloorInfo(
doorHeatmap: gaussainHeatmap(
generateHeatmap(map, doorTiles, config.heatmapKernel),
config.guassainRadius
)
),
symmetryH,
symmetryV,
symmetryC,
outerWall,
roomCount,
highDegBranchCount
};
return floorInfo;

View File

@ -112,6 +112,20 @@ export interface IFloorInfo {
readonly entryHeatmap: number[][];
/** 门热力图 */
readonly doorHeatmap: number[][];
// ── 结构标签(新增)──────────────────────────────
/** 左右对称(基于 convertedMap 完全匹配) */
readonly symmetryH: boolean;
/** 上下对称 */
readonly symmetryV: boolean;
/** 中心对称 */
readonly symmetryC: boolean;
/** 是否外包围墙壁(最外圈墙壁+入口占比 > 90% */
readonly outerWall: boolean;
/** 房间数量原始值(供 Python 两趟扫描使用) */
readonly roomCount: number;
/** 高连接度分支节点数量原始值(供 Python 两趟扫描使用) */
readonly highDegBranchCount: number;
}
export interface IMapBlockConfig {

View File

@ -48,6 +48,15 @@ export interface GinkaTrainData {
map: number[][];
size: [number, number];
heatmap?: number[][][];
// ── 结构标签(新增)──────────────────────────────
/** 对称性:[symmetryH, symmetryV, symmetryC]0 或 1 */
symmetry: [number, number, number];
/** 是否外包围墙壁0 或 1 */
outerWall: number;
/** 房间数量原始值 */
roomCount: number;
/** 高连接度分支节点数量原始值 */
highDegBranchCount: number;
}
export interface GinkaDataset {

View File

@ -0,0 +1,248 @@
# 数据集标签设计文档
## 背景
在当前 VQ-VAE + MaskGIT 的联合训练方案中VQ-VAE 的 codebook 承担着地图风格与多样性的控制职能,但缺少可用户感知的语义维度。为提升生成的可控性,计划在数据集中添加一组**可程序化标注的结构标签**,作为额外条件输入训练,从而使模型能够接受来自用户的高层语义约束(如"生成一个左右对称的地图"、"高房间数量"等)。
所有标签均可在数据集构建阶段(`info.ts` / `parseFloorInfo`)或训练前的两趟扫描中自动计算,无需人工标注。
---
## 标签一览
| 标签名 | 类型 | 取值 | 标注时机 |
| ---------------- | --------- | ------------------- | -------- |
| `symmetryH` | `boolean` | 左右对称 | 单张地图 |
| `symmetryV` | `boolean` | 上下对称 | 单张地图 |
| `symmetryC` | `boolean` | 中心对称 | 单张地图 |
| `roomCountLevel` | `0\|1\|2` | Low / Medium / High | 两趟扫描 |
| `branchLevel` | `0\|1\|2` | Low / Medium / High | 两趟扫描 |
| `outerWall` | `boolean` | 是否外包围墙壁 | 单张地图 |
---
## 标签一:对称性
### 定义
针对经过转换后的地图(`convertedMap`)矩阵逐格比较,仅当**完全满足条件**时才标记为对应对称类型。三种对称相互独立,可同时成立。
| 对称类型 | 条件(对所有 `x ∈ [0, W)`, `y ∈ [0, H)` 成立) |
| -------- | ---------------------------------------------- |
| 左右对称 | `map[y][x] === map[y][W - 1 - x]` |
| 上下对称 | `map[y][x] === map[H - 1 - y][x]` |
| 中心对称 | `map[y][x] === map[H - 1 - y][W - 1 - x]` |
### 实现要点
- 比较使用 `convertedMap`(标签化图块编号),而非原始 `originMap`,使不同塔的同类图块具有可比性。
- 中心对称与左右、上下对称在数学上不蕴含关系(非充分也非必要),需独立计算。
- 对于奇数尺寸的地图(如 13×13中心行/列与自身比较必然成立,无需特殊处理。
### 字段扩展(`IFloorInfo`
```typescript
/** 左右对称 */
readonly symmetryH: boolean;
/** 上下对称 */
readonly symmetryV: boolean;
/** 中心对称 */
readonly symmetryC: boolean;
```
---
## 标签二:房间数量等级
### 定义
**房间Room** 是地图中由"空白节点"或"资源节点"组成的区域,同时满足以下三个条件:
1. **位置条件**:该节点在拓扑图中至少与 **1 个分支节点Branch** 相邻;
2. **面积条件**:节点所包含的地图格子数(`tiles.size`**≥ 4**
3. **形状条件**:节点所有格子的**外接矩形的宽和高都大于 1**(避免把单行/单列走廊计入房间)。
> **为什么这样定义房间**
>
> 在拓扑图中,空白/资源节点天然是游戏空间的"腔体",而分支节点(门/怪物)是进入腔体的关卡节点。只要与至少一个分支节点相邻,就说明这片空间是需要"先过关才能进入/离开"的区域——包括怪物守着宝箱这类单入口房间,同样是典型的房间结构。面积和形状约束则过滤掉通道和死胡同。
### 等级划分
等级为 **三档**Low0/ Medium1/ High2通过以下两趟扫描确定
1. **第一趟**:遍历整个训练集,计算每张地图的房间数量 `roomCount`,收集为数组;
2. **第二趟**:对数组升序排序,取 1/3 和 2/3 分位数作为阈值 `[th1, th2]`
- `roomCount < th1` → Low0
- `th1 ≤ roomCount < th2` → Medium1
- `roomCount ≥ th2` → High2
等级划分力求三档样本数量均等(等频分箱),而非等距分箱。
### 外接矩形计算
给定节点的所有格子坐标集合(`tiles`,存储 `y * width + x` 的平坦坐标),还原为 `(x, y)` 后:
$$
x_{\min} = \min_{t \in tiles}(t \bmod W), \quad x_{\max} = \max_{t \in tiles}(t \bmod W)
$$
$$
y_{\min} = \min_{t \in tiles}(\lfloor t / W \rfloor), \quad y_{\max} = \max_{t \in tiles}(\lfloor t / W \rfloor)
$$
外接矩形宽 = $x_{\max} - x_{\min} + 1$,高 = $y_{\max} - y_{\min} + 1$,两者均需 $> 1$。
### 字段扩展
```typescript
/** 房间数量(原始统计值,供两趟扫描使用) */
readonly roomCount: number;
/** 房间数量等级0=Low, 1=Medium, 2=High需两趟扫描后赋值 */
roomCountLevel: 0 | 1 | 2;
```
---
## 标签三:分支数量等级
### 定义
**高连接度分支节点**:拓扑图中 `type === Branch` 的节点,其**非墙邻居节点**总数(`neighbors.size`**≥ 3**。由于当前拓扑图中已不含墙节点,`neighbors.size` 即等于非墙邻居数,两者在实现上等价。
这类节点是地图中的"交叉口"——一个门或怪物后方至少有三条不同路线,是地图分叉度和策略深度的指征。
等级划分方式与房间数量等级相同:先统计每张地图中高连接度分支节点的数量,再等频分箱为 Low0/ Medium1/ High2
### 与房间数量的区别
| 维度 | 房间数量等级 | 分支数量等级 |
| -------- | -------------------------- | ---------------------------- |
| 度量对象 | 空白/资源节点区域的封闭性 | 分支节点的路径分叉度 |
| 反映特征 | 地图内封闭房间的数量与密度 | 关键路口/多分支门怪的复杂度 |
| 典型高值 | 多房间迷宫风格地图 | 高度分叉、策略选择丰富的地图 |
### 字段扩展
```typescript
/** 高连接度分支节点数量(原始统计值) */
readonly highDegBranchCount: number;
/** 分支数量等级0=Low, 1=Medium, 2=High需两趟扫描后赋值 */
branchLevel: 0 | 1 | 2;
```
---
## 标签四:外包围墙壁
### 定义
地图**最外圈**(最外一圈格子)的格子中,**墙壁格子与入口格子**之和占外圈总格子数的比例 > **90%**,则标记为 `outerWall = true`
$$
\text{outerWall} = \frac{|\{(x,y) \in \text{border} : \text{isWall}(x,y) \lor \text{isEntry}(x,y)\}|}{|\text{border}|} > 0.9
$$
### 最外圈定义
对于 $H \times W$ 的地图,最外圈为所有满足下列条件之一的格子:
$$
x = 0 \; \lor \; x = W - 1 \; \lor \; y = 0 \; \lor \; y = H - 1
$$
最外圈格子总数为 $2(H + W) - 4$(对于 13×13共 48 格)。
### 为什么入口也算"通过"
入口格子在游戏中是楼梯/传送点,不属于可通行的空地,在视觉和结构上等价于边界开口,属于外圈围合结构的合理组成部分,不应被视为"破坏围合"的元素。
### 实现要点
- 使用 `convertedMap` 判断墙壁(`tile === config.wall`
- 使用 `originMap` + `converter.isEntry()` 或直接对 `convertedMap` 判断入口(`tile === config.entry`)判断入口;
- 两项合取计入分子。
### 字段扩展
```typescript
/** 是否外包围墙壁 */
readonly outerWall: boolean;
```
---
## 实现方案
### 单张地图可直接计算的标签
以下标签可在 `parseFloorInfo` 中直接计算,加入 `IFloorInfo`
- `symmetryH`、`symmetryV`、`symmetryC`
- `outerWall`
- `roomCount`(原始值)
- `highDegBranchCount`(原始值)
### 需要两趟扫描的等级标签
`roomCountLevel``branchLevel` 依赖全局分位数,须在数据集构建完成后进行二次处理:
```
第一趟:构建所有楼层的 IFloorInfo写入 roomCount / highDegBranchCount
第二趟:收集所有楼层的原始值 → 计算 1/3, 2/3 分位 → 回填等级
```
在 Python 训练侧,推荐方式:
```python
# 在 Dataset.__init__ 中完成两趟计算
counts = [item['roomCount'] for item in raw_data]
counts_sorted = sorted(counts)
th1 = counts_sorted[len(counts_sorted) // 3]
th2 = counts_sorted[2 * len(counts_sorted) // 3]
for item in raw_data:
c = item['roomCount']
item['roomCountLevel'] = 0 if c < th1 else (1 if c < th2 else 2)
```
同理处理 `branchLevel`
> **注意**:分位数阈值应仅基于**训练集**统计,验证集 / 测试集使用相同的阈值映射,避免数据泄露。
---
## 训练集成
### 条件嵌入
将上述标签作为离散条件与 VQ-VAE 的 z 一同注入 MaskGIT
```python
# 对称性:三个独立布尔值,可合并为 0~7 的整数 cond_sym
cond_sym = symmetryH * 4 + symmetryV * 2 + symmetryC * 1 # [0, 7]
# 房间等级0 / 1 / 2
cond_room = roomCountLevel
# 分支等级0 / 1 / 2
cond_branch = branchLevel
# 外包围墙壁0 / 1
cond_outer = int(outerWall)
```
每个条件通过独立的 `nn.Embedding` 映射为固定维度向量,与 VQ-VAE 的 z 序列沿序列维度拼接后,一同经 Cross-Attention 注入 MaskGIT。
### 条件 Dropout
与 z dropout 类似,训练时以一定概率(如 10~20%)将部分或全部结构标签替换为"无条件"null embedding使模型在推理时支持条件缺省CFG 风格)。
---
## 待细化事项
- [x] 90% 阈值是否合适?——**保持 90%**,如后续数据分布分析发现问题再调整。
- [x] 房间定义中分支节点邻居数量——**改为至少 1 个**,覆盖怪物守宝箱的单入口房间场景。
- [x] 分支等级邻居计数口径——**使用非墙邻居**(当前图结构中无墙节点,与 `neighbors.size` 等价)。
- [x] 是否新增通道数量/路径长度标签——**暂不考虑**。
- [x] 条件嵌入维度对齐——**各标签 Embedding 与 z 序列拼接后统一经 Cross-Attention 注入**。

563
docs/dataset-labels-impl.md Normal file
View File

@ -0,0 +1,563 @@
# 数据集标签实现指南
## 总览
本文档描述将[标签设计文档](./dataset-labels-design.md)中定义的四类结构标签落地到代码中的具体实现步骤,涉及以下文件:
| 文件 | 变更性质 |
| ------------------------ | ------------------------------------------- |
| `data/src/auto/types.ts` | 扩展 `IFloorInfo` 接口 |
| `data/src/types.ts` | 扩展 `GinkaTrainData` 接口 |
| `data/src/auto/info.ts` | 新增四个 helper 函数,修改 `parseFloorInfo` |
| `data/src/auto.ts` | 在数据集序列化时写入新字段 |
| `ginka/dataset.py` | 两趟扫描 + `__getitem__` 读取新标签 |
---
## 第一步:扩展 TypeScript 类型
### `data/src/auto/types.ts` — 扩展 `IFloorInfo`
`IFloorInfo` 接口的 `doorHeatmap` 字段之后追加以下字段:
```typescript
// ── 结构标签(新增)──────────────────────────────
/** 左右对称(基于 convertedMap 完全匹配) */
readonly symmetryH: boolean;
/** 上下对称 */
readonly symmetryV: boolean;
/** 中心对称 */
readonly symmetryC: boolean;
/** 是否外包围墙壁(最外圈墙壁+入口占比 > 90% */
readonly outerWall: boolean;
/** 房间数量原始值(供 Python 两趟扫描使用) */
readonly roomCount: number;
/** 高连接度分支节点数量原始值(供 Python 两趟扫描使用) */
readonly highDegBranchCount: number;
```
### `data/src/types.ts` — 扩展 `GinkaTrainData`
`GinkaTrainData` 接口中追加序列化后写入 JSON 的字段:
```typescript
export interface GinkaTrainData {
tag?: number[];
val: number[];
map: number[][];
size: [number, number];
heatmap?: number[][][];
// ── 结构标签(新增)──────────────────────────────
/** 对称性:[symmetryH, symmetryV, symmetryC]0 或 1 */
symmetry: [number, number, number];
/** 是否外包围墙壁0 或 1 */
outerWall: number;
/** 房间数量原始值 */
roomCount: number;
/** 高连接度分支节点数量原始值 */
highDegBranchCount: number;
}
```
---
## 第二步:实现 Helper 函数(`data/src/auto/info.ts`
以下四个函数添加到 `computeWallDensityStd` 之后、`parseFloorInfo` 之前。
### 2.1 对称性计算
```typescript
/**
* 计算地图的三种对称性(基于 convertedMap完全匹配才标记为 true
*/
function computeSymmetry(map: number[][]): {
symmetryH: boolean;
symmetryV: boolean;
symmetryC: boolean;
} {
const H = map.length;
const W = H > 0 ? map[0].length : 0;
let symmetryH = true;
let symmetryV = true;
let symmetryC = true;
outer: for (let y = 0; y < H; y++) {
for (let x = 0; x < W; x++) {
const tile = map[y][x];
if (symmetryH && tile !== map[y][W - 1 - x]) symmetryH = false;
if (symmetryV && tile !== map[H - 1 - y][x]) symmetryV = false;
if (symmetryC && tile !== map[H - 1 - y][W - 1 - x])
symmetryC = false;
// 三种对称均已排除,提前退出
if (!symmetryH && !symmetryV && !symmetryC) break outer;
}
}
return { symmetryH, symmetryV, symmetryC };
}
```
**复杂度**O(H × W),最坏情况遍历全图一次,实际有短路优化。
**注意**:对于奇数尺寸地图(如 13×13中心行/列的格子与自身比较恒为 true不影响结果。
### 2.2 外包围墙壁检测
```typescript
/**
* 检测地图最外圈中墙壁 + 入口的占比是否超过 90%
* @param map convertedMap
* @param wall 墙壁图块编号
* @param entry 入口图块编号
*/
function computeOuterWall(
map: number[][],
wall: number,
entry: number
): boolean {
const H = map.length;
const W = H > 0 ? map[0].length : 0;
if (H < 2 || W < 2) return false;
let borderCount = 0;
let wallOrEntry = 0;
const check = (tile: number) => {
borderCount++;
if (tile === wall || tile === entry) wallOrEntry++;
};
// 顶行 + 底行
for (let x = 0; x < W; x++) {
check(map[0][x]);
check(map[H - 1][x]);
}
// 左列 + 右列(排除角格,已由上面计入)
for (let y = 1; y < H - 1; y++) {
check(map[y][0]);
check(map[y][W - 1]);
}
// 13×13 地图: borderCount = 2*(13+13)-4 = 48
return borderCount > 0 && wallOrEntry / borderCount > 0.9;
}
```
### 2.3 房间数量统计
**核心思路**:拓扑图中 Empty 节点和 Resource 节点在物理上可能彼此相邻,共同构成一个连续的游戏空间(如 2×3 的房间内放了一个宝物)。不能逐节点独立判断,而应先通过 BFS 将相邻的 Empty/Resource 节点合并为一个"候选区域"再对整个区域判断三个条件。Branch 节点作为边界,不会被合并进区域。
```typescript
/**
* 统计拓扑图中符合"房间"定义的连通区域数量。
*
* 算法:
* 1. 以 Empty / Resource 节点为顶点,在它们之间 BFS
* 得到若干"候选区域"Branch 节点作为边界,不被合并)。
* 2. 对每个候选区域检查三个条件:
* a. 区域内至少一个节点有 Branch 类型邻居
* b. 区域内所有格子总数 >= 4
* c. 所有格子的外接矩形宽 > 1 且高 > 1
*
* @param graph 拓扑图
* @param width 地图宽度(用于平坦坐标解码)
*/
function computeRoomCount(graph: IMapGraph, width: number): number {
// Step 1: 收集所有 Empty 和 Resource 节点(去重)
const allEmptyResource = new Set<MapGraphNode>();
for (const node of graph.nodeMap.values()) {
if (
node.type === GraphNodeType.Empty ||
node.type === GraphNodeType.Resource
) {
allEmptyResource.add(node);
}
}
// Step 2: BFS 将相邻的 Empty/Resource 节点合并为连通区域
let roomCount = 0;
const visited = new Set<MapGraphNode>();
for (const startNode of allEmptyResource) {
if (visited.has(startNode)) continue;
// BFS只在 Empty/Resource 节点间传播Branch 节点阻断
const regionNodes = new Set<MapGraphNode>();
const queue: MapGraphNode[] = [startNode];
visited.add(startNode);
while (queue.length > 0) {
const current = queue.shift()!;
regionNodes.add(current);
for (const nb of current.neighbors) {
if (
!visited.has(nb) &&
(nb.type === GraphNodeType.Empty ||
nb.type === GraphNodeType.Resource)
) {
visited.add(nb);
queue.push(nb);
}
}
}
// Step 3: 对合并后的区域检查三个条件
// 条件 a区域内任一节点有 Branch 邻居
let hasBranch = false;
outer: for (const node of regionNodes) {
for (const nb of node.neighbors) {
if (nb.type === GraphNodeType.Branch) {
hasBranch = true;
break outer;
}
}
}
if (!hasBranch) continue;
// 收集区域内所有格子,计算总数和外接矩形
let totalTiles = 0;
let minX = Infinity,
maxX = -Infinity,
minY = Infinity,
maxY = -Infinity;
for (const node of regionNodes) {
totalTiles += node.tiles.size;
for (const t of node.tiles) {
const x = t % width;
const y = (t - x) / width;
if (x < minX) minX = x;
if (x > maxX) maxX = x;
if (y < minY) minY = y;
if (y > maxY) maxY = y;
}
}
// 条件 b总格子数 >= 4
if (totalTiles < 4) continue;
// 条件 c外接矩形宽高均 > 1即 maxX - minX >= 1 && maxY - minY >= 1
if (maxX - minX < 1 || maxY - minY < 1) continue;
roomCount++;
}
return roomCount;
}
```
**边界情况讨论**
- **混合节点房间**2×3 房间5 个空地 + 1 个资源)→ 1 个 Empty 节点 + 1 个 Resource 节点BFS 合并后总格子数=6外接矩形 2×3**计入房间**。
- **纯条形走廊**1×4 的 Empty 节点 → 外接矩形高=1**不计入房间**。
- **孤立资源死角**Resource 节点的邻居全是 Entry 而无 Branch → 条件 a 不满足,**不计入房间**。
- **两个相邻的 Empty 区域被 Branch 隔开**BFS 不会跨越 Branch各自单独判断。
### 2.4 高连接度分支节点统计
```typescript
/**
* 统计邻居数 >= 3 的分支节点数量
*
* 由于拓扑图中已不含墙节点neighbors.size 即等于非墙邻居数。
*
* @param graph 拓扑图
*/
function computeHighDegBranchCount(graph: IMapGraph): number {
let count = 0;
const visited = new Set<MapGraphNode>();
for (const node of graph.nodeMap.values()) {
if (visited.has(node)) continue;
visited.add(node);
if (node.type === GraphNodeType.Branch && node.neighbors.size >= 3) {
count++;
}
}
return count;
}
```
---
## 第三步:在 `parseFloorInfo` 中调用
`parseFloorInfo` 函数内,`topo` 构建完毕、`floorInfo` 对象构造之前,调用上述四个函数:
```typescript
// ── 结构标签计算 ─────────────────────────────────
const width = map[0]?.length ?? 0;
const { symmetryH, symmetryV, symmetryC } = computeSymmetry(map);
const outerWall = computeOuterWall(
map,
config.classes.wall,
config.classes.entry
);
const roomCount = computeRoomCount(topo.graph, width);
const highDegBranchCount = computeHighDegBranchCount(topo.graph);
```
然后在 `floorInfo` 字面量中追加这些字段:
```typescript
const floorInfo: IFloorInfo = {
// ...(原有字段保持不变)
symmetryH,
symmetryV,
symmetryC,
outerWall,
roomCount,
highDegBranchCount
};
```
---
## 第四步:在 `data/src/auto.ts` 中序列化新字段
在构建 `GinkaTrainData` 的对象字面量中追加:
```typescript
const data: GinkaTrainData = {
map: floor.data.map,
size: [width, height],
heatmap: [
/* ...原有热力图通道... */
],
val: [
/* ...原有标量... */
],
// ── 新增结构标签 ──────────────────────────────
symmetry: [
info.symmetryH ? 1 : 0,
info.symmetryV ? 1 : 0,
info.symmetryC ? 1 : 0
],
outerWall: info.outerWall ? 1 : 0,
roomCount: info.roomCount,
highDegBranchCount: info.highDegBranchCount
};
```
布尔值存为 `0/1` 整数,便于 Python 侧直接读取,不需要类型转换。
---
## 第五步Python 侧两趟扫描(`ginka/dataset.py`
### 5.1 等频分箱函数
`dataset.py` 顶部添加通用的等频分箱 helper
```python
def assign_level(values: list[int]) -> list[int]:
"""
将整数列表按等频分箱映射为 0/1/2 三档等级。
分位数阈值基于当前列表计算(训练集与验证集应共用训练集的阈值)。
Args:
values: 原始统计值列表,顺序与数据集一一对应
Returns:
与输入等长的等级列表,每项为 0 / 1 / 2
"""
n = len(values)
if n == 0:
return []
sorted_vals = sorted(values)
th1 = sorted_vals[n // 3]
th2 = sorted_vals[2 * n // 3]
return [
0 if v < th1 else (1 if v < th2 else 2)
for v in values
]
```
### 5.2 修改 `GinkaVQDataset.__init__`
在加载数据之后立即进行两趟扫描,并将等级结果回填到各 item 中:
```python
class GinkaVQDataset(Dataset):
def __init__(
self,
data_path: str,
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
wall_mask_min: float = 0.0,
wall_mask_max: float = 0.5,
# 以下两个参数仅在 train 模式下使用,验证集传入训练集的阈值即可
room_thresholds: tuple[int, int] | None = None,
branch_thresholds: tuple[int, int] | None = None,
):
self.data = load_data(data_path)
# ...(原有初始化逻辑)
# ── 两趟扫描:计算等频分箱阈值 ──────────────────────────────
room_counts = [item['roomCount'] for item in self.data]
branch_counts = [item['highDegBranchCount'] for item in self.data]
if room_thresholds is None:
# 训练集:自行计算阈值
n = len(room_counts)
rs = sorted(room_counts)
bs = sorted(branch_counts)
self.room_th = (rs[n // 3], rs[2 * n // 3])
self.branch_th = (bs[n // 3], bs[2 * n // 3])
else:
# 验证集/测试集:直接使用训练集的阈值,避免数据泄露
self.room_th = room_thresholds
self.branch_th = branch_thresholds
def to_level(v: int, th: tuple[int, int]) -> int:
return 0 if v < th[0] else (1 if v < th[1] else 2)
# 回填等级字段
for item in self.data:
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
```
**调用方式**`train_vq.py` 中):
```python
dataset_train = GinkaVQDataset(args.train)
dataset_val = GinkaVQDataset(
args.validate,
room_thresholds=dataset_train.room_th,
branch_thresholds=dataset_train.branch_th
)
```
### 5.3 `__getitem__` 中读取结构标签
对称性标签在数据增强(旋转/翻转)后需要**重新从增强后的地图中计算**,因为 `rot90(k=1/3)` 会交换 `symmetryH``symmetryV`。其他三个标签在旋转/翻转下保持不变,可直接读取。
```python
def _compute_symmetry(target_np: np.ndarray) -> tuple[int, int, int]:
"""从 numpy 地图矩阵中直接计算三种对称性O(H*W)"""
H, W = target_np.shape
sym_h = bool(np.all(target_np == target_np[:, ::-1]))
sym_v = bool(np.all(target_np == target_np[::-1, :]))
sym_c = bool(np.all(target_np == target_np[::-1, ::-1]))
return int(sym_h), int(sym_v), int(sym_c)
```
`__getitem__` 数据增强完成后,读取所有标签:
```python
def __getitem__(self, idx):
# ...原有增强逻辑target_np 已经过 rot90 / flip
# 对称性:在增强后重新计算
sym_h, sym_v, sym_c = _compute_symmetry(target_np)
cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7]
# 其余标签:增强不改变拓扑结构,直接读取
item = self.data[idx]
cond_room = item['roomCountLevel'] # 0/1/2
cond_branch = item['branchLevel'] # 0/1/2
cond_outer = item['outerWall'] # 0/1
# 封装为 tensor
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
return {
"raw_map": ...,
"masked_map": ...,
"target_map": ...,
"subset": ...,
"struct_cond": struct_cond # [4],供模型 Embedding 查表
}
```
---
## 第六步:模型侧条件注入(`ginka/maskGIT/model.py`
`struct_cond` 的四个维度分别对应不同词表大小的 Embedding
```python
# 词表大小
SYM_VOCAB = 8 # cond_sym: [0, 7]
ROOM_VOCAB = 4 # cond_room: [0, 2] + 1 个 nulldropout 用)
BRANCH_VOCAB = 4 # cond_branch: [0, 2] + 1 个 null
OUTER_VOCAB = 3 # cond_outer: [0, 1] + 1 个 null
# 在 GinkaMaskGIT.__init__ 中
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
```
`forward` 中,将四个 Embedding 结果与 VQ-VAE 的 z 序列拼接,作为 Cross-Attention 的 memory
```python
def forward(self, map_tokens, z, struct_cond, dropout_struct=False):
# z: [B, L, d_z]
# struct_cond: [B, 4] — [sym, room, branch, outer]
B = z.size(0)
if dropout_struct or (self.training and torch.rand(1) < self.struct_dropout_prob):
# 条件 dropout全部替换为 null index各词表最后一个 index
e_sym = self.sym_embed(torch.full((B,), SYM_VOCAB - 1, device=z.device))
e_room = self.room_embed(torch.full((B,), ROOM_VOCAB - 1, device=z.device))
e_branch = self.branch_embed(torch.full((B,), BRANCH_VOCAB - 1, device=z.device))
e_outer = self.outer_embed(torch.full((B,), OUTER_VOCAB - 1, device=z.device))
else:
e_sym = self.sym_embed(struct_cond[:, 0]) # [B, d_z]
e_room = self.room_embed(struct_cond[:, 1])
e_branch = self.branch_embed(struct_cond[:, 2])
e_outer = self.outer_embed(struct_cond[:, 3])
# 将四个结构标签嵌入拼接为序列,与 z 合并
# 每个 e_* 形状为 [B, d_z]unsqueeze 后变为 [B, 1, d_z]
struct_seq = torch.stack(
[e_sym, e_room, e_branch, e_outer], dim=1
) # [B, 4, d_z]
memory = torch.cat([z, struct_seq], dim=1) # [B, L+4, d_z]
# 后续 Cross-Attention 正常进行
# query = map_token_embeddingskey/value = memory
...
```
**null index 规则**:各 Embedding 的最后一个 index 保留为"无条件"占位符,词表大小因此比有效类别数多 1。
---
## 边界情况与注意事项
### 关于两趟扫描的时机
两趟扫描在 `Dataset.__init__` 中完成,整个过程仅需遍历两次 Python 列表,耗时可忽略不计(数千条数据 < 1ms)。不建议延迟到 `__getitem__` 中逐条计算
### 关于对称性在数据增强下的变化
| 增强操作 | `symmetryH` | `symmetryV` | `symmetryC` |
| ----------------- | ----------- | ----------- | ----------- |
| `fliplr` | 不变 | 不变 | 不变 |
| `flipud` | 不变 | 不变 | 不变 |
| `rot90(k=1 or 3)` | 与 V 交换 | 与 H 交换 | 不变 |
| `rot90(k=2)` | 不变 | 不变 | 不变 |
因此在 `__getitem__` 中对增强后的地图**重新计算**对称性是最简洁正确的方案,无需记录增强历史。
### 关于极端分布
若训练集中某一标签的样本极度不均匀(如 90% 的地图无对称性),可以在条件 Dropout 中对不常见的条件值适当降低 dropout 概率,以确保模型充分学习该条件。初始阶段统一使用相同 dropout 概率即可,后续根据生成效果调整。
### 关于 `roomCount = 0``highDegBranchCount = 0` 的地图
这类地图在等频分箱后会进入 Low0等级。如果训练集中大量地图的值为 0`th1` 可能也为 0导致 Low 等级极少。可以在分箱前加一步检查:若 `th1 == th2`,则手动将 `th2 = th1 + 1` 以避免 Medium 等级为空。
```python
th1 = sorted_vals[n // 3]
th2 = sorted_vals[2 * n // 3]
if th1 == th2:
th2 = th1 + 1 # 防止 Medium 等级为空
```

View File

@ -376,7 +376,7 @@ $$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{comm
## 待探索事项
- 合适的 K、L 取值(建议从 K=16, L=2 开始实验)
- 合适的 K、L 取值(建议从 K=16, L=2 开始实验)K=1, L=64 可能比较合适。
- z dropout 的最优概率
- 若后续 codebook collapse引入 EMA 更新 + 重置机制
- 若后续需要更细粒度控制:加入标量 cond需对推理侧标量做随机扰动处理

View File

@ -232,6 +232,14 @@ class GinkaJointDataset(Dataset):
}
def _compute_symmetry(target_np: np.ndarray) -> tuple:
"""从 numpy 地图矩阵中直接计算三种对称性O(H*W)"""
sym_h = bool(np.all(target_np == target_np[:, ::-1]))
sym_v = bool(np.all(target_np == target_np[::-1, :]))
sym_c = bool(np.all(target_np == target_np[::-1, ::-1]))
return int(sym_h), int(sym_v), int(sym_c)
class GinkaVQDataset(Dataset):
"""
用于 VQ-VAE + MaskGIT 联合训练的多子集数据集
@ -259,13 +267,17 @@ class GinkaVQDataset(Dataset):
data_path: str,
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
wall_mask_ratio: float = 0.3,
room_thresholds: tuple = None,
branch_thresholds: tuple = None,
):
"""
Args:
data_path: JSON 数据文件路径
subset_weights: 子集 (A, B, C, D) 的采样权重自动归一化
wall_mask_ratio: Subset C 中额外随机 mask wall tile 比例上限
每次从 [0, wall_mask_ratio] 均匀采样实际比例
data_path: JSON 数据文件路径
subset_weights: 子集 (A, B, C, D) 的采样权重自动归一化
wall_mask_ratio: Subset C 中额外随机 mask wall tile 比例上限
每次从 [0, wall_mask_ratio] 均匀采样实际比例
room_thresholds: (th1, th2) 房间数量等频分箱阈值 None 时自动从当前数据计算训练集
branch_thresholds: (th1, th2) 分支数量等频分箱阈值 None 时自动从当前数据计算训练集
"""
self.data = load_data(data_path)
self.wall_mask_ratio = wall_mask_ratio
@ -275,6 +287,35 @@ class GinkaVQDataset(Dataset):
normalized = [x / total_w for x in subset_weights]
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
# ── 两趟扫描:计算等频分箱阈值 ──────────────────────────────
room_counts = [item['roomCount'] for item in self.data]
branch_counts = [item['highDegBranchCount'] for item in self.data]
if room_thresholds is None:
n = len(room_counts)
rs = sorted(room_counts)
bs = sorted(branch_counts)
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
# 防止 Medium 等级为空
if th1_r == th2_r:
th2_r = th1_r + 1
if th1_b == th2_b:
th2_b = th1_b + 1
self.room_th = (th1_r, th2_r)
self.branch_th = (th1_b, th2_b)
else:
self.room_th = room_thresholds
self.branch_th = branch_thresholds
def to_level(v: int, th: tuple) -> int:
return 0 if v < th[0] else (1 if v < th[1] else 2)
# 回填等级字段
for item in self.data:
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
def __len__(self):
return len(self.data)
@ -407,11 +448,23 @@ class GinkaVQDataset(Dataset):
masked_np = self._apply_subset(raw_np, subset) # [H*W]
raw_flat = raw_np.reshape(-1) # [H*W]
# 对称性:在增强后重新计算
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7]
# 其余结构标签:增强不改变拓扑结构,直接读取
cond_room = item['roomCountLevel'] # 0/1/2
cond_branch = item['branchLevel'] # 0/1/2
cond_outer = item['outerWall'] # 0/1
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
return {
"raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
"subset": subset, # 调试/统计用
"raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
"subset": subset, # 调试/统计用
"struct_cond": struct_cond, # [4],供模型 Embedding 查表
}

View File

@ -4,6 +4,12 @@ import torch.nn as nn
from ..utils import print_memory
from .maskGIT import Transformer
# 结构标签词表大小(最后一个索引为无条件占位符 null
SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-67 = null
ROOM_VOCAB = 4 # roomCountLevel 0-23 = null
BRANCH_VOCAB = 4 # branchLevel 0-23 = null
OUTER_VOCAB = 3 # outerWall 0-12 = null
class GinkaMaskGIT(nn.Module):
"""
@ -26,20 +32,24 @@ class GinkaMaskGIT(nn.Module):
num_layers: int = 4,
map_size: int = 13 * 13,
z_dropout: float = 0.1,
struct_dropout: float = 0.15,
):
"""
Args:
num_classes: tile 类别数 MASK token=15
d_model: Transformer 内部维度
d_z: VQ-VAE 码字嵌入维度需与 GinkaVQVAE.d_z 一致
dim_ff: 前馈网络隐层维度
nhead: 注意力头数
num_layers: Transformer 层数
map_size: 地图 token 总数H * W
z_dropout: 训练时随机替换 z 为随机码字的概率提升鲁棒性
num_classes: tile 类别数 MASK token=15
d_model: Transformer 内部维度
d_z: VQ-VAE 码字嵌入维度需与 GinkaVQVAE.d_z 一致
dim_ff: 前馈网络隐层维度
nhead: 注意力头数
num_layers: Transformer 层数
map_size: 地图 token 总数H * W
z_dropout: 训练时随机替换 z 为随机码字的概率提升鲁棒性
struct_dropout: 训练时以此概率将结构标签替换为 null无条件占位
实现 classifier-free guidance 兼容训练
"""
super().__init__()
self.z_dropout = z_dropout
self.struct_dropout_prob = struct_dropout
# Tile 嵌入 + 位置编码
self.tile_embedding = nn.Embedding(num_classes, d_model)
@ -51,6 +61,12 @@ class GinkaMaskGIT(nn.Module):
nn.LayerNorm(d_model),
)
# 结构标签嵌入(编码到 d_z 维度,与 z 拼接后统一投影到 d_model
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
# Transformerencoder 做 map token 自注意力decoder 做与 z 的 cross-attention
self.transformer = Transformer(
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
@ -58,33 +74,75 @@ class GinkaMaskGIT(nn.Module):
self.output_fc = nn.Linear(d_model, num_classes)
def forward(self, map: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
def forward(
self,
map: torch.Tensor,
z: torch.Tensor,
struct_cond: torch.Tensor | None = None,
dropout_struct: bool = False,
) -> torch.Tensor:
"""
Args:
map: [B, H*W] 掩码后的地图 token 序列MASK token = 15
z: [B, L, d_z] VQ-VAE 量化后的离散隐变量
map: [B, H*W] 掩码后的地图 token 序列MASK token = 15
z: [B, L, d_z] VQ-VAE 量化后的离散隐变量
struct_cond: [B, 4] 结构标签 LongTensor顺序为
[cond_sym, cond_room, cond_branch, cond_outer]
None 时等价于全 null无条件模式
dropout_struct: bool 强制将所有结构标签替换为 null推理时无条件生成
Returns:
logits: [B, H*W, num_classes]
"""
B = z.shape[0]
# z dropout训练时以一定概率将 z 替换为随机均匀噪声,
# 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义
if self.training and self.z_dropout > 0:
mask = torch.rand(z.shape[0], 1, 1, device=z.device) < self.z_dropout
mask = torch.rand(B, 1, 1, device=z.device) < self.z_dropout
rand_z = torch.randn_like(z)
z = torch.where(mask, rand_z, z)
# 投影 z 到 d_model 维度
z_mem = self.z_proj(z) # [B, L, d_model]
# 结构标签嵌入
# struct_cond 为 None 或 dropout_struct=True 时,全部使用 null 索引
if struct_cond is None or dropout_struct:
sym_idx = torch.full((B,), SYM_VOCAB - 1, dtype=torch.long, device=z.device)
room_idx = torch.full((B,), ROOM_VOCAB - 1, dtype=torch.long, device=z.device)
branch_idx = torch.full((B,), BRANCH_VOCAB - 1, dtype=torch.long, device=z.device)
outer_idx = torch.full((B,), OUTER_VOCAB - 1, dtype=torch.long, device=z.device)
else:
sc = struct_cond.to(z.device)
sym_idx, room_idx, branch_idx, outer_idx = sc[:, 0], sc[:, 1], sc[:, 2], sc[:, 3]
# 训练时对各标签独立做 struct dropout
if self.training and self.struct_dropout_prob > 0:
def _drop(idx, null_val):
drop_mask = torch.rand(B, device=z.device) < self.struct_dropout_prob
return torch.where(drop_mask, torch.full_like(idx, null_val), idx)
sym_idx = _drop(sym_idx, SYM_VOCAB - 1)
room_idx = _drop(room_idx, ROOM_VOCAB - 1)
branch_idx = _drop(branch_idx, BRANCH_VOCAB - 1)
outer_idx = _drop(outer_idx, OUTER_VOCAB - 1)
# 嵌入结构标签到 d_z 维度,拼接到 z 序列末尾
e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z]
e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z]
e_branch = self.branch_embed(branch_idx).unsqueeze(1) # [B, 1, d_z]
e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z]
struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z]
z_ext = torch.cat([z, struct_seq], dim=1) # [B, L+4, d_z]
# 统一投影到 d_model 维度
z_mem = self.z_proj(z_ext) # [B, L+4, d_model]
# tile embedding + 位置编码
x = self.tile_embedding(map) # [B, H*W, d_model]
x = x + self.pos_embedding # [B, H*W, d_model]
x = self.tile_embedding(map) # [B, H*W, d_model]
x = x + self.pos_embedding # [B, H*W, d_model]
# Transformerencoder 做 map 自注意力decoder cross-attend z
# Transformerencoder 做 map 自注意力decoder cross-attend z+struct
x = self.transformer(x, memory=z_mem) # [B, H*W, d_model]
logits = self.output_fc(x) # [B, H*W, num_classes]
logits = self.output_fc(x) # [B, H*W, num_classes]
return logits
@ -107,11 +165,19 @@ if __name__ == "__main__":
n = sum(p.numel() for p in module.parameters())
print(f" {name}: {n:,}")
map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169]
z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64]
map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169]
z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64]
struct_input = torch.tensor([[3, 1, 0, 1],
[0, 2, 1, 0],
[7, 3, 3, 2],
[1, 0, 2, 1]], dtype=torch.long).to(device) # [B=4, 4]
model.train()
logits = model(map_input, z_input)
logits = model(map_input, z_input, struct_cond=struct_input)
print(f"\nlogits shape: {logits.shape}") # [4, 169, 16]
# 无条件模式测试
logits_uncond = model(map_input, z_input, struct_cond=None)
print(f"logits_uncond shape: {logits_uncond.shape}") # [4, 169, 16]
print_memory(device, "前向传播后")

View File

@ -59,7 +59,8 @@ MG_D_MODEL = 192
MG_NHEAD = 8
MG_LAYERS = 4
MG_DIM_FF = 512
MG_Z_DROPOUT= 0.15 # 训练时以此概率把 z 替换为随机噪声
MG_Z_DROPOUT = 0.15 # 训练时以此概率把 z 替换为随机噪声
MG_STRUCT_DROPOUT= 0.15 # 训练时以此概率将结构标签替换为 null无条件占位
# 验证时对每条样本额外采样的 z 数量0 = 只用真实 z
N_Z_SAMPLES = 3
@ -106,6 +107,7 @@ def maskgit_generate(
z: torch.Tensor,
steps: int = GENERATE_STEP,
init_map: torch.Tensor = None,
struct_cond: torch.Tensor | None = None,
) -> torch.Tensor:
"""
迭代生成地图cosine schedule unmasking
@ -133,7 +135,7 @@ def maskgit_generate(
if not generatable.any():
break
logits = model_mg(map_seq, z) # [B, S, C]
logits = model_mg(map_seq, z, struct_cond=struct_cond) # [B, S, C]
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
sampled = dist.sample() # [B, S]
@ -279,7 +281,8 @@ def validate(
B = raw_map.shape[0]
z_q, _, vq_loss, _, _ = model_vq(raw_map)
logits = model_mg(masked_map, z_q)
struct_cond_b = batch["struct_cond"].to(device) # [B, 4]
logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b)
mask = (masked_map == MASK_TOKEN)
ce_loss = F.cross_entropy(
@ -294,9 +297,10 @@ def validate(
s = subsets[i]
if captured[s] is None:
captured[s] = {
"raw": raw_map[i:i+1].clone(),
"masked": masked_map[i:i+1].clone(),
"z_q": z_q[i:i+1].clone(),
"raw": raw_map[i:i+1].clone(),
"masked": masked_map[i:i+1].clone(),
"z_q": z_q[i:i+1].clone(),
"struct_cond": struct_cond_b[i:i+1].clone(),
}
if all(v is not None for v in captured.values()):
@ -307,24 +311,24 @@ def validate(
imgs = []
for i in range(n):
z_r = model_vq.sample(1, device)
gen = maskgit_generate(model_mg, z_r, init_map=cond_map)
gen = maskgit_generate(model_mg, z_r, init_map=cond_map) # struct_cond=None 无条件
imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}"))
return imgs
# ── 场景1标准掩码补全子集 A─────────────────────────────────────────
if captured['A'] is not None:
cap = captured['A']
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
cond_img = label_image(make_map_image(cond[0], tile_dict), "masked input")
# 单步 argmax 预测(观察模型对掩码位置的瞬时判断)
pred = model_mg(cond, z_q).argmax(dim=-1)[0]
pred = model_mg(cond, z_q, struct_cond=sc).argmax(dim=-1)[0]
pred_img = label_image(make_map_image(pred, tile_dict), "z_real pred")
# 迭代生成(从掩码输入出发,真实 z
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
@ -333,11 +337,11 @@ def validate(
# ── 场景2墙壁辅助生成子集 B─────────────────────────────────────────
if captured['B'] is not None:
cap = captured['B']
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall-only input")
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
@ -346,11 +350,11 @@ def validate(
# ── 场景3稀疏墙壁条件生成子集 C────────────────────────────────────
if captured['C'] is not None:
cap = captured['C']
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
cond_img = label_image(make_map_image(cond[0], tile_dict), "sparse wall input")
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
@ -359,11 +363,11 @@ def validate(
# ── 场景4墙壁+入口条件生成(子集 D───────────────────────────────────
if captured['D'] is not None:
cap = captured['D']
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall+entrance input")
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
@ -403,6 +407,7 @@ def train():
num_layers=MG_LAYERS,
map_size=MAP_SIZE,
z_dropout=MG_Z_DROPOUT,
struct_dropout=MG_STRUCT_DROPOUT,
).to(device)
vq_params = sum(p.numel() for p in model_vq.parameters())
@ -419,6 +424,8 @@ def train():
dataset_val = GinkaVQDataset(
args.validate,
subset_weights=SUBSET_WEIGHTS,
room_thresholds=dataset_train.room_th,
branch_thresholds=dataset_train.branch_th,
)
dataloader_train = DataLoader(
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
@ -481,8 +488,9 @@ def train():
# 1. VQ-VAE 编码真实地图 → z_q
z_q, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q: [B, L, d_z]
# 2. MaskGIT 以掩码地图 + z 预测原始 tile
logits = model_mg(masked_map, z_q) # [B, 169, C]
# 2. MaskGIT 以掩码地图 + z + 结构标签预测原始 tile
struct_cond = batch["struct_cond"].to(device) # [B, 4]
logits = model_mg(masked_map, z_q, struct_cond=struct_cond) # [B, 169, C]
# 3. 只对被 mask 的位置计算 CE loss
mask = (masked_map == MASK_TOKEN) # [B, 169] bool