mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 11:01:12 +08:00
feat: 添加少数结构性标签
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
21315a6cb0
commit
abbad781ab
@ -361,7 +361,15 @@ const labelConfig: IAutoLabelConfig = {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0
|
0
|
||||||
]
|
],
|
||||||
|
symmetry: [
|
||||||
|
info.symmetryH ? 1 : 0,
|
||||||
|
info.symmetryV ? 1 : 0,
|
||||||
|
info.symmetryC ? 1 : 0
|
||||||
|
],
|
||||||
|
outerWall: info.outerWall ? 1 : 0,
|
||||||
|
roomCount: info.roomCount,
|
||||||
|
highDegBranchCount: info.highDegBranchCount
|
||||||
};
|
};
|
||||||
dataset.data[id] = data;
|
dataset.data[id] = data;
|
||||||
});
|
});
|
||||||
|
|||||||
@ -3,8 +3,10 @@ import {
|
|||||||
GraphNodeType,
|
GraphNodeType,
|
||||||
IAutoLabelConfig,
|
IAutoLabelConfig,
|
||||||
IFloorInfo,
|
IFloorInfo,
|
||||||
|
IMapGraph,
|
||||||
IMapTileConverter,
|
IMapTileConverter,
|
||||||
ITowerInfo,
|
ITowerInfo,
|
||||||
|
MapGraphNode,
|
||||||
TowerColor
|
TowerColor
|
||||||
} from './types';
|
} from './types';
|
||||||
import {
|
import {
|
||||||
@ -153,6 +155,180 @@ export function computeWallDensityStd(
|
|||||||
return Math.sqrt(variance);
|
return Math.sqrt(variance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 计算地图的三种对称性(基于 convertedMap,完全匹配才标记为 true)
|
||||||
|
*/
|
||||||
|
function computeSymmetry(map: number[][]): {
|
||||||
|
symmetryH: boolean;
|
||||||
|
symmetryV: boolean;
|
||||||
|
symmetryC: boolean;
|
||||||
|
} {
|
||||||
|
const H = map.length;
|
||||||
|
const W = H > 0 ? map[0].length : 0;
|
||||||
|
let symmetryH = true;
|
||||||
|
let symmetryV = true;
|
||||||
|
let symmetryC = true;
|
||||||
|
|
||||||
|
outer: for (let y = 0; y < H; y++) {
|
||||||
|
for (let x = 0; x < W; x++) {
|
||||||
|
const tile = map[y][x];
|
||||||
|
if (symmetryH && tile !== map[y][W - 1 - x]) symmetryH = false;
|
||||||
|
if (symmetryV && tile !== map[H - 1 - y][x]) symmetryV = false;
|
||||||
|
if (symmetryC && tile !== map[H - 1 - y][W - 1 - x])
|
||||||
|
symmetryC = false;
|
||||||
|
if (!symmetryH && !symmetryV && !symmetryC) break outer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { symmetryH, symmetryV, symmetryC };
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检测地图最外圈中墙壁 + 入口的占比是否超过 90%
|
||||||
|
* @param map convertedMap
|
||||||
|
* @param wall 墙壁图块编号
|
||||||
|
* @param entry 入口图块编号
|
||||||
|
*/
|
||||||
|
function computeOuterWall(
|
||||||
|
map: number[][],
|
||||||
|
wall: number,
|
||||||
|
entry: number
|
||||||
|
): boolean {
|
||||||
|
const H = map.length;
|
||||||
|
const W = H > 0 ? map[0].length : 0;
|
||||||
|
if (H < 2 || W < 2) return false;
|
||||||
|
|
||||||
|
let borderCount = 0;
|
||||||
|
let wallOrEntry = 0;
|
||||||
|
|
||||||
|
const check = (tile: number) => {
|
||||||
|
borderCount++;
|
||||||
|
if (tile === wall || tile === entry) wallOrEntry++;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (let x = 0; x < W; x++) {
|
||||||
|
check(map[0][x]);
|
||||||
|
check(map[H - 1][x]);
|
||||||
|
}
|
||||||
|
for (let y = 1; y < H - 1; y++) {
|
||||||
|
check(map[y][0]);
|
||||||
|
check(map[y][W - 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return borderCount > 0 && wallOrEntry / borderCount > 0.9;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 统计拓扑图中符合"房间"定义的连通区域数量。
|
||||||
|
*
|
||||||
|
* 算法:
|
||||||
|
* 1. 以 Empty / Resource 节点为顶点,在它们之间 BFS,
|
||||||
|
* 得到若干"候选区域"(Branch 节点作为边界,不被合并)。
|
||||||
|
* 2. 对每个候选区域检查三个条件:
|
||||||
|
* a. 区域内至少一个节点有 Branch 类型邻居
|
||||||
|
* b. 区域内所有格子总数 >= 4
|
||||||
|
* c. 所有格子的外接矩形宽 > 1 且高 > 1
|
||||||
|
*
|
||||||
|
* @param graph 拓扑图
|
||||||
|
* @param width 地图宽度(用于平坦坐标解码)
|
||||||
|
*/
|
||||||
|
function computeRoomCount(graph: IMapGraph, width: number): number {
|
||||||
|
const allEmptyResource = new Set<MapGraphNode>();
|
||||||
|
for (const node of graph.nodeMap.values()) {
|
||||||
|
if (
|
||||||
|
node.type === GraphNodeType.Empty ||
|
||||||
|
node.type === GraphNodeType.Resource
|
||||||
|
) {
|
||||||
|
allEmptyResource.add(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let roomCount = 0;
|
||||||
|
const visited = new Set<MapGraphNode>();
|
||||||
|
|
||||||
|
for (const startNode of allEmptyResource) {
|
||||||
|
if (visited.has(startNode)) continue;
|
||||||
|
|
||||||
|
const regionNodes = new Set<MapGraphNode>();
|
||||||
|
const queue: MapGraphNode[] = [startNode];
|
||||||
|
visited.add(startNode);
|
||||||
|
|
||||||
|
while (queue.length > 0) {
|
||||||
|
const current = queue.shift()!;
|
||||||
|
regionNodes.add(current);
|
||||||
|
for (const nb of current.neighbors) {
|
||||||
|
if (
|
||||||
|
!visited.has(nb) &&
|
||||||
|
(nb.type === GraphNodeType.Empty ||
|
||||||
|
nb.type === GraphNodeType.Resource)
|
||||||
|
) {
|
||||||
|
visited.add(nb);
|
||||||
|
queue.push(nb);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 条件 a:区域内任一节点有 Branch 邻居
|
||||||
|
let hasBranch = false;
|
||||||
|
outer: for (const node of regionNodes) {
|
||||||
|
for (const nb of node.neighbors) {
|
||||||
|
if (nb.type === GraphNodeType.Branch) {
|
||||||
|
hasBranch = true;
|
||||||
|
break outer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!hasBranch) continue;
|
||||||
|
|
||||||
|
// 收集区域内所有格子,计算总数和外接矩形
|
||||||
|
let totalTiles = 0;
|
||||||
|
let minX = Infinity,
|
||||||
|
maxX = -Infinity,
|
||||||
|
minY = Infinity,
|
||||||
|
maxY = -Infinity;
|
||||||
|
|
||||||
|
for (const node of regionNodes) {
|
||||||
|
totalTiles += node.tiles.size;
|
||||||
|
for (const t of node.tiles) {
|
||||||
|
const x = t % width;
|
||||||
|
const y = (t - x) / width;
|
||||||
|
if (x < minX) minX = x;
|
||||||
|
if (x > maxX) maxX = x;
|
||||||
|
if (y < minY) minY = y;
|
||||||
|
if (y > maxY) maxY = y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 条件 b:总格子数 >= 4
|
||||||
|
if (totalTiles < 4) continue;
|
||||||
|
|
||||||
|
// 条件 c:外接矩形宽高均 > 1
|
||||||
|
if (maxX - minX < 1 || maxY - minY < 1) continue;
|
||||||
|
|
||||||
|
roomCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return roomCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 统计邻居数 >= 3 的分支节点数量(高连接度分支节点)
|
||||||
|
* @param graph 拓扑图
|
||||||
|
*/
|
||||||
|
function computeHighDegBranchCount(graph: IMapGraph): number {
|
||||||
|
let count = 0;
|
||||||
|
const visited = new Set<MapGraphNode>();
|
||||||
|
|
||||||
|
for (const node of graph.nodeMap.values()) {
|
||||||
|
if (visited.has(node)) continue;
|
||||||
|
visited.add(node);
|
||||||
|
|
||||||
|
if (node.type === GraphNodeType.Branch && node.neighbors.size >= 3) {
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 根据地图矩阵解析出地图数据
|
* 根据地图矩阵解析出地图数据
|
||||||
* @param tower 地图所属塔信息
|
* @param tower 地图所属塔信息
|
||||||
@ -177,6 +353,17 @@ export function parseFloorInfo(
|
|||||||
);
|
);
|
||||||
const flattened = map.flat();
|
const flattened = map.flat();
|
||||||
const area = flattened.length;
|
const area = flattened.length;
|
||||||
|
const width = map[0]?.length ?? 0;
|
||||||
|
|
||||||
|
// ── 结构标签计算 ─────────────────────────────────
|
||||||
|
const { symmetryH, symmetryV, symmetryC } = computeSymmetry(map);
|
||||||
|
const outerWall = computeOuterWall(
|
||||||
|
map,
|
||||||
|
config.classes.wall,
|
||||||
|
config.classes.entry
|
||||||
|
);
|
||||||
|
const roomCount = computeRoomCount(topo.graph, width);
|
||||||
|
const highDegBranchCount = computeHighDegBranchCount(topo.graph);
|
||||||
|
|
||||||
let hasUselessBranch = false;
|
let hasUselessBranch = false;
|
||||||
|
|
||||||
@ -283,7 +470,13 @@ export function parseFloorInfo(
|
|||||||
doorHeatmap: gaussainHeatmap(
|
doorHeatmap: gaussainHeatmap(
|
||||||
generateHeatmap(map, doorTiles, config.heatmapKernel),
|
generateHeatmap(map, doorTiles, config.heatmapKernel),
|
||||||
config.guassainRadius
|
config.guassainRadius
|
||||||
)
|
),
|
||||||
|
symmetryH,
|
||||||
|
symmetryV,
|
||||||
|
symmetryC,
|
||||||
|
outerWall,
|
||||||
|
roomCount,
|
||||||
|
highDegBranchCount
|
||||||
};
|
};
|
||||||
|
|
||||||
return floorInfo;
|
return floorInfo;
|
||||||
|
|||||||
@ -112,6 +112,20 @@ export interface IFloorInfo {
|
|||||||
readonly entryHeatmap: number[][];
|
readonly entryHeatmap: number[][];
|
||||||
/** 门热力图 */
|
/** 门热力图 */
|
||||||
readonly doorHeatmap: number[][];
|
readonly doorHeatmap: number[][];
|
||||||
|
|
||||||
|
// ── 结构标签(新增)──────────────────────────────
|
||||||
|
/** 左右对称(基于 convertedMap 完全匹配) */
|
||||||
|
readonly symmetryH: boolean;
|
||||||
|
/** 上下对称 */
|
||||||
|
readonly symmetryV: boolean;
|
||||||
|
/** 中心对称 */
|
||||||
|
readonly symmetryC: boolean;
|
||||||
|
/** 是否外包围墙壁(最外圈墙壁+入口占比 > 90%) */
|
||||||
|
readonly outerWall: boolean;
|
||||||
|
/** 房间数量原始值(供 Python 两趟扫描使用) */
|
||||||
|
readonly roomCount: number;
|
||||||
|
/** 高连接度分支节点数量原始值(供 Python 两趟扫描使用) */
|
||||||
|
readonly highDegBranchCount: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface IMapBlockConfig {
|
export interface IMapBlockConfig {
|
||||||
|
|||||||
@ -48,6 +48,15 @@ export interface GinkaTrainData {
|
|||||||
map: number[][];
|
map: number[][];
|
||||||
size: [number, number];
|
size: [number, number];
|
||||||
heatmap?: number[][][];
|
heatmap?: number[][][];
|
||||||
|
// ── 结构标签(新增)──────────────────────────────
|
||||||
|
/** 对称性:[symmetryH, symmetryV, symmetryC],0 或 1 */
|
||||||
|
symmetry: [number, number, number];
|
||||||
|
/** 是否外包围墙壁,0 或 1 */
|
||||||
|
outerWall: number;
|
||||||
|
/** 房间数量原始值 */
|
||||||
|
roomCount: number;
|
||||||
|
/** 高连接度分支节点数量原始值 */
|
||||||
|
highDegBranchCount: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface GinkaDataset {
|
export interface GinkaDataset {
|
||||||
|
|||||||
248
docs/dataset-labels-design.md
Normal file
248
docs/dataset-labels-design.md
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
# 数据集标签设计文档
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
在当前 VQ-VAE + MaskGIT 的联合训练方案中,VQ-VAE 的 codebook 承担着地图风格与多样性的控制职能,但缺少可用户感知的语义维度。为提升生成的可控性,计划在数据集中添加一组**可程序化标注的结构标签**,作为额外条件输入训练,从而使模型能够接受来自用户的高层语义约束(如"生成一个左右对称的地图"、"高房间数量"等)。
|
||||||
|
|
||||||
|
所有标签均可在数据集构建阶段(`info.ts` / `parseFloorInfo`)或训练前的两趟扫描中自动计算,无需人工标注。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 标签一览
|
||||||
|
|
||||||
|
| 标签名 | 类型 | 取值 | 标注时机 |
|
||||||
|
| ---------------- | --------- | ------------------- | -------- |
|
||||||
|
| `symmetryH` | `boolean` | 左右对称 | 单张地图 |
|
||||||
|
| `symmetryV` | `boolean` | 上下对称 | 单张地图 |
|
||||||
|
| `symmetryC` | `boolean` | 中心对称 | 单张地图 |
|
||||||
|
| `roomCountLevel` | `0\|1\|2` | Low / Medium / High | 两趟扫描 |
|
||||||
|
| `branchLevel` | `0\|1\|2` | Low / Medium / High | 两趟扫描 |
|
||||||
|
| `outerWall` | `boolean` | 是否外包围墙壁 | 单张地图 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 标签一:对称性
|
||||||
|
|
||||||
|
### 定义
|
||||||
|
|
||||||
|
针对经过转换后的地图(`convertedMap`)矩阵逐格比较,仅当**完全满足条件**时才标记为对应对称类型。三种对称相互独立,可同时成立。
|
||||||
|
|
||||||
|
| 对称类型 | 条件(对所有 `x ∈ [0, W)`, `y ∈ [0, H)` 成立) |
|
||||||
|
| -------- | ---------------------------------------------- |
|
||||||
|
| 左右对称 | `map[y][x] === map[y][W - 1 - x]` |
|
||||||
|
| 上下对称 | `map[y][x] === map[H - 1 - y][x]` |
|
||||||
|
| 中心对称 | `map[y][x] === map[H - 1 - y][W - 1 - x]` |
|
||||||
|
|
||||||
|
### 实现要点
|
||||||
|
|
||||||
|
- 比较使用 `convertedMap`(标签化图块编号),而非原始 `originMap`,使不同塔的同类图块具有可比性。
|
||||||
|
- 中心对称与左右、上下对称在数学上不蕴含关系(非充分也非必要),需独立计算。
|
||||||
|
- 对于奇数尺寸的地图(如 13×13),中心行/列与自身比较必然成立,无需特殊处理。
|
||||||
|
|
||||||
|
### 字段扩展(`IFloorInfo`)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
/** 左右对称 */
|
||||||
|
readonly symmetryH: boolean;
|
||||||
|
/** 上下对称 */
|
||||||
|
readonly symmetryV: boolean;
|
||||||
|
/** 中心对称 */
|
||||||
|
readonly symmetryC: boolean;
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 标签二:房间数量等级
|
||||||
|
|
||||||
|
### 定义
|
||||||
|
|
||||||
|
**房间(Room)** 是地图中由"空白节点"或"资源节点"组成的区域,同时满足以下三个条件:
|
||||||
|
|
||||||
|
1. **位置条件**:该节点在拓扑图中至少与 **1 个分支节点(Branch)** 相邻;
|
||||||
|
2. **面积条件**:节点所包含的地图格子数(`tiles.size`)**≥ 4**;
|
||||||
|
3. **形状条件**:节点所有格子的**外接矩形的宽和高都大于 1**(避免把单行/单列走廊计入房间)。
|
||||||
|
|
||||||
|
> **为什么这样定义房间**
|
||||||
|
>
|
||||||
|
> 在拓扑图中,空白/资源节点天然是游戏空间的"腔体",而分支节点(门/怪物)是进入腔体的关卡节点。只要与至少一个分支节点相邻,就说明这片空间是需要"先过关才能进入/离开"的区域——包括怪物守着宝箱这类单入口房间,同样是典型的房间结构。面积和形状约束则过滤掉通道和死胡同。
|
||||||
|
|
||||||
|
### 等级划分
|
||||||
|
|
||||||
|
等级为 **三档**:Low(0)/ Medium(1)/ High(2),通过以下两趟扫描确定:
|
||||||
|
|
||||||
|
1. **第一趟**:遍历整个训练集,计算每张地图的房间数量 `roomCount`,收集为数组;
|
||||||
|
2. **第二趟**:对数组升序排序,取 1/3 和 2/3 分位数作为阈值 `[th1, th2]`:
|
||||||
|
- `roomCount < th1` → Low(0)
|
||||||
|
- `th1 ≤ roomCount < th2` → Medium(1)
|
||||||
|
- `roomCount ≥ th2` → High(2)
|
||||||
|
|
||||||
|
等级划分力求三档样本数量均等(等频分箱),而非等距分箱。
|
||||||
|
|
||||||
|
### 外接矩形计算
|
||||||
|
|
||||||
|
给定节点的所有格子坐标集合(`tiles`,存储 `y * width + x` 的平坦坐标),还原为 `(x, y)` 后:
|
||||||
|
|
||||||
|
$$
|
||||||
|
x_{\min} = \min_{t \in tiles}(t \bmod W), \quad x_{\max} = \max_{t \in tiles}(t \bmod W)
|
||||||
|
$$
|
||||||
|
|
||||||
|
$$
|
||||||
|
y_{\min} = \min_{t \in tiles}(\lfloor t / W \rfloor), \quad y_{\max} = \max_{t \in tiles}(\lfloor t / W \rfloor)
|
||||||
|
$$
|
||||||
|
|
||||||
|
外接矩形宽 = $x_{\max} - x_{\min} + 1$,高 = $y_{\max} - y_{\min} + 1$,两者均需 $> 1$。
|
||||||
|
|
||||||
|
### 字段扩展
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
/** 房间数量(原始统计值,供两趟扫描使用) */
|
||||||
|
readonly roomCount: number;
|
||||||
|
/** 房间数量等级:0=Low, 1=Medium, 2=High(需两趟扫描后赋值) */
|
||||||
|
roomCountLevel: 0 | 1 | 2;
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 标签三:分支数量等级
|
||||||
|
|
||||||
|
### 定义
|
||||||
|
|
||||||
|
**高连接度分支节点**:拓扑图中 `type === Branch` 的节点,其**非墙邻居节点**总数(`neighbors.size`)**≥ 3**。由于当前拓扑图中已不含墙节点,`neighbors.size` 即等于非墙邻居数,两者在实现上等价。
|
||||||
|
|
||||||
|
这类节点是地图中的"交叉口"——一个门或怪物后方至少有三条不同路线,是地图分叉度和策略深度的指征。
|
||||||
|
|
||||||
|
等级划分方式与房间数量等级相同:先统计每张地图中高连接度分支节点的数量,再等频分箱为 Low(0)/ Medium(1)/ High(2)。
|
||||||
|
|
||||||
|
### 与房间数量的区别
|
||||||
|
|
||||||
|
| 维度 | 房间数量等级 | 分支数量等级 |
|
||||||
|
| -------- | -------------------------- | ---------------------------- |
|
||||||
|
| 度量对象 | 空白/资源节点区域的封闭性 | 分支节点的路径分叉度 |
|
||||||
|
| 反映特征 | 地图内封闭房间的数量与密度 | 关键路口/多分支门怪的复杂度 |
|
||||||
|
| 典型高值 | 多房间迷宫风格地图 | 高度分叉、策略选择丰富的地图 |
|
||||||
|
|
||||||
|
### 字段扩展
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
/** 高连接度分支节点数量(原始统计值) */
|
||||||
|
readonly highDegBranchCount: number;
|
||||||
|
/** 分支数量等级:0=Low, 1=Medium, 2=High(需两趟扫描后赋值) */
|
||||||
|
branchLevel: 0 | 1 | 2;
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 标签四:外包围墙壁
|
||||||
|
|
||||||
|
### 定义
|
||||||
|
|
||||||
|
地图**最外圈**(最外一圈格子)的格子中,**墙壁格子与入口格子**之和占外圈总格子数的比例 > **90%**,则标记为 `outerWall = true`。
|
||||||
|
|
||||||
|
$$
|
||||||
|
\text{outerWall} = \frac{|\{(x,y) \in \text{border} : \text{isWall}(x,y) \lor \text{isEntry}(x,y)\}|}{|\text{border}|} > 0.9
|
||||||
|
$$
|
||||||
|
|
||||||
|
### 最外圈定义
|
||||||
|
|
||||||
|
对于 $H \times W$ 的地图,最外圈为所有满足下列条件之一的格子:
|
||||||
|
|
||||||
|
$$
|
||||||
|
x = 0 \; \lor \; x = W - 1 \; \lor \; y = 0 \; \lor \; y = H - 1
|
||||||
|
$$
|
||||||
|
|
||||||
|
最外圈格子总数为 $2(H + W) - 4$(对于 13×13,共 48 格)。
|
||||||
|
|
||||||
|
### 为什么入口也算"通过"
|
||||||
|
|
||||||
|
入口格子在游戏中是楼梯/传送点,不属于可通行的空地,在视觉和结构上等价于边界开口,属于外圈围合结构的合理组成部分,不应被视为"破坏围合"的元素。
|
||||||
|
|
||||||
|
### 实现要点
|
||||||
|
|
||||||
|
- 使用 `convertedMap` 判断墙壁(`tile === config.wall`);
|
||||||
|
- 使用 `originMap` + `converter.isEntry()` 或直接对 `convertedMap` 判断入口(`tile === config.entry`)判断入口;
|
||||||
|
- 两项合取计入分子。
|
||||||
|
|
||||||
|
### 字段扩展
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
/** 是否外包围墙壁 */
|
||||||
|
readonly outerWall: boolean;
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 实现方案
|
||||||
|
|
||||||
|
### 单张地图可直接计算的标签
|
||||||
|
|
||||||
|
以下标签可在 `parseFloorInfo` 中直接计算,加入 `IFloorInfo`:
|
||||||
|
|
||||||
|
- `symmetryH`、`symmetryV`、`symmetryC`
|
||||||
|
- `outerWall`
|
||||||
|
- `roomCount`(原始值)
|
||||||
|
- `highDegBranchCount`(原始值)
|
||||||
|
|
||||||
|
### 需要两趟扫描的等级标签
|
||||||
|
|
||||||
|
`roomCountLevel` 和 `branchLevel` 依赖全局分位数,须在数据集构建完成后进行二次处理:
|
||||||
|
|
||||||
|
```
|
||||||
|
第一趟:构建所有楼层的 IFloorInfo,写入 roomCount / highDegBranchCount
|
||||||
|
第二趟:收集所有楼层的原始值 → 计算 1/3, 2/3 分位 → 回填等级
|
||||||
|
```
|
||||||
|
|
||||||
|
在 Python 训练侧,推荐方式:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 Dataset.__init__ 中完成两趟计算
|
||||||
|
counts = [item['roomCount'] for item in raw_data]
|
||||||
|
counts_sorted = sorted(counts)
|
||||||
|
th1 = counts_sorted[len(counts_sorted) // 3]
|
||||||
|
th2 = counts_sorted[2 * len(counts_sorted) // 3]
|
||||||
|
|
||||||
|
for item in raw_data:
|
||||||
|
c = item['roomCount']
|
||||||
|
item['roomCountLevel'] = 0 if c < th1 else (1 if c < th2 else 2)
|
||||||
|
```
|
||||||
|
|
||||||
|
同理处理 `branchLevel`。
|
||||||
|
|
||||||
|
> **注意**:分位数阈值应仅基于**训练集**统计,验证集 / 测试集使用相同的阈值映射,避免数据泄露。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 训练集成
|
||||||
|
|
||||||
|
### 条件嵌入
|
||||||
|
|
||||||
|
将上述标签作为离散条件与 VQ-VAE 的 z 一同注入 MaskGIT:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 对称性:三个独立布尔值,可合并为 0~7 的整数 cond_sym
|
||||||
|
cond_sym = symmetryH * 4 + symmetryV * 2 + symmetryC * 1 # [0, 7]
|
||||||
|
|
||||||
|
# 房间等级:0 / 1 / 2
|
||||||
|
cond_room = roomCountLevel
|
||||||
|
|
||||||
|
# 分支等级:0 / 1 / 2
|
||||||
|
cond_branch = branchLevel
|
||||||
|
|
||||||
|
# 外包围墙壁:0 / 1
|
||||||
|
cond_outer = int(outerWall)
|
||||||
|
```
|
||||||
|
|
||||||
|
每个条件通过独立的 `nn.Embedding` 映射为固定维度向量,与 VQ-VAE 的 z 序列沿序列维度拼接后,一同经 Cross-Attention 注入 MaskGIT。
|
||||||
|
|
||||||
|
### 条件 Dropout
|
||||||
|
|
||||||
|
与 z dropout 类似,训练时以一定概率(如 10~20%)将部分或全部结构标签替换为"无条件"(null embedding),使模型在推理时支持条件缺省(CFG 风格)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 待细化事项
|
||||||
|
|
||||||
|
- [x] 90% 阈值是否合适?——**保持 90%**,如后续数据分布分析发现问题再调整。
|
||||||
|
- [x] 房间定义中分支节点邻居数量——**改为至少 1 个**,覆盖怪物守宝箱的单入口房间场景。
|
||||||
|
- [x] 分支等级邻居计数口径——**使用非墙邻居**(当前图结构中无墙节点,与 `neighbors.size` 等价)。
|
||||||
|
- [x] 是否新增通道数量/路径长度标签——**暂不考虑**。
|
||||||
|
- [x] 条件嵌入维度对齐——**各标签 Embedding 与 z 序列拼接后统一经 Cross-Attention 注入**。
|
||||||
563
docs/dataset-labels-impl.md
Normal file
563
docs/dataset-labels-impl.md
Normal file
@ -0,0 +1,563 @@
|
|||||||
|
# 数据集标签实现指南
|
||||||
|
|
||||||
|
## 总览
|
||||||
|
|
||||||
|
本文档描述将[标签设计文档](./dataset-labels-design.md)中定义的四类结构标签落地到代码中的具体实现步骤,涉及以下文件:
|
||||||
|
|
||||||
|
| 文件 | 变更性质 |
|
||||||
|
| ------------------------ | ------------------------------------------- |
|
||||||
|
| `data/src/auto/types.ts` | 扩展 `IFloorInfo` 接口 |
|
||||||
|
| `data/src/types.ts` | 扩展 `GinkaTrainData` 接口 |
|
||||||
|
| `data/src/auto/info.ts` | 新增四个 helper 函数,修改 `parseFloorInfo` |
|
||||||
|
| `data/src/auto.ts` | 在数据集序列化时写入新字段 |
|
||||||
|
| `ginka/dataset.py` | 两趟扫描 + `__getitem__` 读取新标签 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 第一步:扩展 TypeScript 类型
|
||||||
|
|
||||||
|
### `data/src/auto/types.ts` — 扩展 `IFloorInfo`
|
||||||
|
|
||||||
|
在 `IFloorInfo` 接口的 `doorHeatmap` 字段之后追加以下字段:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ── 结构标签(新增)──────────────────────────────
|
||||||
|
/** 左右对称(基于 convertedMap 完全匹配) */
|
||||||
|
readonly symmetryH: boolean;
|
||||||
|
/** 上下对称 */
|
||||||
|
readonly symmetryV: boolean;
|
||||||
|
/** 中心对称 */
|
||||||
|
readonly symmetryC: boolean;
|
||||||
|
/** 是否外包围墙壁(最外圈墙壁+入口占比 > 90%) */
|
||||||
|
readonly outerWall: boolean;
|
||||||
|
/** 房间数量原始值(供 Python 两趟扫描使用) */
|
||||||
|
readonly roomCount: number;
|
||||||
|
/** 高连接度分支节点数量原始值(供 Python 两趟扫描使用) */
|
||||||
|
readonly highDegBranchCount: number;
|
||||||
|
```
|
||||||
|
|
||||||
|
### `data/src/types.ts` — 扩展 `GinkaTrainData`
|
||||||
|
|
||||||
|
在 `GinkaTrainData` 接口中追加序列化后写入 JSON 的字段:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export interface GinkaTrainData {
|
||||||
|
tag?: number[];
|
||||||
|
val: number[];
|
||||||
|
map: number[][];
|
||||||
|
size: [number, number];
|
||||||
|
heatmap?: number[][][];
|
||||||
|
// ── 结构标签(新增)──────────────────────────────
|
||||||
|
/** 对称性:[symmetryH, symmetryV, symmetryC],0 或 1 */
|
||||||
|
symmetry: [number, number, number];
|
||||||
|
/** 是否外包围墙壁,0 或 1 */
|
||||||
|
outerWall: number;
|
||||||
|
/** 房间数量原始值 */
|
||||||
|
roomCount: number;
|
||||||
|
/** 高连接度分支节点数量原始值 */
|
||||||
|
highDegBranchCount: number;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 第二步:实现 Helper 函数(`data/src/auto/info.ts`)
|
||||||
|
|
||||||
|
以下四个函数添加到 `computeWallDensityStd` 之后、`parseFloorInfo` 之前。
|
||||||
|
|
||||||
|
### 2.1 对称性计算
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
/**
|
||||||
|
* 计算地图的三种对称性(基于 convertedMap,完全匹配才标记为 true)
|
||||||
|
*/
|
||||||
|
function computeSymmetry(map: number[][]): {
|
||||||
|
symmetryH: boolean;
|
||||||
|
symmetryV: boolean;
|
||||||
|
symmetryC: boolean;
|
||||||
|
} {
|
||||||
|
const H = map.length;
|
||||||
|
const W = H > 0 ? map[0].length : 0;
|
||||||
|
let symmetryH = true;
|
||||||
|
let symmetryV = true;
|
||||||
|
let symmetryC = true;
|
||||||
|
|
||||||
|
outer: for (let y = 0; y < H; y++) {
|
||||||
|
for (let x = 0; x < W; x++) {
|
||||||
|
const tile = map[y][x];
|
||||||
|
if (symmetryH && tile !== map[y][W - 1 - x]) symmetryH = false;
|
||||||
|
if (symmetryV && tile !== map[H - 1 - y][x]) symmetryV = false;
|
||||||
|
if (symmetryC && tile !== map[H - 1 - y][W - 1 - x])
|
||||||
|
symmetryC = false;
|
||||||
|
// 三种对称均已排除,提前退出
|
||||||
|
if (!symmetryH && !symmetryV && !symmetryC) break outer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { symmetryH, symmetryV, symmetryC };
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**复杂度**:O(H × W),最坏情况遍历全图一次,实际有短路优化。
|
||||||
|
|
||||||
|
**注意**:对于奇数尺寸地图(如 13×13),中心行/列的格子与自身比较恒为 true,不影响结果。
|
||||||
|
|
||||||
|
### 2.2 外包围墙壁检测
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
/**
|
||||||
|
* 检测地图最外圈中墙壁 + 入口的占比是否超过 90%
|
||||||
|
* @param map convertedMap
|
||||||
|
* @param wall 墙壁图块编号
|
||||||
|
* @param entry 入口图块编号
|
||||||
|
*/
|
||||||
|
function computeOuterWall(
|
||||||
|
map: number[][],
|
||||||
|
wall: number,
|
||||||
|
entry: number
|
||||||
|
): boolean {
|
||||||
|
const H = map.length;
|
||||||
|
const W = H > 0 ? map[0].length : 0;
|
||||||
|
if (H < 2 || W < 2) return false;
|
||||||
|
|
||||||
|
let borderCount = 0;
|
||||||
|
let wallOrEntry = 0;
|
||||||
|
|
||||||
|
const check = (tile: number) => {
|
||||||
|
borderCount++;
|
||||||
|
if (tile === wall || tile === entry) wallOrEntry++;
|
||||||
|
};
|
||||||
|
|
||||||
|
// 顶行 + 底行
|
||||||
|
for (let x = 0; x < W; x++) {
|
||||||
|
check(map[0][x]);
|
||||||
|
check(map[H - 1][x]);
|
||||||
|
}
|
||||||
|
// 左列 + 右列(排除角格,已由上面计入)
|
||||||
|
for (let y = 1; y < H - 1; y++) {
|
||||||
|
check(map[y][0]);
|
||||||
|
check(map[y][W - 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 13×13 地图: borderCount = 2*(13+13)-4 = 48
|
||||||
|
return borderCount > 0 && wallOrEntry / borderCount > 0.9;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.3 房间数量统计
|
||||||
|
|
||||||
|
**核心思路**:拓扑图中 Empty 节点和 Resource 节点在物理上可能彼此相邻,共同构成一个连续的游戏空间(如 2×3 的房间内放了一个宝物)。不能逐节点独立判断,而应先通过 BFS 将相邻的 Empty/Resource 节点合并为一个"候选区域",再对整个区域判断三个条件。Branch 节点作为边界,不会被合并进区域。
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
/**
|
||||||
|
* 统计拓扑图中符合"房间"定义的连通区域数量。
|
||||||
|
*
|
||||||
|
* 算法:
|
||||||
|
* 1. 以 Empty / Resource 节点为顶点,在它们之间 BFS,
|
||||||
|
* 得到若干"候选区域"(Branch 节点作为边界,不被合并)。
|
||||||
|
* 2. 对每个候选区域检查三个条件:
|
||||||
|
* a. 区域内至少一个节点有 Branch 类型邻居
|
||||||
|
* b. 区域内所有格子总数 >= 4
|
||||||
|
* c. 所有格子的外接矩形宽 > 1 且高 > 1
|
||||||
|
*
|
||||||
|
* @param graph 拓扑图
|
||||||
|
* @param width 地图宽度(用于平坦坐标解码)
|
||||||
|
*/
|
||||||
|
function computeRoomCount(graph: IMapGraph, width: number): number {
|
||||||
|
// Step 1: 收集所有 Empty 和 Resource 节点(去重)
|
||||||
|
const allEmptyResource = new Set<MapGraphNode>();
|
||||||
|
for (const node of graph.nodeMap.values()) {
|
||||||
|
if (
|
||||||
|
node.type === GraphNodeType.Empty ||
|
||||||
|
node.type === GraphNodeType.Resource
|
||||||
|
) {
|
||||||
|
allEmptyResource.add(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: BFS 将相邻的 Empty/Resource 节点合并为连通区域
|
||||||
|
let roomCount = 0;
|
||||||
|
const visited = new Set<MapGraphNode>();
|
||||||
|
|
||||||
|
for (const startNode of allEmptyResource) {
|
||||||
|
if (visited.has(startNode)) continue;
|
||||||
|
|
||||||
|
// BFS:只在 Empty/Resource 节点间传播,Branch 节点阻断
|
||||||
|
const regionNodes = new Set<MapGraphNode>();
|
||||||
|
const queue: MapGraphNode[] = [startNode];
|
||||||
|
visited.add(startNode);
|
||||||
|
|
||||||
|
while (queue.length > 0) {
|
||||||
|
const current = queue.shift()!;
|
||||||
|
regionNodes.add(current);
|
||||||
|
for (const nb of current.neighbors) {
|
||||||
|
if (
|
||||||
|
!visited.has(nb) &&
|
||||||
|
(nb.type === GraphNodeType.Empty ||
|
||||||
|
nb.type === GraphNodeType.Resource)
|
||||||
|
) {
|
||||||
|
visited.add(nb);
|
||||||
|
queue.push(nb);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: 对合并后的区域检查三个条件
|
||||||
|
|
||||||
|
// 条件 a:区域内任一节点有 Branch 邻居
|
||||||
|
let hasBranch = false;
|
||||||
|
outer: for (const node of regionNodes) {
|
||||||
|
for (const nb of node.neighbors) {
|
||||||
|
if (nb.type === GraphNodeType.Branch) {
|
||||||
|
hasBranch = true;
|
||||||
|
break outer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!hasBranch) continue;
|
||||||
|
|
||||||
|
// 收集区域内所有格子,计算总数和外接矩形
|
||||||
|
let totalTiles = 0;
|
||||||
|
let minX = Infinity,
|
||||||
|
maxX = -Infinity,
|
||||||
|
minY = Infinity,
|
||||||
|
maxY = -Infinity;
|
||||||
|
|
||||||
|
for (const node of regionNodes) {
|
||||||
|
totalTiles += node.tiles.size;
|
||||||
|
for (const t of node.tiles) {
|
||||||
|
const x = t % width;
|
||||||
|
const y = (t - x) / width;
|
||||||
|
if (x < minX) minX = x;
|
||||||
|
if (x > maxX) maxX = x;
|
||||||
|
if (y < minY) minY = y;
|
||||||
|
if (y > maxY) maxY = y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 条件 b:总格子数 >= 4
|
||||||
|
if (totalTiles < 4) continue;
|
||||||
|
|
||||||
|
// 条件 c:外接矩形宽高均 > 1(即 maxX - minX >= 1 && maxY - minY >= 1)
|
||||||
|
if (maxX - minX < 1 || maxY - minY < 1) continue;
|
||||||
|
|
||||||
|
roomCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return roomCount;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**边界情况讨论**:
|
||||||
|
|
||||||
|
- **混合节点房间**:2×3 房间(5 个空地 + 1 个资源)→ 1 个 Empty 节点 + 1 个 Resource 节点,BFS 合并后总格子数=6,外接矩形 2×3,**计入房间**。
|
||||||
|
- **纯条形走廊**:1×4 的 Empty 节点 → 外接矩形高=1,**不计入房间**。
|
||||||
|
- **孤立资源死角**:Resource 节点的邻居全是 Entry 而无 Branch → 条件 a 不满足,**不计入房间**。
|
||||||
|
- **两个相邻的 Empty 区域被 Branch 隔开**:BFS 不会跨越 Branch,各自单独判断。
|
||||||
|
|
||||||
|
### 2.4 高连接度分支节点统计
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
/**
|
||||||
|
* 统计邻居数 >= 3 的分支节点数量
|
||||||
|
*
|
||||||
|
* 由于拓扑图中已不含墙节点,neighbors.size 即等于非墙邻居数。
|
||||||
|
*
|
||||||
|
* @param graph 拓扑图
|
||||||
|
*/
|
||||||
|
function computeHighDegBranchCount(graph: IMapGraph): number {
|
||||||
|
let count = 0;
|
||||||
|
const visited = new Set<MapGraphNode>();
|
||||||
|
|
||||||
|
for (const node of graph.nodeMap.values()) {
|
||||||
|
if (visited.has(node)) continue;
|
||||||
|
visited.add(node);
|
||||||
|
|
||||||
|
if (node.type === GraphNodeType.Branch && node.neighbors.size >= 3) {
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 第三步:在 `parseFloorInfo` 中调用
|
||||||
|
|
||||||
|
在 `parseFloorInfo` 函数内,`topo` 构建完毕、`floorInfo` 对象构造之前,调用上述四个函数:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ── 结构标签计算 ─────────────────────────────────
|
||||||
|
const width = map[0]?.length ?? 0;
|
||||||
|
const { symmetryH, symmetryV, symmetryC } = computeSymmetry(map);
|
||||||
|
const outerWall = computeOuterWall(
|
||||||
|
map,
|
||||||
|
config.classes.wall,
|
||||||
|
config.classes.entry
|
||||||
|
);
|
||||||
|
const roomCount = computeRoomCount(topo.graph, width);
|
||||||
|
const highDegBranchCount = computeHighDegBranchCount(topo.graph);
|
||||||
|
```
|
||||||
|
|
||||||
|
然后在 `floorInfo` 字面量中追加这些字段:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const floorInfo: IFloorInfo = {
|
||||||
|
// ...(原有字段保持不变)
|
||||||
|
symmetryH,
|
||||||
|
symmetryV,
|
||||||
|
symmetryC,
|
||||||
|
outerWall,
|
||||||
|
roomCount,
|
||||||
|
highDegBranchCount
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 第四步:在 `data/src/auto.ts` 中序列化新字段
|
||||||
|
|
||||||
|
在构建 `GinkaTrainData` 的对象字面量中追加:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const data: GinkaTrainData = {
|
||||||
|
map: floor.data.map,
|
||||||
|
size: [width, height],
|
||||||
|
heatmap: [
|
||||||
|
/* ...原有热力图通道... */
|
||||||
|
],
|
||||||
|
val: [
|
||||||
|
/* ...原有标量... */
|
||||||
|
],
|
||||||
|
// ── 新增结构标签 ──────────────────────────────
|
||||||
|
symmetry: [
|
||||||
|
info.symmetryH ? 1 : 0,
|
||||||
|
info.symmetryV ? 1 : 0,
|
||||||
|
info.symmetryC ? 1 : 0
|
||||||
|
],
|
||||||
|
outerWall: info.outerWall ? 1 : 0,
|
||||||
|
roomCount: info.roomCount,
|
||||||
|
highDegBranchCount: info.highDegBranchCount
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
布尔值存为 `0/1` 整数,便于 Python 侧直接读取,不需要类型转换。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 第五步:Python 侧两趟扫描(`ginka/dataset.py`)
|
||||||
|
|
||||||
|
### 5.1 等频分箱函数
|
||||||
|
|
||||||
|
在 `dataset.py` 顶部添加通用的等频分箱 helper:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def assign_level(values: list[int]) -> list[int]:
|
||||||
|
"""
|
||||||
|
将整数列表按等频分箱映射为 0/1/2 三档等级。
|
||||||
|
分位数阈值基于当前列表计算(训练集与验证集应共用训练集的阈值)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: 原始统计值列表,顺序与数据集一一对应
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
与输入等长的等级列表,每项为 0 / 1 / 2
|
||||||
|
"""
|
||||||
|
n = len(values)
|
||||||
|
if n == 0:
|
||||||
|
return []
|
||||||
|
sorted_vals = sorted(values)
|
||||||
|
th1 = sorted_vals[n // 3]
|
||||||
|
th2 = sorted_vals[2 * n // 3]
|
||||||
|
return [
|
||||||
|
0 if v < th1 else (1 if v < th2 else 2)
|
||||||
|
for v in values
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 修改 `GinkaVQDataset.__init__`
|
||||||
|
|
||||||
|
在加载数据之后立即进行两趟扫描,并将等级结果回填到各 item 中:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class GinkaVQDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_path: str,
|
||||||
|
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
|
||||||
|
wall_mask_min: float = 0.0,
|
||||||
|
wall_mask_max: float = 0.5,
|
||||||
|
# 以下两个参数仅在 train 模式下使用,验证集传入训练集的阈值即可
|
||||||
|
room_thresholds: tuple[int, int] | None = None,
|
||||||
|
branch_thresholds: tuple[int, int] | None = None,
|
||||||
|
):
|
||||||
|
self.data = load_data(data_path)
|
||||||
|
# ...(原有初始化逻辑)
|
||||||
|
|
||||||
|
# ── 两趟扫描:计算等频分箱阈值 ──────────────────────────────
|
||||||
|
room_counts = [item['roomCount'] for item in self.data]
|
||||||
|
branch_counts = [item['highDegBranchCount'] for item in self.data]
|
||||||
|
|
||||||
|
if room_thresholds is None:
|
||||||
|
# 训练集:自行计算阈值
|
||||||
|
n = len(room_counts)
|
||||||
|
rs = sorted(room_counts)
|
||||||
|
bs = sorted(branch_counts)
|
||||||
|
self.room_th = (rs[n // 3], rs[2 * n // 3])
|
||||||
|
self.branch_th = (bs[n // 3], bs[2 * n // 3])
|
||||||
|
else:
|
||||||
|
# 验证集/测试集:直接使用训练集的阈值,避免数据泄露
|
||||||
|
self.room_th = room_thresholds
|
||||||
|
self.branch_th = branch_thresholds
|
||||||
|
|
||||||
|
def to_level(v: int, th: tuple[int, int]) -> int:
|
||||||
|
return 0 if v < th[0] else (1 if v < th[1] else 2)
|
||||||
|
|
||||||
|
# 回填等级字段
|
||||||
|
for item in self.data:
|
||||||
|
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
|
||||||
|
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
|
||||||
|
```
|
||||||
|
|
||||||
|
**调用方式**(`train_vq.py` 中):
|
||||||
|
|
||||||
|
```python
|
||||||
|
dataset_train = GinkaVQDataset(args.train)
|
||||||
|
dataset_val = GinkaVQDataset(
|
||||||
|
args.validate,
|
||||||
|
room_thresholds=dataset_train.room_th,
|
||||||
|
branch_thresholds=dataset_train.branch_th
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 `__getitem__` 中读取结构标签
|
||||||
|
|
||||||
|
对称性标签在数据增强(旋转/翻转)后需要**重新从增强后的地图中计算**,因为 `rot90(k=1/3)` 会交换 `symmetryH` 和 `symmetryV`。其他三个标签在旋转/翻转下保持不变,可直接读取。
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _compute_symmetry(target_np: np.ndarray) -> tuple[int, int, int]:
|
||||||
|
"""从 numpy 地图矩阵中直接计算三种对称性,O(H*W)"""
|
||||||
|
H, W = target_np.shape
|
||||||
|
sym_h = bool(np.all(target_np == target_np[:, ::-1]))
|
||||||
|
sym_v = bool(np.all(target_np == target_np[::-1, :]))
|
||||||
|
sym_c = bool(np.all(target_np == target_np[::-1, ::-1]))
|
||||||
|
return int(sym_h), int(sym_v), int(sym_c)
|
||||||
|
```
|
||||||
|
|
||||||
|
在 `__getitem__` 数据增强完成后,读取所有标签:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# ...(原有增强逻辑,target_np 已经过 rot90 / flip)
|
||||||
|
|
||||||
|
# 对称性:在增强后重新计算
|
||||||
|
sym_h, sym_v, sym_c = _compute_symmetry(target_np)
|
||||||
|
cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7]
|
||||||
|
|
||||||
|
# 其余标签:增强不改变拓扑结构,直接读取
|
||||||
|
item = self.data[idx]
|
||||||
|
cond_room = item['roomCountLevel'] # 0/1/2
|
||||||
|
cond_branch = item['branchLevel'] # 0/1/2
|
||||||
|
cond_outer = item['outerWall'] # 0/1
|
||||||
|
|
||||||
|
# 封装为 tensor
|
||||||
|
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"raw_map": ...,
|
||||||
|
"masked_map": ...,
|
||||||
|
"target_map": ...,
|
||||||
|
"subset": ...,
|
||||||
|
"struct_cond": struct_cond # [4],供模型 Embedding 查表
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 第六步:模型侧条件注入(`ginka/maskGIT/model.py`)
|
||||||
|
|
||||||
|
`struct_cond` 的四个维度分别对应不同词表大小的 Embedding:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 词表大小
|
||||||
|
SYM_VOCAB = 8 # cond_sym: [0, 7]
|
||||||
|
ROOM_VOCAB = 4 # cond_room: [0, 2] + 1 个 null(dropout 用)
|
||||||
|
BRANCH_VOCAB = 4 # cond_branch: [0, 2] + 1 个 null
|
||||||
|
OUTER_VOCAB = 3 # cond_outer: [0, 1] + 1 个 null
|
||||||
|
|
||||||
|
# 在 GinkaMaskGIT.__init__ 中
|
||||||
|
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
|
||||||
|
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
|
||||||
|
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
|
||||||
|
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||||
|
```
|
||||||
|
|
||||||
|
在 `forward` 中,将四个 Embedding 结果与 VQ-VAE 的 z 序列拼接,作为 Cross-Attention 的 memory:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def forward(self, map_tokens, z, struct_cond, dropout_struct=False):
|
||||||
|
# z: [B, L, d_z]
|
||||||
|
# struct_cond: [B, 4] — [sym, room, branch, outer]
|
||||||
|
|
||||||
|
B = z.size(0)
|
||||||
|
|
||||||
|
if dropout_struct or (self.training and torch.rand(1) < self.struct_dropout_prob):
|
||||||
|
# 条件 dropout:全部替换为 null index(各词表最后一个 index)
|
||||||
|
e_sym = self.sym_embed(torch.full((B,), SYM_VOCAB - 1, device=z.device))
|
||||||
|
e_room = self.room_embed(torch.full((B,), ROOM_VOCAB - 1, device=z.device))
|
||||||
|
e_branch = self.branch_embed(torch.full((B,), BRANCH_VOCAB - 1, device=z.device))
|
||||||
|
e_outer = self.outer_embed(torch.full((B,), OUTER_VOCAB - 1, device=z.device))
|
||||||
|
else:
|
||||||
|
e_sym = self.sym_embed(struct_cond[:, 0]) # [B, d_z]
|
||||||
|
e_room = self.room_embed(struct_cond[:, 1])
|
||||||
|
e_branch = self.branch_embed(struct_cond[:, 2])
|
||||||
|
e_outer = self.outer_embed(struct_cond[:, 3])
|
||||||
|
|
||||||
|
# 将四个结构标签嵌入拼接为序列,与 z 合并
|
||||||
|
# 每个 e_* 形状为 [B, d_z],unsqueeze 后变为 [B, 1, d_z]
|
||||||
|
struct_seq = torch.stack(
|
||||||
|
[e_sym, e_room, e_branch, e_outer], dim=1
|
||||||
|
) # [B, 4, d_z]
|
||||||
|
|
||||||
|
memory = torch.cat([z, struct_seq], dim=1) # [B, L+4, d_z]
|
||||||
|
|
||||||
|
# 后续 Cross-Attention 正常进行
|
||||||
|
# query = map_token_embeddings,key/value = memory
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
**null index 规则**:各 Embedding 的最后一个 index 保留为"无条件"占位符,词表大小因此比有效类别数多 1。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 边界情况与注意事项
|
||||||
|
|
||||||
|
### 关于两趟扫描的时机
|
||||||
|
|
||||||
|
两趟扫描在 `Dataset.__init__` 中完成,整个过程仅需遍历两次 Python 列表,耗时可忽略不计(数千条数据 < 1ms)。不建议延迟到 `__getitem__` 中逐条计算。
|
||||||
|
|
||||||
|
### 关于对称性在数据增强下的变化
|
||||||
|
|
||||||
|
| 增强操作 | `symmetryH` | `symmetryV` | `symmetryC` |
|
||||||
|
| ----------------- | ----------- | ----------- | ----------- |
|
||||||
|
| `fliplr` | 不变 | 不变 | 不变 |
|
||||||
|
| `flipud` | 不变 | 不变 | 不变 |
|
||||||
|
| `rot90(k=1 or 3)` | 与 V 交换 | 与 H 交换 | 不变 |
|
||||||
|
| `rot90(k=2)` | 不变 | 不变 | 不变 |
|
||||||
|
|
||||||
|
因此在 `__getitem__` 中对增强后的地图**重新计算**对称性是最简洁正确的方案,无需记录增强历史。
|
||||||
|
|
||||||
|
### 关于极端分布
|
||||||
|
|
||||||
|
若训练集中某一标签的样本极度不均匀(如 90% 的地图无对称性),可以在条件 Dropout 中对不常见的条件值适当降低 dropout 概率,以确保模型充分学习该条件。初始阶段统一使用相同 dropout 概率即可,后续根据生成效果调整。
|
||||||
|
|
||||||
|
### 关于 `roomCount = 0` 和 `highDegBranchCount = 0` 的地图
|
||||||
|
|
||||||
|
这类地图在等频分箱后会进入 Low(0)等级。如果训练集中大量地图的值为 0,`th1` 可能也为 0,导致 Low 等级极少。可以在分箱前加一步检查:若 `th1 == th2`,则手动将 `th2 = th1 + 1` 以避免 Medium 等级为空。
|
||||||
|
|
||||||
|
```python
|
||||||
|
th1 = sorted_vals[n // 3]
|
||||||
|
th2 = sorted_vals[2 * n // 3]
|
||||||
|
if th1 == th2:
|
||||||
|
th2 = th1 + 1 # 防止 Medium 等级为空
|
||||||
|
```
|
||||||
@ -376,7 +376,7 @@ $$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{comm
|
|||||||
|
|
||||||
## 待探索事项
|
## 待探索事项
|
||||||
|
|
||||||
- 合适的 K、L 取值(建议从 K=16, L=2 开始实验)
|
- 合适的 K、L 取值(建议从 K=16, L=2 开始实验),K=1, L=64 可能比较合适。
|
||||||
- z dropout 的最优概率
|
- z dropout 的最优概率
|
||||||
- 若后续 codebook collapse:引入 EMA 更新 + 重置机制
|
- 若后续 codebook collapse:引入 EMA 更新 + 重置机制
|
||||||
- 若后续需要更细粒度控制:加入标量 cond(需对推理侧标量做随机扰动处理)
|
- 若后续需要更细粒度控制:加入标量 cond(需对推理侧标量做随机扰动处理)
|
||||||
|
|||||||
@ -232,6 +232,14 @@ class GinkaJointDataset(Dataset):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_symmetry(target_np: np.ndarray) -> tuple:
|
||||||
|
"""从 numpy 地图矩阵中直接计算三种对称性,O(H*W)"""
|
||||||
|
sym_h = bool(np.all(target_np == target_np[:, ::-1]))
|
||||||
|
sym_v = bool(np.all(target_np == target_np[::-1, :]))
|
||||||
|
sym_c = bool(np.all(target_np == target_np[::-1, ::-1]))
|
||||||
|
return int(sym_h), int(sym_v), int(sym_c)
|
||||||
|
|
||||||
|
|
||||||
class GinkaVQDataset(Dataset):
|
class GinkaVQDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
用于 VQ-VAE + MaskGIT 联合训练的多子集数据集。
|
用于 VQ-VAE + MaskGIT 联合训练的多子集数据集。
|
||||||
@ -259,6 +267,8 @@ class GinkaVQDataset(Dataset):
|
|||||||
data_path: str,
|
data_path: str,
|
||||||
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
|
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
|
||||||
wall_mask_ratio: float = 0.3,
|
wall_mask_ratio: float = 0.3,
|
||||||
|
room_thresholds: tuple = None,
|
||||||
|
branch_thresholds: tuple = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -266,6 +276,8 @@ class GinkaVQDataset(Dataset):
|
|||||||
subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化
|
subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化
|
||||||
wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限
|
wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限
|
||||||
(每次从 [0, wall_mask_ratio] 均匀采样实际比例)
|
(每次从 [0, wall_mask_ratio] 均匀采样实际比例)
|
||||||
|
room_thresholds: (th1, th2) 房间数量等频分箱阈值;为 None 时自动从当前数据计算(训练集)
|
||||||
|
branch_thresholds: (th1, th2) 分支数量等频分箱阈值;为 None 时自动从当前数据计算(训练集)
|
||||||
"""
|
"""
|
||||||
self.data = load_data(data_path)
|
self.data = load_data(data_path)
|
||||||
self.wall_mask_ratio = wall_mask_ratio
|
self.wall_mask_ratio = wall_mask_ratio
|
||||||
@ -275,6 +287,35 @@ class GinkaVQDataset(Dataset):
|
|||||||
normalized = [x / total_w for x in subset_weights]
|
normalized = [x / total_w for x in subset_weights]
|
||||||
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
|
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
|
||||||
|
|
||||||
|
# ── 两趟扫描:计算等频分箱阈值 ──────────────────────────────
|
||||||
|
room_counts = [item['roomCount'] for item in self.data]
|
||||||
|
branch_counts = [item['highDegBranchCount'] for item in self.data]
|
||||||
|
|
||||||
|
if room_thresholds is None:
|
||||||
|
n = len(room_counts)
|
||||||
|
rs = sorted(room_counts)
|
||||||
|
bs = sorted(branch_counts)
|
||||||
|
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
|
||||||
|
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
|
||||||
|
# 防止 Medium 等级为空
|
||||||
|
if th1_r == th2_r:
|
||||||
|
th2_r = th1_r + 1
|
||||||
|
if th1_b == th2_b:
|
||||||
|
th2_b = th1_b + 1
|
||||||
|
self.room_th = (th1_r, th2_r)
|
||||||
|
self.branch_th = (th1_b, th2_b)
|
||||||
|
else:
|
||||||
|
self.room_th = room_thresholds
|
||||||
|
self.branch_th = branch_thresholds
|
||||||
|
|
||||||
|
def to_level(v: int, th: tuple) -> int:
|
||||||
|
return 0 if v < th[0] else (1 if v < th[1] else 2)
|
||||||
|
|
||||||
|
# 回填等级字段
|
||||||
|
for item in self.data:
|
||||||
|
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
|
||||||
|
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
@ -407,11 +448,23 @@ class GinkaVQDataset(Dataset):
|
|||||||
masked_np = self._apply_subset(raw_np, subset) # [H*W]
|
masked_np = self._apply_subset(raw_np, subset) # [H*W]
|
||||||
raw_flat = raw_np.reshape(-1) # [H*W]
|
raw_flat = raw_np.reshape(-1) # [H*W]
|
||||||
|
|
||||||
|
# 对称性:在增强后重新计算
|
||||||
|
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
|
||||||
|
cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7]
|
||||||
|
|
||||||
|
# 其余结构标签:增强不改变拓扑结构,直接读取
|
||||||
|
cond_room = item['roomCountLevel'] # 0/1/2
|
||||||
|
cond_branch = item['branchLevel'] # 0/1/2
|
||||||
|
cond_outer = item['outerWall'] # 0/1
|
||||||
|
|
||||||
|
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入
|
"raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入
|
||||||
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
|
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
|
||||||
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
|
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
|
||||||
"subset": subset, # 调试/统计用
|
"subset": subset, # 调试/统计用
|
||||||
|
"struct_cond": struct_cond, # [4],供模型 Embedding 查表
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,12 @@ import torch.nn as nn
|
|||||||
from ..utils import print_memory
|
from ..utils import print_memory
|
||||||
from .maskGIT import Transformer
|
from .maskGIT import Transformer
|
||||||
|
|
||||||
|
# 结构标签词表大小(最后一个索引为无条件占位符 null)
|
||||||
|
SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-6,7 = null
|
||||||
|
ROOM_VOCAB = 4 # roomCountLevel 0-2,3 = null
|
||||||
|
BRANCH_VOCAB = 4 # branchLevel 0-2,3 = null
|
||||||
|
OUTER_VOCAB = 3 # outerWall 0-1,2 = null
|
||||||
|
|
||||||
|
|
||||||
class GinkaMaskGIT(nn.Module):
|
class GinkaMaskGIT(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -26,6 +32,7 @@ class GinkaMaskGIT(nn.Module):
|
|||||||
num_layers: int = 4,
|
num_layers: int = 4,
|
||||||
map_size: int = 13 * 13,
|
map_size: int = 13 * 13,
|
||||||
z_dropout: float = 0.1,
|
z_dropout: float = 0.1,
|
||||||
|
struct_dropout: float = 0.15,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -37,9 +44,12 @@ class GinkaMaskGIT(nn.Module):
|
|||||||
num_layers: Transformer 层数
|
num_layers: Transformer 层数
|
||||||
map_size: 地图 token 总数(H * W)
|
map_size: 地图 token 总数(H * W)
|
||||||
z_dropout: 训练时随机替换 z 为随机码字的概率(提升鲁棒性)
|
z_dropout: 训练时随机替换 z 为随机码字的概率(提升鲁棒性)
|
||||||
|
struct_dropout: 训练时以此概率将结构标签替换为 null(无条件占位),
|
||||||
|
实现 classifier-free guidance 兼容训练
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.z_dropout = z_dropout
|
self.z_dropout = z_dropout
|
||||||
|
self.struct_dropout_prob = struct_dropout
|
||||||
|
|
||||||
# Tile 嵌入 + 位置编码
|
# Tile 嵌入 + 位置编码
|
||||||
self.tile_embedding = nn.Embedding(num_classes, d_model)
|
self.tile_embedding = nn.Embedding(num_classes, d_model)
|
||||||
@ -51,6 +61,12 @@ class GinkaMaskGIT(nn.Module):
|
|||||||
nn.LayerNorm(d_model),
|
nn.LayerNorm(d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 结构标签嵌入(编码到 d_z 维度,与 z 拼接后统一投影到 d_model)
|
||||||
|
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
|
||||||
|
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
|
||||||
|
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
|
||||||
|
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||||
|
|
||||||
# Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention
|
# Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention
|
||||||
self.transformer = Transformer(
|
self.transformer = Transformer(
|
||||||
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
|
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
|
||||||
@ -58,30 +74,72 @@ class GinkaMaskGIT(nn.Module):
|
|||||||
|
|
||||||
self.output_fc = nn.Linear(d_model, num_classes)
|
self.output_fc = nn.Linear(d_model, num_classes)
|
||||||
|
|
||||||
def forward(self, map: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
map: torch.Tensor,
|
||||||
|
z: torch.Tensor,
|
||||||
|
struct_cond: torch.Tensor | None = None,
|
||||||
|
dropout_struct: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15)
|
map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15)
|
||||||
z: [B, L, d_z] VQ-VAE 量化后的离散隐变量
|
z: [B, L, d_z] VQ-VAE 量化后的离散隐变量
|
||||||
|
struct_cond: [B, 4] 结构标签 LongTensor,顺序为
|
||||||
|
[cond_sym, cond_room, cond_branch, cond_outer];
|
||||||
|
为 None 时等价于全 null(无条件模式)
|
||||||
|
dropout_struct: bool 强制将所有结构标签替换为 null(推理时无条件生成)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logits: [B, H*W, num_classes]
|
logits: [B, H*W, num_classes]
|
||||||
"""
|
"""
|
||||||
|
B = z.shape[0]
|
||||||
|
|
||||||
# z dropout:训练时以一定概率将 z 替换为随机均匀噪声,
|
# z dropout:训练时以一定概率将 z 替换为随机均匀噪声,
|
||||||
# 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义
|
# 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义
|
||||||
if self.training and self.z_dropout > 0:
|
if self.training and self.z_dropout > 0:
|
||||||
mask = torch.rand(z.shape[0], 1, 1, device=z.device) < self.z_dropout
|
mask = torch.rand(B, 1, 1, device=z.device) < self.z_dropout
|
||||||
rand_z = torch.randn_like(z)
|
rand_z = torch.randn_like(z)
|
||||||
z = torch.where(mask, rand_z, z)
|
z = torch.where(mask, rand_z, z)
|
||||||
|
|
||||||
# 投影 z 到 d_model 维度
|
# 结构标签嵌入
|
||||||
z_mem = self.z_proj(z) # [B, L, d_model]
|
# struct_cond 为 None 或 dropout_struct=True 时,全部使用 null 索引
|
||||||
|
if struct_cond is None or dropout_struct:
|
||||||
|
sym_idx = torch.full((B,), SYM_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||||
|
room_idx = torch.full((B,), ROOM_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||||
|
branch_idx = torch.full((B,), BRANCH_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||||
|
outer_idx = torch.full((B,), OUTER_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||||
|
else:
|
||||||
|
sc = struct_cond.to(z.device)
|
||||||
|
sym_idx, room_idx, branch_idx, outer_idx = sc[:, 0], sc[:, 1], sc[:, 2], sc[:, 3]
|
||||||
|
|
||||||
|
# 训练时对各标签独立做 struct dropout
|
||||||
|
if self.training and self.struct_dropout_prob > 0:
|
||||||
|
def _drop(idx, null_val):
|
||||||
|
drop_mask = torch.rand(B, device=z.device) < self.struct_dropout_prob
|
||||||
|
return torch.where(drop_mask, torch.full_like(idx, null_val), idx)
|
||||||
|
sym_idx = _drop(sym_idx, SYM_VOCAB - 1)
|
||||||
|
room_idx = _drop(room_idx, ROOM_VOCAB - 1)
|
||||||
|
branch_idx = _drop(branch_idx, BRANCH_VOCAB - 1)
|
||||||
|
outer_idx = _drop(outer_idx, OUTER_VOCAB - 1)
|
||||||
|
|
||||||
|
# 嵌入结构标签到 d_z 维度,拼接到 z 序列末尾
|
||||||
|
e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z]
|
||||||
|
e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z]
|
||||||
|
e_branch = self.branch_embed(branch_idx).unsqueeze(1) # [B, 1, d_z]
|
||||||
|
e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z]
|
||||||
|
|
||||||
|
struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z]
|
||||||
|
z_ext = torch.cat([z, struct_seq], dim=1) # [B, L+4, d_z]
|
||||||
|
|
||||||
|
# 统一投影到 d_model 维度
|
||||||
|
z_mem = self.z_proj(z_ext) # [B, L+4, d_model]
|
||||||
|
|
||||||
# tile embedding + 位置编码
|
# tile embedding + 位置编码
|
||||||
x = self.tile_embedding(map) # [B, H*W, d_model]
|
x = self.tile_embedding(map) # [B, H*W, d_model]
|
||||||
x = x + self.pos_embedding # [B, H*W, d_model]
|
x = x + self.pos_embedding # [B, H*W, d_model]
|
||||||
|
|
||||||
# Transformer:encoder 做 map 自注意力,decoder cross-attend z
|
# Transformer:encoder 做 map 自注意力,decoder cross-attend z+struct
|
||||||
x = self.transformer(x, memory=z_mem) # [B, H*W, d_model]
|
x = self.transformer(x, memory=z_mem) # [B, H*W, d_model]
|
||||||
|
|
||||||
logits = self.output_fc(x) # [B, H*W, num_classes]
|
logits = self.output_fc(x) # [B, H*W, num_classes]
|
||||||
@ -109,9 +167,17 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169]
|
map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169]
|
||||||
z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64]
|
z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64]
|
||||||
|
struct_input = torch.tensor([[3, 1, 0, 1],
|
||||||
|
[0, 2, 1, 0],
|
||||||
|
[7, 3, 3, 2],
|
||||||
|
[1, 0, 2, 1]], dtype=torch.long).to(device) # [B=4, 4]
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
logits = model(map_input, z_input)
|
logits = model(map_input, z_input, struct_cond=struct_input)
|
||||||
print(f"\nlogits shape: {logits.shape}") # [4, 169, 16]
|
print(f"\nlogits shape: {logits.shape}") # [4, 169, 16]
|
||||||
|
|
||||||
|
# 无条件模式测试
|
||||||
|
logits_uncond = model(map_input, z_input, struct_cond=None)
|
||||||
|
print(f"logits_uncond shape: {logits_uncond.shape}") # [4, 169, 16]
|
||||||
|
|
||||||
print_memory(device, "前向传播后")
|
print_memory(device, "前向传播后")
|
||||||
|
|||||||
@ -59,7 +59,8 @@ MG_D_MODEL = 192
|
|||||||
MG_NHEAD = 8
|
MG_NHEAD = 8
|
||||||
MG_LAYERS = 4
|
MG_LAYERS = 4
|
||||||
MG_DIM_FF = 512
|
MG_DIM_FF = 512
|
||||||
MG_Z_DROPOUT= 0.15 # 训练时以此概率把 z 替换为随机噪声
|
MG_Z_DROPOUT = 0.15 # 训练时以此概率把 z 替换为随机噪声
|
||||||
|
MG_STRUCT_DROPOUT= 0.15 # 训练时以此概率将结构标签替换为 null(无条件占位)
|
||||||
|
|
||||||
# 验证时对每条样本额外采样的 z 数量(0 = 只用真实 z)
|
# 验证时对每条样本额外采样的 z 数量(0 = 只用真实 z)
|
||||||
N_Z_SAMPLES = 3
|
N_Z_SAMPLES = 3
|
||||||
@ -106,6 +107,7 @@ def maskgit_generate(
|
|||||||
z: torch.Tensor,
|
z: torch.Tensor,
|
||||||
steps: int = GENERATE_STEP,
|
steps: int = GENERATE_STEP,
|
||||||
init_map: torch.Tensor = None,
|
init_map: torch.Tensor = None,
|
||||||
|
struct_cond: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
迭代生成地图(cosine schedule unmasking)。
|
迭代生成地图(cosine schedule unmasking)。
|
||||||
@ -133,7 +135,7 @@ def maskgit_generate(
|
|||||||
if not generatable.any():
|
if not generatable.any():
|
||||||
break
|
break
|
||||||
|
|
||||||
logits = model_mg(map_seq, z) # [B, S, C]
|
logits = model_mg(map_seq, z, struct_cond=struct_cond) # [B, S, C]
|
||||||
probs = F.softmax(logits, dim=-1)
|
probs = F.softmax(logits, dim=-1)
|
||||||
dist = torch.distributions.Categorical(probs)
|
dist = torch.distributions.Categorical(probs)
|
||||||
sampled = dist.sample() # [B, S]
|
sampled = dist.sample() # [B, S]
|
||||||
@ -279,7 +281,8 @@ def validate(
|
|||||||
B = raw_map.shape[0]
|
B = raw_map.shape[0]
|
||||||
|
|
||||||
z_q, _, vq_loss, _, _ = model_vq(raw_map)
|
z_q, _, vq_loss, _, _ = model_vq(raw_map)
|
||||||
logits = model_mg(masked_map, z_q)
|
struct_cond_b = batch["struct_cond"].to(device) # [B, 4]
|
||||||
|
logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b)
|
||||||
mask = (masked_map == MASK_TOKEN)
|
mask = (masked_map == MASK_TOKEN)
|
||||||
|
|
||||||
ce_loss = F.cross_entropy(
|
ce_loss = F.cross_entropy(
|
||||||
@ -297,6 +300,7 @@ def validate(
|
|||||||
"raw": raw_map[i:i+1].clone(),
|
"raw": raw_map[i:i+1].clone(),
|
||||||
"masked": masked_map[i:i+1].clone(),
|
"masked": masked_map[i:i+1].clone(),
|
||||||
"z_q": z_q[i:i+1].clone(),
|
"z_q": z_q[i:i+1].clone(),
|
||||||
|
"struct_cond": struct_cond_b[i:i+1].clone(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if all(v is not None for v in captured.values()):
|
if all(v is not None for v in captured.values()):
|
||||||
@ -307,24 +311,24 @@ def validate(
|
|||||||
imgs = []
|
imgs = []
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
z_r = model_vq.sample(1, device)
|
z_r = model_vq.sample(1, device)
|
||||||
gen = maskgit_generate(model_mg, z_r, init_map=cond_map)
|
gen = maskgit_generate(model_mg, z_r, init_map=cond_map) # struct_cond=None 无条件
|
||||||
imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}"))
|
imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}"))
|
||||||
return imgs
|
return imgs
|
||||||
|
|
||||||
# ── 场景1:标准掩码补全(子集 A)─────────────────────────────────────────
|
# ── 场景1:标准掩码补全(子集 A)─────────────────────────────────────────
|
||||||
if captured['A'] is not None:
|
if captured['A'] is not None:
|
||||||
cap = captured['A']
|
cap = captured['A']
|
||||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||||
|
|
||||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "masked input")
|
cond_img = label_image(make_map_image(cond[0], tile_dict), "masked input")
|
||||||
|
|
||||||
# 单步 argmax 预测(观察模型对掩码位置的瞬时判断)
|
# 单步 argmax 预测(观察模型对掩码位置的瞬时判断)
|
||||||
pred = model_mg(cond, z_q).argmax(dim=-1)[0]
|
pred = model_mg(cond, z_q, struct_cond=sc).argmax(dim=-1)[0]
|
||||||
pred_img = label_image(make_map_image(pred, tile_dict), "z_real pred")
|
pred_img = label_image(make_map_image(pred, tile_dict), "z_real pred")
|
||||||
|
|
||||||
# 迭代生成(从掩码输入出发,真实 z)
|
# 迭代生成(从掩码输入出发,真实 z)
|
||||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||||
|
|
||||||
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||||
@ -333,11 +337,11 @@ def validate(
|
|||||||
# ── 场景2:墙壁辅助生成(子集 B)─────────────────────────────────────────
|
# ── 场景2:墙壁辅助生成(子集 B)─────────────────────────────────────────
|
||||||
if captured['B'] is not None:
|
if captured['B'] is not None:
|
||||||
cap = captured['B']
|
cap = captured['B']
|
||||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||||
|
|
||||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall-only input")
|
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall-only input")
|
||||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||||
|
|
||||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||||
@ -346,11 +350,11 @@ def validate(
|
|||||||
# ── 场景3:稀疏墙壁条件生成(子集 C)────────────────────────────────────
|
# ── 场景3:稀疏墙壁条件生成(子集 C)────────────────────────────────────
|
||||||
if captured['C'] is not None:
|
if captured['C'] is not None:
|
||||||
cap = captured['C']
|
cap = captured['C']
|
||||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||||
|
|
||||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "sparse wall input")
|
cond_img = label_image(make_map_image(cond[0], tile_dict), "sparse wall input")
|
||||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||||
|
|
||||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||||
@ -359,11 +363,11 @@ def validate(
|
|||||||
# ── 场景4:墙壁+入口条件生成(子集 D)───────────────────────────────────
|
# ── 场景4:墙壁+入口条件生成(子集 D)───────────────────────────────────
|
||||||
if captured['D'] is not None:
|
if captured['D'] is not None:
|
||||||
cap = captured['D']
|
cap = captured['D']
|
||||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||||
|
|
||||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall+entrance input")
|
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall+entrance input")
|
||||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||||
|
|
||||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||||
@ -403,6 +407,7 @@ def train():
|
|||||||
num_layers=MG_LAYERS,
|
num_layers=MG_LAYERS,
|
||||||
map_size=MAP_SIZE,
|
map_size=MAP_SIZE,
|
||||||
z_dropout=MG_Z_DROPOUT,
|
z_dropout=MG_Z_DROPOUT,
|
||||||
|
struct_dropout=MG_STRUCT_DROPOUT,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
vq_params = sum(p.numel() for p in model_vq.parameters())
|
vq_params = sum(p.numel() for p in model_vq.parameters())
|
||||||
@ -419,6 +424,8 @@ def train():
|
|||||||
dataset_val = GinkaVQDataset(
|
dataset_val = GinkaVQDataset(
|
||||||
args.validate,
|
args.validate,
|
||||||
subset_weights=SUBSET_WEIGHTS,
|
subset_weights=SUBSET_WEIGHTS,
|
||||||
|
room_thresholds=dataset_train.room_th,
|
||||||
|
branch_thresholds=dataset_train.branch_th,
|
||||||
)
|
)
|
||||||
dataloader_train = DataLoader(
|
dataloader_train = DataLoader(
|
||||||
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
|
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
|
||||||
@ -481,8 +488,9 @@ def train():
|
|||||||
# 1. VQ-VAE 编码真实地图 → z_q
|
# 1. VQ-VAE 编码真实地图 → z_q
|
||||||
z_q, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q: [B, L, d_z]
|
z_q, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q: [B, L, d_z]
|
||||||
|
|
||||||
# 2. MaskGIT 以掩码地图 + z 预测原始 tile
|
# 2. MaskGIT 以掩码地图 + z + 结构标签预测原始 tile
|
||||||
logits = model_mg(masked_map, z_q) # [B, 169, C]
|
struct_cond = batch["struct_cond"].to(device) # [B, 4]
|
||||||
|
logits = model_mg(masked_map, z_q, struct_cond=struct_cond) # [B, 169, C]
|
||||||
|
|
||||||
# 3. 只对被 mask 的位置计算 CE loss
|
# 3. 只对被 mask 的位置计算 CE loss
|
||||||
mask = (masked_map == MASK_TOKEN) # [B, 169] bool
|
mask = (masked_map == MASK_TOKEN) # [B, 169] bool
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user